From e24c2504e0d027629f74390c1764823139f9b262 Mon Sep 17 00:00:00 2001 From: Bryan Strassner Date: Mon, 21 Aug 2017 18:53:03 -0500 Subject: [PATCH] Add update_site dag, concurrency check concurrency check uses the postgres db to look at dag_runs table to determine what is currently running. update_site and deploy_site have been added/refactored to be a linear flow, with a function as an error handler. This was done to prevent runs of downstream error handling steps when their parent was set to upstream failed status. Added simple test for failure handler. Added testing for concurrency check logic Added requirements for testing purposes Change-Id: I29a2509df999b4714a9ade2ebabac7522504e24e --- .../dags/dag_concurrency_check.py | 26 +--- shipyard_airflow/dags/deploy_site.py | 83 +++-------- shipyard_airflow/dags/failure_handlers.py | 21 +++ shipyard_airflow/dags/preflight_checks.py | 14 -- shipyard_airflow/dags/update_site.py | 94 ++++++++++++ shipyard_airflow/dags/validate_site_design.py | 16 --- .../plugins/concurrency_check_operator.py | 135 ++++++++++++++++++ test-requirements.txt | 3 + tests/unit/dags/__init__.py | 0 tests/unit/dags/test_failure_handlers.py | 27 ++++ tests/unit/plugins/__init__.py | 0 .../test_concurrency_check_operator.py | 132 +++++++++++++++++ 12 files changed, 434 insertions(+), 117 deletions(-) create mode 100644 shipyard_airflow/dags/failure_handlers.py create mode 100644 shipyard_airflow/dags/update_site.py create mode 100644 shipyard_airflow/plugins/concurrency_check_operator.py create mode 100644 tests/unit/dags/__init__.py create mode 100644 tests/unit/dags/test_failure_handlers.py create mode 100644 tests/unit/plugins/__init__.py create mode 100644 tests/unit/plugins/test_concurrency_check_operator.py diff --git a/shipyard_airflow/dags/dag_concurrency_check.py b/shipyard_airflow/dags/dag_concurrency_check.py index 744eb1c3..33c8c658 100644 --- a/shipyard_airflow/dags/dag_concurrency_check.py +++ b/shipyard_airflow/dags/dag_concurrency_check.py @@ -13,8 +13,7 @@ # limitations under the License. from airflow.models import DAG -from airflow.operators import PlaceholderOperator -from airflow.operators.dummy_operator import DummyOperator +from airflow.operators import ConcurrencyCheckOperator def dag_concurrency_check(parent_dag_name, child_dag_name, args): @@ -26,28 +25,7 @@ def dag_concurrency_check(parent_dag_name, child_dag_name, args): '{}.{}'.format(parent_dag_name, child_dag_name), default_args=args, ) - # TODO () Replace this operator with a real operator that will: - # 1) Look for an instance of the parent_dag_name running currently in - # airflow - # 2) Fail if the parent_dag_name is running - # 3) Succeed if there are no instances of parent_dag_name running - dag_concurrency_check_operator = PlaceholderOperator( + dag_concurrency_check_operator = ConcurrencyCheckOperator( task_id='dag_concurrency_check', dag=dag) return dag - - -def dag_concurrency_check_failure_handler(parent_dag_name, child_dag_name, - args): - ''' - Peforms the actions necessary when concurrency checks fail - ''' - dag = DAG( - '{}.{}'.format(parent_dag_name, child_dag_name), - default_args=args, ) - - operator = DummyOperator( - task_id='dag_concurrency_check_failure_handler', - dag=dag, ) - - return dag diff --git a/shipyard_airflow/dags/deploy_site.py b/shipyard_airflow/dags/deploy_site.py index 53bb97a3..19edc262 100644 --- a/shipyard_airflow/dags/deploy_site.py +++ b/shipyard_airflow/dags/deploy_site.py @@ -11,35 +11,27 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - from datetime import timedelta import airflow from airflow import DAG -from dag_concurrency_check import dag_concurrency_check -from dag_concurrency_check import dag_concurrency_check_failure_handler +import failure_handlers from preflight_checks import all_preflight_checks -from preflight_checks import preflight_failure_handler from validate_site_design import validate_site_design -from validate_site_design import validate_site_design_failure_handler from airflow.operators.subdag_operator import SubDagOperator +from airflow.operators import ConcurrencyCheckOperator from airflow.operators import DeckhandOperator from airflow.operators import PlaceholderOperator -from airflow.utils.trigger_rule import TriggerRule -''' +""" deploy_site is the top-level orchestration DAG for deploying a site using the Undercloud platform. -''' +""" PARENT_DAG_NAME = 'deploy_site' DAG_CONCURRENCY_CHECK_DAG_NAME = 'dag_concurrency_check' -CONCURRENCY_FAILURE_DAG_NAME = 'concurrency_check_failure_handler' ALL_PREFLIGHT_CHECKS_DAG_NAME = 'preflight' -PREFLIGHT_FAILURE_DAG_NAME = 'preflight_failure_handler' DECKHAND_GET_DESIGN_VERSION = 'deckhand_get_design_version' VALIDATE_SITE_DESIGN_DAG_NAME = 'validate_site_design' -VALIDATION_FAILED_DAG_NAME = 'validate_site_design_failure_handler' -DECKHAND_MARK_LAST_KNOWN_GOOD = 'deckhand_mark_last_known_good' default_args = { 'owner': 'airflow', @@ -54,84 +46,49 @@ default_args = { dag = DAG(PARENT_DAG_NAME, default_args=default_args, schedule_interval=None) -concurrency_check = SubDagOperator( - subdag=dag_concurrency_check( - PARENT_DAG_NAME, DAG_CONCURRENCY_CHECK_DAG_NAME, args=default_args), +concurrency_check = ConcurrencyCheckOperator( task_id=DAG_CONCURRENCY_CHECK_DAG_NAME, - dag=dag, ) - -concurrency_check_failure_handler = SubDagOperator( - subdag=dag_concurrency_check_failure_handler( - PARENT_DAG_NAME, CONCURRENCY_FAILURE_DAG_NAME, args=default_args), - task_id=CONCURRENCY_FAILURE_DAG_NAME, - trigger_rule=TriggerRule.ONE_FAILED, - dag=dag, ) + on_failure_callback=failure_handlers.step_failure_handler, + dag=dag) preflight = SubDagOperator( subdag=all_preflight_checks( PARENT_DAG_NAME, ALL_PREFLIGHT_CHECKS_DAG_NAME, args=default_args), task_id=ALL_PREFLIGHT_CHECKS_DAG_NAME, - dag=dag, ) - -preflight_failure = SubDagOperator( - subdag=preflight_failure_handler( - PARENT_DAG_NAME, PREFLIGHT_FAILURE_DAG_NAME, args=default_args), - task_id=PREFLIGHT_FAILURE_DAG_NAME, - trigger_rule=TriggerRule.ONE_FAILED, + on_failure_callback=failure_handlers.step_failure_handler, dag=dag, ) get_design_version = DeckhandOperator( - task_id=DECKHAND_GET_DESIGN_VERSION, dag=dag) + task_id=DECKHAND_GET_DESIGN_VERSION, + on_failure_callback=failure_handlers.step_failure_handler, + dag=dag) validate_site_design = SubDagOperator( subdag=validate_site_design( PARENT_DAG_NAME, VALIDATE_SITE_DESIGN_DAG_NAME, args=default_args), task_id=VALIDATE_SITE_DESIGN_DAG_NAME, + on_failure_callback=failure_handlers.step_failure_handler, dag=dag) -validate_site_design_failure = SubDagOperator( - subdag=validate_site_design_failure_handler( - dag.dag_id, VALIDATION_FAILED_DAG_NAME, args=default_args), - task_id=VALIDATION_FAILED_DAG_NAME, - trigger_rule=TriggerRule.ONE_FAILED, - dag=dag) - -drydock_build = PlaceholderOperator(task_id='drydock_build', dag=dag) - -drydock_failure_handler = PlaceholderOperator( - task_id='drydock_failure_handler', - trigger_rule=TriggerRule.ONE_FAILED, +drydock_build = PlaceholderOperator( + task_id='drydock_build', + on_failure_callback=failure_handlers.step_failure_handler, dag=dag) query_node_status = PlaceholderOperator( - task_id='deployed_node_status', dag=dag) - -nodes_not_healthy = PlaceholderOperator( - task_id='deployed_nodes_not_healthy', - trigger_rule=TriggerRule.ONE_FAILED, + task_id='deployed_node_status', + on_failure_callback=failure_handlers.step_failure_handler, dag=dag) -armada_build = PlaceholderOperator(task_id='armada_build', dag=dag) - -armada_failure_handler = PlaceholderOperator( - task_id='armada_failure_handler', - trigger_rule=TriggerRule.ONE_FAILED, +armada_build = PlaceholderOperator( + task_id='armada_build', + on_failure_callback=failure_handlers.step_failure_handler, dag=dag) -mark_last_known_good = DeckhandOperator( - task_id=DECKHAND_MARK_LAST_KNOWN_GOOD, dag=dag) - # DAG Wiring -concurrency_check_failure_handler.set_upstream(concurrency_check) preflight.set_upstream(concurrency_check) -preflight_failure.set_upstream(preflight) get_design_version.set_upstream(preflight) validate_site_design.set_upstream(get_design_version) -validate_site_design_failure.set_upstream(validate_site_design) drydock_build.set_upstream(validate_site_design) -drydock_failure_handler.set_upstream(drydock_build) query_node_status.set_upstream(drydock_build) -nodes_not_healthy.set_upstream(query_node_status) armada_build.set_upstream(query_node_status) -armada_failure_handler.set_upstream(armada_build) -mark_last_known_good.set_upstream(armada_build) diff --git a/shipyard_airflow/dags/failure_handlers.py b/shipyard_airflow/dags/failure_handlers.py new file mode 100644 index 00000000..f1355fe9 --- /dev/null +++ b/shipyard_airflow/dags/failure_handlers.py @@ -0,0 +1,21 @@ +# Copyright 2017 AT&T Intellectual Property. All other rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import logging + + +def step_failure_handler(context): + """ + Callable used to handle failure of this step. + """ + logging.info('%s step failed', context['task_instance'].task_id) diff --git a/shipyard_airflow/dags/preflight_checks.py b/shipyard_airflow/dags/preflight_checks.py index d076e630..80d5d0f8 100644 --- a/shipyard_airflow/dags/preflight_checks.py +++ b/shipyard_airflow/dags/preflight_checks.py @@ -14,7 +14,6 @@ from airflow.models import DAG from airflow.operators.subdag_operator import SubDagOperator -from airflow.operators.dummy_operator import DummyOperator from airflow.operators import PlaceholderOperator @@ -150,16 +149,3 @@ def all_preflight_checks(parent_dag_name, child_dag_name, args): dag=dag, ) return dag - - -def preflight_failure_handler(parent_dag_name, child_dag_name, args): - ''' - Peforms the actions necessary when preflight checks fail - ''' - dag = DAG( - '{}.{}'.format(parent_dag_name, child_dag_name), - default_args=args, ) - - operator = DummyOperator(task_id='preflight_failure_handler', dag=dag) - - return dag diff --git a/shipyard_airflow/dags/update_site.py b/shipyard_airflow/dags/update_site.py new file mode 100644 index 00000000..2452c417 --- /dev/null +++ b/shipyard_airflow/dags/update_site.py @@ -0,0 +1,94 @@ +# Copyright 2017 AT&T Intellectual Property. All other rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from datetime import timedelta + +import airflow +from airflow import DAG +import failure_handlers +from preflight_checks import all_preflight_checks +from validate_site_design import validate_site_design +from airflow.operators.subdag_operator import SubDagOperator +from airflow.operators import ConcurrencyCheckOperator +from airflow.operators import DeckhandOperator +from airflow.operators import PlaceholderOperator +""" +update_site is the top-level orchestration DAG for updating a site using the +Undercloud platform. +""" + +PARENT_DAG_NAME = 'update_site' +DAG_CONCURRENCY_CHECK_DAG_NAME = 'dag_concurrency_check' +ALL_PREFLIGHT_CHECKS_DAG_NAME = 'preflight' +DECKHAND_GET_DESIGN_VERSION = 'deckhand_get_design_version' +VALIDATE_SITE_DESIGN_DAG_NAME = 'validate_site_design' + +default_args = { + 'owner': 'airflow', + 'depends_on_past': False, + 'start_date': airflow.utils.dates.days_ago(1), + 'email': [''], + 'email_on_failure': False, + 'email_on_retry': False, + 'retries': 0, + 'retry_delay': timedelta(minutes=1), +} + +dag = DAG(PARENT_DAG_NAME, default_args=default_args, schedule_interval=None) + +concurrency_check = ConcurrencyCheckOperator( + task_id=DAG_CONCURRENCY_CHECK_DAG_NAME, + on_failure_callback=failure_handlers.step_failure_handler, + dag=dag) + +preflight = SubDagOperator( + subdag=all_preflight_checks( + PARENT_DAG_NAME, ALL_PREFLIGHT_CHECKS_DAG_NAME, args=default_args), + task_id=ALL_PREFLIGHT_CHECKS_DAG_NAME, + on_failure_callback=failure_handlers.step_failure_handler, + dag=dag, ) + +get_design_version = DeckhandOperator( + task_id=DECKHAND_GET_DESIGN_VERSION, + on_failure_callback=failure_handlers.step_failure_handler, + dag=dag) + +validate_site_design = SubDagOperator( + subdag=validate_site_design( + PARENT_DAG_NAME, VALIDATE_SITE_DESIGN_DAG_NAME, args=default_args), + task_id=VALIDATE_SITE_DESIGN_DAG_NAME, + on_failure_callback=failure_handlers.step_failure_handler, + dag=dag) + +drydock_build = PlaceholderOperator( + task_id='drydock_build', + on_failure_callback=failure_handlers.step_failure_handler, + dag=dag) + +query_node_status = PlaceholderOperator( + task_id='deployed_node_status', + on_failure_callback=failure_handlers.step_failure_handler, + dag=dag) + +armada_build = PlaceholderOperator( + task_id='armada_build', + on_failure_callback=failure_handlers.step_failure_handler, + dag=dag) + +# DAG Wiring +preflight.set_upstream(concurrency_check) +get_design_version.set_upstream(preflight) +validate_site_design.set_upstream(get_design_version) +drydock_build.set_upstream(validate_site_design) +query_node_status.set_upstream(drydock_build) +armada_build.set_upstream(query_node_status) diff --git a/shipyard_airflow/dags/validate_site_design.py b/shipyard_airflow/dags/validate_site_design.py index 0e1efaeb..b7590a11 100644 --- a/shipyard_airflow/dags/validate_site_design.py +++ b/shipyard_airflow/dags/validate_site_design.py @@ -13,26 +13,10 @@ # limitations under the License. from airflow.models import DAG -from airflow.operators.dummy_operator import DummyOperator from airflow.operators import DeckhandOperator from airflow.operators import PlaceholderOperator -def validate_site_design_failure_handler(parent_dag_name, child_dag_name, - args): - ''' - Peforms the actions necessary when any of the site design checks fail - ''' - dag = DAG( - '{}.{}'.format(parent_dag_name, child_dag_name), - default_args=args, ) - - operator = DummyOperator( - task_id='site_design_validation_failure_handler', dag=dag) - - return dag - - def validate_site_design(parent_dag_name, child_dag_name, args): ''' Subdag to delegate design verification to the UCP components diff --git a/shipyard_airflow/plugins/concurrency_check_operator.py b/shipyard_airflow/plugins/concurrency_check_operator.py new file mode 100644 index 00000000..50ced2e6 --- /dev/null +++ b/shipyard_airflow/plugins/concurrency_check_operator.py @@ -0,0 +1,135 @@ +# Copyright 2017 AT&T Intellectual Property. All other rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import logging + +from airflow.models import BaseOperator +from airflow.utils.decorators import apply_defaults +from airflow.plugins_manager import AirflowPlugin +from airflow.hooks.postgres_hook import PostgresHook +from airflow.exceptions import AirflowException + +# constants related to the dag_run table. +DAG_RUN_SELECT_RUNNING_SQL = ("select dag_id, execution_date " + "from dag_run " + "where state='running'") + +# connection name for airflow's own sql db +AIRFLOW_DB = 'airflows_own_db' + +# each set in this list of sets indicates DAGs that shouldn't execute together +CONFLICTING_DAG_SETS = [set(['deploy_site', 'update_site', 'redeploy_server'])] + + +def find_conflicting_dag_set(dag_name, conflicting_dag_sets=None): + """ + Using dag_name, finds all other dag names that are in any set of + conflicting dags from input conflicting_dag_sets + """ + if conflicting_dag_sets is None: + conflicting_dag_sets = CONFLICTING_DAG_SETS + full_set = set() + for single_set in conflicting_dag_sets: + if dag_name in single_set: + full_set = full_set | single_set + full_set.discard(dag_name) + logging.info('Potential conflicts: %s', ', '.join(full_set)) + return full_set + + +class ConcurrencyCheckOperator(BaseOperator): + """ + Provides a way to indicate which DAGs should not be executing + simultaneously. + """ + + @apply_defaults + def __init__(self, conflicting_dag_set=None, *args, **kwargs): + super(ConcurrencyCheckOperator, self).__init__(*args, **kwargs) + if conflicting_dag_set is not None: + self.conflicting_dag_set = conflicting_dag_set + else: + self.check_dag_id = self.dag.dag_id + logging.debug('dag_id is %s', self.check_dag_id) + if '.' in self.dag.dag_id: + self.check_dag_id = self.dag.dag_id.split('.', 1)[0] + logging.debug('dag_id modified to %s', self.check_dag_id) + + logging.info('from dag %s, assuming %s for concurrency check', + self.dag.dag_id, self.check_dag_id) + self.conflicting_dag_set = find_conflicting_dag_set( + self.check_dag_id) + + def execute(self, context): + """ + Run the check to see if this DAG has an concurrency issues with other + DAGs. Stop the workflow if there is. + """ + logging.info('Checking for running of dags: %s', + ', '.join(self.conflicting_dag_set)) + + conflicting_dag = self.find_conflicting_dag(self.check_dag_id) + if conflicting_dag is None: + logging.info('No conflicts found. Continuing Execution') + else: + self.abort_conflict( + dag_name=self.check_dag_id, conflict=conflicting_dag) + + def get_executing_dags(self): + """ + Encapsulation of getting database records of running dags. + Returns a list of records of dag_id and execution_date + """ + logging.info('Executing: %s', DAG_RUN_SELECT_RUNNING_SQL) + airflow_pg_hook = PostgresHook(postgres_conn_id=AIRFLOW_DB) + return airflow_pg_hook.get_records(DAG_RUN_SELECT_RUNNING_SQL) + + def find_conflicting_dag(self, dag_id_to_check): + """ + Checks for a DAGs that is conflicting and exits based on the first + one found. + Also will return the dag_id_to_check as conflicting if more than 1 + instance is running + """ + self_dag_count = 0 + for dag_id, execution_date in self.get_executing_dags(): + logging.info('Checking %s @ %s vs. current %s', dag_id, + execution_date, dag_id_to_check) + if dag_id == dag_id_to_check: + self_dag_count += 1 + logging.info( + "Found an instance of the dag_id being checked. Tally: %s", + self_dag_count) + if dag_id in self.conflicting_dag_set: + logging.info("Conflict found: %s @ %s", dag_id, execution_date) + return dag_id + if self_dag_count > 1: + return dag_id_to_check + return None + + def abort_conflict(self, dag_name, conflict): + """ + Log and raise an exception that there is a conflicting workflow. + """ + conflict_string = '{} conflicts with running {}. Aborting run'.format( + dag_name, conflict) + logging.warning(conflict_string) + raise AirflowException(conflict_string) + + +class ConcurrencyCheckPlugin(AirflowPlugin): + """ + Register this plugin for this operator. + """ + name = 'concurrency_check_operator_plugin' + operators = [ConcurrencyCheckOperator] diff --git a/test-requirements.txt b/test-requirements.txt index a1748afc..57cb50c7 100644 --- a/test-requirements.txt +++ b/test-requirements.txt @@ -1,6 +1,9 @@ # Testing pytest==3.2.1 mock==2.0.0 +testfixtures==5.1.1 +apache-airflow[crypto,celery,postgres,hive,hdfs,jdbc]==1.8.1 +psycopg2==2.7.3 # Linting flake8==3.3.0 diff --git a/tests/unit/dags/__init__.py b/tests/unit/dags/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/unit/dags/test_failure_handlers.py b/tests/unit/dags/test_failure_handlers.py new file mode 100644 index 00000000..951ee298 --- /dev/null +++ b/tests/unit/dags/test_failure_handlers.py @@ -0,0 +1,27 @@ +# Copyright 2017 AT&T Intellectual Property. All other rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from testfixtures import LogCapture +from types import SimpleNamespace +from shipyard_airflow.dags import failure_handlers + +CONTEXT = {'task_instance': SimpleNamespace(task_id='cheese')} + + +def test_step_failure_handler(): + """ + Ensure that the failure handler is logging as intended. + """ + with LogCapture() as log_capturer: + failure_handlers.step_failure_handler(CONTEXT) + log_capturer.check(('root', 'INFO', 'cheese step failed')) diff --git a/tests/unit/plugins/__init__.py b/tests/unit/plugins/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/unit/plugins/test_concurrency_check_operator.py b/tests/unit/plugins/test_concurrency_check_operator.py new file mode 100644 index 00000000..5b4fb87c --- /dev/null +++ b/tests/unit/plugins/test_concurrency_check_operator.py @@ -0,0 +1,132 @@ +# Copyright 2017 AT&T Intellectual Property. All other rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import pytest +from shipyard_airflow.plugins import concurrency_check_operator as operator +from shipyard_airflow.plugins.concurrency_check_operator import ( + ConcurrencyCheckOperator +) +from airflow.exceptions import AirflowException + +COLOR_SETS = [ + set(['blue', 'green']), + set(['blue', 'purple']), + set(['red', 'purple']), + set(['red', 'orange']), + set(['yellow', 'green']), + set(['yellow', 'orange']), +] + +CONFLICT_SET = set(['cow', 'monkey', 'chicken']) + + +def test_find_conflicting_dag_set(): + """ + Ensure that the right values are determined by find_conflicting_dag_set + """ + # Should not be found in the default set - no conflicts + assert operator.DAG_RUN_SELECT_RUNNING_SQL is not None + assert not operator.find_conflicting_dag_set("this_is_completely_cheese") + + # Check for contents vs the COLOR_SETS + not_in_green_response_set = set(['purple', 'red', 'orange']) + response_set = operator.find_conflicting_dag_set( + dag_name='green', conflicting_dag_sets=COLOR_SETS) + assert 'blue' in response_set + assert 'yellow' in response_set + assert not_in_green_response_set.isdisjoint(response_set) + assert len(response_set) == 2 + + +def get_executing_dags_stub_running_twice(): + return [ + ('buffalo', 'now'), + ('buffalo', 'earlier'), + ('squirrel', 'ages ago'), + ] + + +def get_executing_dags_stub(): + return [ + ('buffalo', 'now'), + ('chicken', 'earlier'), + ('monkey', 'ages ago'), + ] + + +def get_executing_dags_stub_no_conflicts(): + return [ + ('buffalo', 'now'), + ('hedgehog', 'earlier'), + ('panda', 'ages ago'), + ] + + +def test_find_conflicting_dag(): + """ + Ensure that: + 1) responds with a found conflict + 2) responds None if there is no conflict + 3) responds with the dag_id_to_check being searched for if it is running + more than once. + """ + cco = ConcurrencyCheckOperator( + conflicting_dag_set=CONFLICT_SET, + task_id='bogus') + + # no conflicts + cco.get_executing_dags = get_executing_dags_stub_no_conflicts + assert cco.find_conflicting_dag('buffalo') is None + + # self is running twice + cco.get_executing_dags = get_executing_dags_stub_running_twice + assert cco.find_conflicting_dag('buffalo') != 'squirrel' + assert cco.find_conflicting_dag('buffalo') == 'buffalo' + + # a conflict from the list + cco.get_executing_dags = get_executing_dags_stub + assert cco.find_conflicting_dag('buffalo') != 'monkey' + assert cco.find_conflicting_dag('buffalo') == 'chicken' + +def test_execute_exception(): + """ + Run the whole execute function for testing + """ + cco = ConcurrencyCheckOperator( + conflicting_dag_set=CONFLICT_SET, + task_id='bogus') + # dag_id of cow should cause monkey to conflict. + cco.check_dag_id = 'cow' + cco.get_executing_dags = get_executing_dags_stub + try: + cco.execute(None) + pytest.fail('AirflowException should have been raised') + except AirflowException as airflow_exception: + assert 'Aborting run' in airflow_exception.args[0] + +def test_execute_success(): + """ + Run the whole execute function for testing - successfully! + """ + cco = ConcurrencyCheckOperator( + conflicting_dag_set=set(['car', 'truck']), + task_id='bogus') + + # dag_id of airplane should have no conflicts + cco.check_dag_id = 'airplane' + cco.get_executing_dags = get_executing_dags_stub + try: + cco.execute(None) + assert True + except AirflowException: + pytest.fail('AirflowException should not have been raised')