You cannot select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
151 lines
5.7 KiB
Python
151 lines
5.7 KiB
Python
# Copyright 2017 AT&T Intellectual Property. All other rights reserved.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
import logging
|
|
|
|
from airflow.models import BaseOperator
|
|
from airflow.utils.decorators import apply_defaults
|
|
from airflow.plugins_manager import AirflowPlugin
|
|
from airflow.hooks.postgres_hook import PostgresHook
|
|
from airflow.exceptions import AirflowException
|
|
|
|
from shipyard_airflow.plugins.xcom_pusher import XcomPusher
|
|
|
|
# constants related to the dag_run table.
|
|
DAG_RUN_SELECT_RUNNING_SQL = ("select dag_id, execution_date "
|
|
"from dag_run "
|
|
"where state='running'")
|
|
|
|
# connection name for airflow's own sql db
|
|
AIRFLOW_DB = 'airflows_own_db'
|
|
|
|
# each set in this list of sets indicates DAGs that shouldn't execute together
|
|
CONFLICTING_DAG_SETS = [set(['deploy_site', 'update_site', 'update_software',
|
|
'redeploy_server', 'relabel_nodes'])]
|
|
|
|
|
|
def find_conflicting_dag_set(dag_name, conflicting_dag_sets=None):
|
|
"""
|
|
Using dag_name, finds all other dag names that are in any set of
|
|
conflicting dags from input conflicting_dag_sets
|
|
"""
|
|
if conflicting_dag_sets is None:
|
|
conflicting_dag_sets = CONFLICTING_DAG_SETS
|
|
full_set = set()
|
|
for single_set in conflicting_dag_sets:
|
|
if dag_name in single_set:
|
|
full_set = full_set | single_set
|
|
full_set.discard(dag_name)
|
|
logging.info('Potential conflicts: %s', ', '.join(full_set))
|
|
return full_set
|
|
|
|
|
|
class ConcurrencyCheckOperator(BaseOperator):
|
|
"""
|
|
Provides a way to indicate which DAGs should not be executing
|
|
simultaneously.
|
|
"""
|
|
|
|
@apply_defaults
|
|
def __init__(self, conflicting_dag_set=None, *args, **kwargs):
|
|
super(ConcurrencyCheckOperator, self).__init__(*args, **kwargs)
|
|
self.conflicting_dag_set = conflicting_dag_set
|
|
self.xcom_push = None
|
|
|
|
def execute(self, context):
|
|
"""
|
|
Run the check to see if this DAG has an concurrency issues with other
|
|
DAGs. Stop the workflow if there is.
|
|
"""
|
|
self.xcom_push = XcomPusher(context['task_instance'])
|
|
self._xcom_push_status(False)
|
|
if self.conflicting_dag_set is None:
|
|
self.check_dag_id = self.dag.dag_id
|
|
logging.debug('dag_id is %s', self.check_dag_id)
|
|
if '.' in self.dag.dag_id:
|
|
self.check_dag_id = self.dag.dag_id.split('.', 1)[0]
|
|
logging.debug('dag_id modified to %s', self.check_dag_id)
|
|
|
|
logging.info('from dag %s, assuming %s for concurrency check',
|
|
self.dag.dag_id, self.check_dag_id)
|
|
self.conflicting_dag_set = find_conflicting_dag_set(
|
|
self.check_dag_id)
|
|
|
|
logging.info('Checking for running of dags: %s',
|
|
', '.join(self.conflicting_dag_set))
|
|
|
|
conflicting_dag = self.find_conflicting_dag(self.check_dag_id)
|
|
if conflicting_dag is None:
|
|
logging.info('No conflicts found. Continuing Execution')
|
|
else:
|
|
self.abort_conflict(
|
|
dag_name=self.check_dag_id, conflict=conflicting_dag)
|
|
self._xcom_push_status(True)
|
|
|
|
def get_executing_dags(self):
|
|
"""
|
|
Encapsulation of getting database records of running dags.
|
|
Returns a list of records of dag_id and execution_date
|
|
"""
|
|
logging.info('Executing: %s', DAG_RUN_SELECT_RUNNING_SQL)
|
|
airflow_pg_hook = PostgresHook(postgres_conn_id=AIRFLOW_DB)
|
|
return airflow_pg_hook.get_records(DAG_RUN_SELECT_RUNNING_SQL)
|
|
|
|
def find_conflicting_dag(self, dag_id_to_check):
|
|
"""
|
|
Checks for a DAGs that is conflicting and exits based on the first
|
|
one found.
|
|
Also will return the dag_id_to_check as conflicting if more than 1
|
|
instance is running
|
|
"""
|
|
self_dag_count = 0
|
|
for dag_id, execution_date in self.get_executing_dags():
|
|
logging.info('Checking %s @ %s vs. current %s', dag_id,
|
|
execution_date, dag_id_to_check)
|
|
if dag_id == dag_id_to_check:
|
|
self_dag_count += 1
|
|
logging.info(
|
|
"Found an instance of the dag_id being checked. Tally: %s",
|
|
self_dag_count)
|
|
if dag_id in self.conflicting_dag_set:
|
|
logging.info("Conflict found: %s @ %s", dag_id, execution_date)
|
|
return dag_id
|
|
if self_dag_count > 1:
|
|
return dag_id_to_check
|
|
return None
|
|
|
|
def abort_conflict(self, dag_name, conflict):
|
|
"""
|
|
Log and raise an exception that there is a conflicting workflow.
|
|
"""
|
|
conflict_string = '{} conflicts with running {}. Aborting run'.format(
|
|
dag_name, conflict)
|
|
logging.error(conflict_string)
|
|
raise AirflowException(conflict_string)
|
|
|
|
def _xcom_push_status(self, status):
|
|
"""
|
|
Push the status of the concurrency check
|
|
:param status: bool of whether or not this task is successful
|
|
:return:
|
|
"""
|
|
self.xcom_push.xcom_push(key="concurrency_check_success", value=status)
|
|
|
|
|
|
class ConcurrencyCheckPlugin(AirflowPlugin):
|
|
"""
|
|
Register this plugin for this operator.
|
|
"""
|
|
name = 'concurrency_check_operator_plugin'
|
|
operators = [ConcurrencyCheckOperator]
|