diff --git a/src/bin/shipyard_airflow/shipyard_airflow/dags/dag_names.py b/src/bin/shipyard_airflow/shipyard_airflow/dags/dag_names.py index ac8ca05f..0c37ce16 100644 --- a/src/bin/shipyard_airflow/shipyard_airflow/dags/dag_names.py +++ b/src/bin/shipyard_airflow/shipyard_airflow/dags/dag_names.py @@ -14,6 +14,7 @@ # Subdags ALL_PREFLIGHT_CHECKS_DAG_NAME = 'preflight' +UCP_PREFLIGHT_NAME = 'ucp_preflight_check' ARMADA_BUILD_DAG_NAME = 'armada_build' DESTROY_SERVER_DAG_NAME = 'destroy_server' DRYDOCK_BUILD_DAG_NAME = 'drydock_build' diff --git a/src/bin/shipyard_airflow/shipyard_airflow/plugins/concurrency_check_operator.py b/src/bin/shipyard_airflow/shipyard_airflow/plugins/concurrency_check_operator.py index c99abced..a4b5dce9 100644 --- a/src/bin/shipyard_airflow/shipyard_airflow/plugins/concurrency_check_operator.py +++ b/src/bin/shipyard_airflow/shipyard_airflow/plugins/concurrency_check_operator.py @@ -20,6 +20,8 @@ from airflow.hooks.postgres_hook import PostgresHook from airflow.exceptions import AirflowException # constants related to the dag_run table. +from shipyard_airflow.plugins.xcom_pusher import XcomPusher + DAG_RUN_SELECT_RUNNING_SQL = ("select dag_id, execution_date " "from dag_run " "where state='running'") @@ -58,12 +60,15 @@ class ConcurrencyCheckOperator(BaseOperator): def __init__(self, conflicting_dag_set=None, *args, **kwargs): super(ConcurrencyCheckOperator, self).__init__(*args, **kwargs) self.conflicting_dag_set = conflicting_dag_set + self.xcom_push = None def execute(self, context): """ Run the check to see if this DAG has an concurrency issues with other DAGs. Stop the workflow if there is. """ + self.xcom_push = XcomPusher(context['task_instance']) + self._xcom_push_status(False) if self.conflicting_dag_set is None: self.check_dag_id = self.dag.dag_id logging.debug('dag_id is %s', self.check_dag_id) @@ -85,6 +90,7 @@ class ConcurrencyCheckOperator(BaseOperator): else: self.abort_conflict( dag_name=self.check_dag_id, conflict=conflicting_dag) + self._xcom_push_status(True) def get_executing_dags(self): """ @@ -127,6 +133,14 @@ class ConcurrencyCheckOperator(BaseOperator): logging.error(conflict_string) raise AirflowException(conflict_string) + def _xcom_push_status(self, status): + """ + Push the status of the concurrency check + :param status: bool of whether or not this task is successful + :return: + """ + self.xcom_push.xcom_push(key="concurrency_check_success", value=status) + class ConcurrencyCheckPlugin(AirflowPlugin): """ diff --git a/src/bin/shipyard_airflow/shipyard_airflow/plugins/xcom_puller.py b/src/bin/shipyard_airflow/shipyard_airflow/plugins/xcom_puller.py index 8fedb2c1..9f85bf84 100644 --- a/src/bin/shipyard_airflow/shipyard_airflow/plugins/xcom_puller.py +++ b/src/bin/shipyard_airflow/shipyard_airflow/plugins/xcom_puller.py @@ -13,6 +13,8 @@ # limitations under the License. import logging +from shipyard_airflow.dags import dag_names + LOG = logging.getLogger(__name__) @@ -66,7 +68,7 @@ class XcomPuller(object): def get_deployment_configuration(self): """Retrieve the deployment configuration dictionary""" - source_task = 'deployment_configuration' + source_task = dag_names.DEPLOYMENT_CONFIGURATION source_dag = None key = None return self._get_xcom(source_task=source_task, @@ -80,7 +82,7 @@ class XcomPuller(object): that contains information about the workflow such as action_id, name and other related parameters """ - source_task = 'action_xcom' + source_task = dag_names.ACTION_XCOM source_dag = None key = 'action' return self._get_xcom(source_task=source_task, @@ -89,7 +91,7 @@ class XcomPuller(object): def get_action_type(self): """Retrieve the action type""" - source_task = 'action_xcom' + source_task = dag_names.ACTION_XCOM source_dag = None key = 'action_type' return self._get_xcom(source_task=source_task, @@ -98,9 +100,18 @@ class XcomPuller(object): def get_check_drydock_continue_on_fail(self): """Check if 'drydock_continue_on_fail' key exists""" - source_task = 'ucp_preflight_check' + source_task = dag_names.UCP_PREFLIGHT_NAME source_dag = None key = 'drydock_continue_on_fail' return self._get_xcom(source_task=source_task, dag_id=source_dag, key=key) + + def get_concurrency_status(self): + """Retrieve the success status of concurrency_check""" + source_task = dag_names.CONCURRENCY_CHECK + source_dag = None + key = 'concurrency_check_success' + return self._get_xcom(source_task=source_task, + dag_id=source_dag, + key=key) diff --git a/src/bin/shipyard_airflow/tests/unit/plugins/test_concurrency_check_operator.py b/src/bin/shipyard_airflow/tests/unit/plugins/test_concurrency_check_operator.py index bf6050ef..c5bfc2eb 100644 --- a/src/bin/shipyard_airflow/tests/unit/plugins/test_concurrency_check_operator.py +++ b/src/bin/shipyard_airflow/tests/unit/plugins/test_concurrency_check_operator.py @@ -11,6 +11,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from unittest import mock + import pytest from shipyard_airflow.plugins import concurrency_check_operator as operator from shipyard_airflow.plugins.concurrency_check_operator import ( @@ -99,7 +101,8 @@ def test_find_conflicting_dag(): assert cco.find_conflicting_dag('buffalo') == 'chicken' -def test_execute_exception(): +@mock.patch('shipyard_airflow.plugins.concurrency_check_operator.XcomPusher') +def test_execute_exception(xcom_pusher): """ Run the whole execute function for testing """ @@ -110,13 +113,15 @@ def test_execute_exception(): cco.check_dag_id = 'cow' cco.get_executing_dags = get_executing_dags_stub try: - cco.execute(None) + context = {'task_instance': None} + cco.execute(context) pytest.fail('AirflowException should have been raised') except AirflowException as airflow_exception: assert 'Aborting run' in airflow_exception.args[0] -def test_execute_success(): +@mock.patch('shipyard_airflow.plugins.concurrency_check_operator.XcomPusher') +def test_execute_success(xcom_pusher): """ Run the whole execute function for testing - successfully! """ @@ -128,7 +133,8 @@ def test_execute_success(): cco.check_dag_id = 'airplane' cco.get_executing_dags = get_executing_dags_stub try: - cco.execute(None) + context = {'task_instance': None} + cco.execute(context) assert True except AirflowException: pytest.fail('AirflowException should not have been raised')