shipyard/src/bin/shipyard_airflow/shipyard_airflow/plugins/concurrency_check_operator.py

151 lines
5.7 KiB
Python

# Copyright 2017 AT&T Intellectual Property. All other rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
from airflow.models import BaseOperator
from airflow.utils.decorators import apply_defaults
from airflow.plugins_manager import AirflowPlugin
from airflow.hooks.postgres_hook import PostgresHook
from airflow.exceptions import AirflowException
from shipyard_airflow.plugins.xcom_pusher import XcomPusher
# constants related to the dag_run table.
DAG_RUN_SELECT_RUNNING_SQL = ("select dag_id, execution_date "
"from dag_run "
"where state='running'")
# connection name for airflow's own sql db
AIRFLOW_DB = 'airflows_own_db'
# each set in this list of sets indicates DAGs that shouldn't execute together
CONFLICTING_DAG_SETS = [set(['deploy_site', 'update_site', 'update_software',
'redeploy_server', 'relabel_nodes'])]
def find_conflicting_dag_set(dag_name, conflicting_dag_sets=None):
"""
Using dag_name, finds all other dag names that are in any set of
conflicting dags from input conflicting_dag_sets
"""
if conflicting_dag_sets is None:
conflicting_dag_sets = CONFLICTING_DAG_SETS
full_set = set()
for single_set in conflicting_dag_sets:
if dag_name in single_set:
full_set = full_set | single_set
full_set.discard(dag_name)
logging.info('Potential conflicts: %s', ', '.join(full_set))
return full_set
class ConcurrencyCheckOperator(BaseOperator):
"""
Provides a way to indicate which DAGs should not be executing
simultaneously.
"""
@apply_defaults
def __init__(self, conflicting_dag_set=None, *args, **kwargs):
super(ConcurrencyCheckOperator, self).__init__(*args, **kwargs)
self.conflicting_dag_set = conflicting_dag_set
self.xcom_push = None
def execute(self, context):
"""
Run the check to see if this DAG has an concurrency issues with other
DAGs. Stop the workflow if there is.
"""
self.xcom_push = XcomPusher(context['task_instance'])
self._xcom_push_status(False)
if self.conflicting_dag_set is None:
self.check_dag_id = self.dag.dag_id
logging.debug('dag_id is %s', self.check_dag_id)
if '.' in self.dag.dag_id:
self.check_dag_id = self.dag.dag_id.split('.', 1)[0]
logging.debug('dag_id modified to %s', self.check_dag_id)
logging.info('from dag %s, assuming %s for concurrency check',
self.dag.dag_id, self.check_dag_id)
self.conflicting_dag_set = find_conflicting_dag_set(
self.check_dag_id)
logging.info('Checking for running of dags: %s',
', '.join(self.conflicting_dag_set))
conflicting_dag = self.find_conflicting_dag(self.check_dag_id)
if conflicting_dag is None:
logging.info('No conflicts found. Continuing Execution')
else:
self.abort_conflict(
dag_name=self.check_dag_id, conflict=conflicting_dag)
self._xcom_push_status(True)
def get_executing_dags(self):
"""
Encapsulation of getting database records of running dags.
Returns a list of records of dag_id and execution_date
"""
logging.info('Executing: %s', DAG_RUN_SELECT_RUNNING_SQL)
airflow_pg_hook = PostgresHook(postgres_conn_id=AIRFLOW_DB)
return airflow_pg_hook.get_records(DAG_RUN_SELECT_RUNNING_SQL)
def find_conflicting_dag(self, dag_id_to_check):
"""
Checks for a DAGs that is conflicting and exits based on the first
one found.
Also will return the dag_id_to_check as conflicting if more than 1
instance is running
"""
self_dag_count = 0
for dag_id, execution_date in self.get_executing_dags():
logging.info('Checking %s @ %s vs. current %s', dag_id,
execution_date, dag_id_to_check)
if dag_id == dag_id_to_check:
self_dag_count += 1
logging.info(
"Found an instance of the dag_id being checked. Tally: %s",
self_dag_count)
if dag_id in self.conflicting_dag_set:
logging.info("Conflict found: %s @ %s", dag_id, execution_date)
return dag_id
if self_dag_count > 1:
return dag_id_to_check
return None
def abort_conflict(self, dag_name, conflict):
"""
Log and raise an exception that there is a conflicting workflow.
"""
conflict_string = '{} conflicts with running {}. Aborting run'.format(
dag_name, conflict)
logging.error(conflict_string)
raise AirflowException(conflict_string)
def _xcom_push_status(self, status):
"""
Push the status of the concurrency check
:param status: bool of whether or not this task is successful
:return:
"""
self.xcom_push.xcom_push(key="concurrency_check_success", value=status)
class ConcurrencyCheckPlugin(AirflowPlugin):
"""
Register this plugin for this operator.
"""
name = 'concurrency_check_operator_plugin'
operators = [ConcurrencyCheckOperator]