shipyard/shipyard_airflow/db/airflow_db.py

288 lines
7.9 KiB
Python

# Copyright 2017 AT&T Intellectual Property. All other rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Airflow database access - see db.py for instances to use
"""
import sqlalchemy
from oslo_config import cfg
from shipyard_airflow.db.common_db import DbAccess
from shipyard_airflow.db.errors import AirflowStateError
CONF = cfg.CONF
class AirflowDbAccess(DbAccess):
"""
Airflow database access
WARNING: This is a large set of assumptions based on the way airflow
arranges its database and are subject to change with airflow future
releases - i.e. we're leveraging undocumented/non-exposed interfaces
for airflow to work around lack of API and feature functionality.
"""
SELECT_ALL_DAG_RUNS = sqlalchemy.sql.text('''
SELECT
"id",
"dag_id",
"execution_date",
"state",
"run_id",
"external_trigger",
"conf",
"end_date",
"start_date"
FROM
dag_run
''')
SELECT_DAG_RUNS_BY_ID = sqlalchemy.sql.text('''
SELECT
"id",
"dag_id",
"execution_date",
"state",
"run_id",
"external_trigger",
"conf",
"end_date",
"start_date"
FROM
dag_run
WHERE
dag_id = :dag_id
AND
execution_date = :execution_date
''')
# The like parameter must have '%' appropriately applied to the args
# used to merge into this query.
SELECT_DAG_RUNS_LIKE_ID = sqlalchemy.sql.text('''
SELECT
"id",
"dag_id",
"execution_date",
"state",
"run_id",
"external_trigger",
"conf",
"end_date",
"start_date"
FROM
dag_run
WHERE
dag_id LIKE :dag_id
AND
execution_date = :execution_date
''')
SELECT_ALL_TASKS = sqlalchemy.sql.text('''
SELECT
"task_id",
"dag_id",
"execution_date",
"start_date",
"end_date",
"duration",
"state",
"try_number",
"hostname",
"unixname",
"job_id",
"pool",
"queue",
"priority_weight",
"operator",
"queued_dttm",
"pid",
"max_tries"
FROM
task_instance
ORDER BY
priority_weight desc,
start_date
''')
# The like parameter must have '%' appropriately applied to the args
# used to merge into this query.
SELECT_TASKS_BY_ID = sqlalchemy.sql.text('''
SELECT
"task_id",
"dag_id",
"execution_date",
"start_date",
"end_date",
"duration",
"state",
"try_number",
"hostname",
"unixname",
"job_id",
"pool",
"queue",
"priority_weight",
"operator",
"queued_dttm",
"pid",
"max_tries"
FROM
task_instance
WHERE
dag_id LIKE :dag_id
AND
execution_date = :execution_date
ORDER BY
priority_weight desc,
start_date
''')
UPDATE_DAG_RUN_STATUS = sqlalchemy.sql.text('''
UPDATE
dag_run
SET
state = :state
WHERE
dag_id = :dag_id
AND
execution_date = :execution_date
''')
def __init__(self):
DbAccess.__init__(self)
def get_connection_string(self):
"""
Returns the connection string for this db connection
"""
return CONF.base.postgresql_airflow_db
def get_all_dag_runs(self):
"""
Retrieves all dag runs.
"""
return self.get_as_dict_array(AirflowDbAccess.SELECT_ALL_DAG_RUNS)
def get_dag_runs_by_id(self, dag_id, execution_date):
"""
Retrieves dag runs by dag id and execution date
"""
return self.get_as_dict_array(
AirflowDbAccess.SELECT_DAG_RUNS_BY_ID,
dag_id=dag_id,
execution_date=execution_date)
def get_dag_runs_like_id(self, dag_id, execution_date):
"""
Retrieves dag runs, for parent and child dags by the parent
dag id and execution date
"""
return self.get_as_dict_array(
AirflowDbAccess.SELECT_DAG_RUNS_LIKE_ID,
dag_id=dag_id + '%',
execution_date=execution_date)
def get_all_tasks(self):
"""
Retrieves all tasks.
"""
return self.get_as_dict_array(AirflowDbAccess.SELECT_ALL_TASKS)
def get_tasks_by_id(self, dag_id, execution_date):
"""
Retrieves tasks by dag id and execution date
"""
return self.get_as_dict_array(
AirflowDbAccess.SELECT_TASKS_BY_ID,
dag_id=dag_id + '%',
execution_date=execution_date)
def stop_dag_run(self, dag_id, execution_date):
"""
Triggers an update to set a dag_run to failed state
causing dag_run to be stopped
running -> failed
"""
self._control_dag_run(
dag_id=dag_id,
execution_date=execution_date,
expected_state='running',
desired_state='failed')
def pause_dag_run(self, dag_id, execution_date):
"""
Triggers an update to set a dag_run to paused state
causing dag_run to be paused
running -> paused
"""
self._control_dag_run(
dag_id=dag_id,
execution_date=execution_date,
expected_state='running',
desired_state='paused')
def unpause_dag_run(self, dag_id, execution_date):
"""
Triggers an update to set a dag_run to running state
causing dag_run to be unpaused
paused -> running
"""
self._control_dag_run(
dag_id=dag_id,
execution_date=execution_date,
expected_state='paused',
desired_state='running')
def check_dag_run_state(self, dag_id, execution_date, expected_state):
"""
Examines a dag_run for state. Throws execption if it's not right
"""
dag_run_list = self.get_dag_runs_by_id(
dag_id=dag_id, execution_date=execution_date)
if dag_run_list:
dag_run = dag_run_list[0]
if dag_run['state'] != expected_state:
raise AirflowStateError(
message='dag_run state must be running, but is {}'.format(
dag_run['state']))
else:
# not found
raise AirflowStateError(message='dag_run does not exist')
def _control_dag_run(self, dag_id, execution_date, expected_state,
desired_state):
"""
checks a dag_run's state for expected state, and sets it to the
desired state
"""
self.check_dag_run_state(
dag_id=dag_id,
execution_date=execution_date,
expected_state=expected_state)
self._set_dag_run_state(
state=desired_state, dag_id=dag_id, execution_date=execution_date)
def _set_dag_run_state(self, state, dag_id, execution_date):
"""
Sets a dag run to the specified state.
WARNING: this assumes that airflow works by reading state from the
dag_run table dynamically, is not caching results, and doesn't
start to use the states we're using in a new way.
"""
self.perform_insert(
AirflowDbAccess.UPDATE_DAG_RUN_STATUS,
state=state,
dag_id=dag_id,
execution_date=execution_date)