Create Xcom Pusher/Puller for concurrency_check

This change adds support in the concurrency_check plugin to push its
status to xcom so other components can easily check if it has passed.
This is especially useful for components that need to always run, but
still want to respect the concurrency check.

This is also a slight refactor of XcomPuller to make it import its dag
names from dag_names.py

Change-Id: I9ca6b43d7789d9499121384d4427835e296c44b8
This commit is contained in:
Michael Beaver 2019-05-23 17:42:33 -05:00
parent c19580378b
commit b5469c39ec
4 changed files with 40 additions and 8 deletions
src/bin/shipyard_airflow

View File

@ -14,6 +14,7 @@
# Subdags
ALL_PREFLIGHT_CHECKS_DAG_NAME = 'preflight'
UCP_PREFLIGHT_NAME = 'ucp_preflight_check'
ARMADA_BUILD_DAG_NAME = 'armada_build'
DESTROY_SERVER_DAG_NAME = 'destroy_server'
DRYDOCK_BUILD_DAG_NAME = 'drydock_build'

View File

@ -20,6 +20,8 @@ from airflow.hooks.postgres_hook import PostgresHook
from airflow.exceptions import AirflowException
# constants related to the dag_run table.
from shipyard_airflow.plugins.xcom_pusher import XcomPusher
DAG_RUN_SELECT_RUNNING_SQL = ("select dag_id, execution_date "
"from dag_run "
"where state='running'")
@ -58,12 +60,15 @@ class ConcurrencyCheckOperator(BaseOperator):
def __init__(self, conflicting_dag_set=None, *args, **kwargs):
super(ConcurrencyCheckOperator, self).__init__(*args, **kwargs)
self.conflicting_dag_set = conflicting_dag_set
self.xcom_push = None
def execute(self, context):
"""
Run the check to see if this DAG has an concurrency issues with other
DAGs. Stop the workflow if there is.
"""
self.xcom_push = XcomPusher(context['task_instance'])
self._xcom_push_status(False)
if self.conflicting_dag_set is None:
self.check_dag_id = self.dag.dag_id
logging.debug('dag_id is %s', self.check_dag_id)
@ -85,6 +90,7 @@ class ConcurrencyCheckOperator(BaseOperator):
else:
self.abort_conflict(
dag_name=self.check_dag_id, conflict=conflicting_dag)
self._xcom_push_status(True)
def get_executing_dags(self):
"""
@ -127,6 +133,14 @@ class ConcurrencyCheckOperator(BaseOperator):
logging.error(conflict_string)
raise AirflowException(conflict_string)
def _xcom_push_status(self, status):
"""
Push the status of the concurrency check
:param status: bool of whether or not this task is successful
:return:
"""
self.xcom_push.xcom_push(key="concurrency_check_success", value=status)
class ConcurrencyCheckPlugin(AirflowPlugin):
"""

View File

@ -13,6 +13,8 @@
# limitations under the License.
import logging
from shipyard_airflow.dags import dag_names
LOG = logging.getLogger(__name__)
@ -66,7 +68,7 @@ class XcomPuller(object):
def get_deployment_configuration(self):
"""Retrieve the deployment configuration dictionary"""
source_task = 'deployment_configuration'
source_task = dag_names.DEPLOYMENT_CONFIGURATION
source_dag = None
key = None
return self._get_xcom(source_task=source_task,
@ -80,7 +82,7 @@ class XcomPuller(object):
that contains information about the workflow such as action_id, name
and other related parameters
"""
source_task = 'action_xcom'
source_task = dag_names.ACTION_XCOM
source_dag = None
key = 'action'
return self._get_xcom(source_task=source_task,
@ -89,7 +91,7 @@ class XcomPuller(object):
def get_action_type(self):
"""Retrieve the action type"""
source_task = 'action_xcom'
source_task = dag_names.ACTION_XCOM
source_dag = None
key = 'action_type'
return self._get_xcom(source_task=source_task,
@ -98,9 +100,18 @@ class XcomPuller(object):
def get_check_drydock_continue_on_fail(self):
"""Check if 'drydock_continue_on_fail' key exists"""
source_task = 'ucp_preflight_check'
source_task = dag_names.UCP_PREFLIGHT_NAME
source_dag = None
key = 'drydock_continue_on_fail'
return self._get_xcom(source_task=source_task,
dag_id=source_dag,
key=key)
def get_concurrency_status(self):
"""Retrieve the success status of concurrency_check"""
source_task = dag_names.CONCURRENCY_CHECK
source_dag = None
key = 'concurrency_check_success'
return self._get_xcom(source_task=source_task,
dag_id=source_dag,
key=key)

View File

@ -11,6 +11,8 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from unittest import mock
import pytest
from shipyard_airflow.plugins import concurrency_check_operator as operator
from shipyard_airflow.plugins.concurrency_check_operator import (
@ -99,7 +101,8 @@ def test_find_conflicting_dag():
assert cco.find_conflicting_dag('buffalo') == 'chicken'
def test_execute_exception():
@mock.patch('shipyard_airflow.plugins.concurrency_check_operator.XcomPusher')
def test_execute_exception(xcom_pusher):
"""
Run the whole execute function for testing
"""
@ -110,13 +113,15 @@ def test_execute_exception():
cco.check_dag_id = 'cow'
cco.get_executing_dags = get_executing_dags_stub
try:
cco.execute(None)
context = {'task_instance': None}
cco.execute(context)
pytest.fail('AirflowException should have been raised')
except AirflowException as airflow_exception:
assert 'Aborting run' in airflow_exception.args[0]
def test_execute_success():
@mock.patch('shipyard_airflow.plugins.concurrency_check_operator.XcomPusher')
def test_execute_success(xcom_pusher):
"""
Run the whole execute function for testing - successfully!
"""
@ -128,7 +133,8 @@ def test_execute_success():
cco.check_dag_id = 'airplane'
cco.get_executing_dags = get_executing_dags_stub
try:
cco.execute(None)
context = {'task_instance': None}
cco.execute(context)
assert True
except AirflowException:
pytest.fail('AirflowException should not have been raised')