Create Xcom Pusher/Puller for concurrency_check
This change adds support in the concurrency_check plugin to push its status to xcom so other components can easily check if it has passed. This is especially useful for components that need to always run, but still want to respect the concurrency check. This is also a slight refactor of XcomPuller to make it import its dag names from dag_names.py Change-Id: I9ca6b43d7789d9499121384d4427835e296c44b8
This commit is contained in:
parent
c19580378b
commit
b5469c39ec
@ -14,6 +14,7 @@
|
||||
|
||||
# Subdags
|
||||
ALL_PREFLIGHT_CHECKS_DAG_NAME = 'preflight'
|
||||
UCP_PREFLIGHT_NAME = 'ucp_preflight_check'
|
||||
ARMADA_BUILD_DAG_NAME = 'armada_build'
|
||||
DESTROY_SERVER_DAG_NAME = 'destroy_server'
|
||||
DRYDOCK_BUILD_DAG_NAME = 'drydock_build'
|
||||
|
@ -20,6 +20,8 @@ from airflow.hooks.postgres_hook import PostgresHook
|
||||
from airflow.exceptions import AirflowException
|
||||
|
||||
# constants related to the dag_run table.
|
||||
from shipyard_airflow.plugins.xcom_pusher import XcomPusher
|
||||
|
||||
DAG_RUN_SELECT_RUNNING_SQL = ("select dag_id, execution_date "
|
||||
"from dag_run "
|
||||
"where state='running'")
|
||||
@ -58,12 +60,15 @@ class ConcurrencyCheckOperator(BaseOperator):
|
||||
def __init__(self, conflicting_dag_set=None, *args, **kwargs):
|
||||
super(ConcurrencyCheckOperator, self).__init__(*args, **kwargs)
|
||||
self.conflicting_dag_set = conflicting_dag_set
|
||||
self.xcom_push = None
|
||||
|
||||
def execute(self, context):
|
||||
"""
|
||||
Run the check to see if this DAG has an concurrency issues with other
|
||||
DAGs. Stop the workflow if there is.
|
||||
"""
|
||||
self.xcom_push = XcomPusher(context['task_instance'])
|
||||
self._xcom_push_status(False)
|
||||
if self.conflicting_dag_set is None:
|
||||
self.check_dag_id = self.dag.dag_id
|
||||
logging.debug('dag_id is %s', self.check_dag_id)
|
||||
@ -85,6 +90,7 @@ class ConcurrencyCheckOperator(BaseOperator):
|
||||
else:
|
||||
self.abort_conflict(
|
||||
dag_name=self.check_dag_id, conflict=conflicting_dag)
|
||||
self._xcom_push_status(True)
|
||||
|
||||
def get_executing_dags(self):
|
||||
"""
|
||||
@ -127,6 +133,14 @@ class ConcurrencyCheckOperator(BaseOperator):
|
||||
logging.error(conflict_string)
|
||||
raise AirflowException(conflict_string)
|
||||
|
||||
def _xcom_push_status(self, status):
|
||||
"""
|
||||
Push the status of the concurrency check
|
||||
:param status: bool of whether or not this task is successful
|
||||
:return:
|
||||
"""
|
||||
self.xcom_push.xcom_push(key="concurrency_check_success", value=status)
|
||||
|
||||
|
||||
class ConcurrencyCheckPlugin(AirflowPlugin):
|
||||
"""
|
||||
|
@ -13,6 +13,8 @@
|
||||
# limitations under the License.
|
||||
import logging
|
||||
|
||||
from shipyard_airflow.dags import dag_names
|
||||
|
||||
LOG = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@ -66,7 +68,7 @@ class XcomPuller(object):
|
||||
|
||||
def get_deployment_configuration(self):
|
||||
"""Retrieve the deployment configuration dictionary"""
|
||||
source_task = 'deployment_configuration'
|
||||
source_task = dag_names.DEPLOYMENT_CONFIGURATION
|
||||
source_dag = None
|
||||
key = None
|
||||
return self._get_xcom(source_task=source_task,
|
||||
@ -80,7 +82,7 @@ class XcomPuller(object):
|
||||
that contains information about the workflow such as action_id, name
|
||||
and other related parameters
|
||||
"""
|
||||
source_task = 'action_xcom'
|
||||
source_task = dag_names.ACTION_XCOM
|
||||
source_dag = None
|
||||
key = 'action'
|
||||
return self._get_xcom(source_task=source_task,
|
||||
@ -89,7 +91,7 @@ class XcomPuller(object):
|
||||
|
||||
def get_action_type(self):
|
||||
"""Retrieve the action type"""
|
||||
source_task = 'action_xcom'
|
||||
source_task = dag_names.ACTION_XCOM
|
||||
source_dag = None
|
||||
key = 'action_type'
|
||||
return self._get_xcom(source_task=source_task,
|
||||
@ -98,9 +100,18 @@ class XcomPuller(object):
|
||||
|
||||
def get_check_drydock_continue_on_fail(self):
|
||||
"""Check if 'drydock_continue_on_fail' key exists"""
|
||||
source_task = 'ucp_preflight_check'
|
||||
source_task = dag_names.UCP_PREFLIGHT_NAME
|
||||
source_dag = None
|
||||
key = 'drydock_continue_on_fail'
|
||||
return self._get_xcom(source_task=source_task,
|
||||
dag_id=source_dag,
|
||||
key=key)
|
||||
|
||||
def get_concurrency_status(self):
|
||||
"""Retrieve the success status of concurrency_check"""
|
||||
source_task = dag_names.CONCURRENCY_CHECK
|
||||
source_dag = None
|
||||
key = 'concurrency_check_success'
|
||||
return self._get_xcom(source_task=source_task,
|
||||
dag_id=source_dag,
|
||||
key=key)
|
||||
|
@ -11,6 +11,8 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from unittest import mock
|
||||
|
||||
import pytest
|
||||
from shipyard_airflow.plugins import concurrency_check_operator as operator
|
||||
from shipyard_airflow.plugins.concurrency_check_operator import (
|
||||
@ -99,7 +101,8 @@ def test_find_conflicting_dag():
|
||||
assert cco.find_conflicting_dag('buffalo') == 'chicken'
|
||||
|
||||
|
||||
def test_execute_exception():
|
||||
@mock.patch('shipyard_airflow.plugins.concurrency_check_operator.XcomPusher')
|
||||
def test_execute_exception(xcom_pusher):
|
||||
"""
|
||||
Run the whole execute function for testing
|
||||
"""
|
||||
@ -110,13 +113,15 @@ def test_execute_exception():
|
||||
cco.check_dag_id = 'cow'
|
||||
cco.get_executing_dags = get_executing_dags_stub
|
||||
try:
|
||||
cco.execute(None)
|
||||
context = {'task_instance': None}
|
||||
cco.execute(context)
|
||||
pytest.fail('AirflowException should have been raised')
|
||||
except AirflowException as airflow_exception:
|
||||
assert 'Aborting run' in airflow_exception.args[0]
|
||||
|
||||
|
||||
def test_execute_success():
|
||||
@mock.patch('shipyard_airflow.plugins.concurrency_check_operator.XcomPusher')
|
||||
def test_execute_success(xcom_pusher):
|
||||
"""
|
||||
Run the whole execute function for testing - successfully!
|
||||
"""
|
||||
@ -128,7 +133,8 @@ def test_execute_success():
|
||||
cco.check_dag_id = 'airplane'
|
||||
cco.get_executing_dags = get_executing_dags_stub
|
||||
try:
|
||||
cco.execute(None)
|
||||
context = {'task_instance': None}
|
||||
cco.execute(context)
|
||||
assert True
|
||||
except AirflowException:
|
||||
pytest.fail('AirflowException should not have been raised')
|
||||
|
Loading…
Reference in New Issue
Block a user