Working on Data Flow (step 3)

* Refactored tests hierarcy (moved get_resource to the base class)
* Fixed logic in abstract engine (transaction scoope was wrong)
* Added stub methods for context manipulations (not everywhere yet)
* Refactored actions
* Created ECHO action to be able to emulate required output
* Added simple data flow test (two dependent tasks)

TODO:
* More tests
* Refactor actions according to the latest discussions

Change-Id: I5ba5e330110889014eeca53501ddae54dc9a1236
This commit is contained in:
Renat Akhmerov 2014-02-24 11:50:06 +07:00
parent 69f6aa6477
commit 31ed10aa18
26 changed files with 511 additions and 201 deletions

View File

@ -22,14 +22,11 @@ from mistral import dsl
from mistral import exceptions as exc from mistral import exceptions as exc
from mistral.engine import states from mistral.engine import states
from mistral.engine import workflow from mistral.engine import workflow
from mistral.engine import data_flow
LOG = logging.getLogger(__name__) LOG = logging.getLogger(__name__)
# TODO(rakhmerov): Upcoming Data Flow changes:
# 1. Calculate "in_context" for all the tasks submitted for execution.
# 2. Transfer "in_context" along with task data over AMQP.
class AbstractEngine(object): class AbstractEngine(object):
@classmethod @classmethod
@ -41,10 +38,10 @@ class AbstractEngine(object):
def start_workflow_execution(cls, workbook_name, task_name, context): def start_workflow_execution(cls, workbook_name, task_name, context):
db_api.start_tx() db_api.start_tx()
wb_dsl = cls._get_wb_dsl(workbook_name)
# Persist execution and tasks in DB. # Persist execution and tasks in DB.
try: try:
wb_dsl = cls._get_wb_dsl(workbook_name)
execution = cls._create_execution(workbook_name, execution = cls._create_execution(workbook_name,
task_name, task_name,
context) context)
@ -52,9 +49,14 @@ class AbstractEngine(object):
tasks = cls._create_tasks( tasks = cls._create_tasks(
workflow.find_workflow_tasks(wb_dsl, task_name), workflow.find_workflow_tasks(wb_dsl, task_name),
wb_dsl, wb_dsl,
workbook_name, execution['id'] workbook_name,
execution['id']
) )
tasks_to_start = workflow.find_resolved_tasks(tasks)
data_flow.prepare_tasks(tasks_to_start, context)
db_api.commit_tx() db_api.commit_tx()
except Exception as e: except Exception as e:
raise exc.EngineException("Failed to create necessary DB objects:" raise exc.EngineException("Failed to create necessary DB objects:"
@ -62,10 +64,7 @@ class AbstractEngine(object):
finally: finally:
db_api.end_tx() db_api.end_tx()
# TODO(rakhmerov): This doesn't look correct anymore, we shouldn't cls._run_tasks(tasks_to_start)
# start tasks which don't have dependencies but are reachable only
# via direct transitions.
cls._run_tasks(workflow.find_resolved_tasks(tasks))
return execution return execution
@ -74,6 +73,7 @@ class AbstractEngine(object):
task_id, state, result): task_id, state, result):
db_api.start_tx() db_api.start_tx()
try:
wb_dsl = cls._get_wb_dsl(workbook_name) wb_dsl = cls._get_wb_dsl(workbook_name)
#TODO(rakhmerov): validate state transition #TODO(rakhmerov): validate state transition
@ -83,22 +83,30 @@ class AbstractEngine(object):
execution = db_api.execution_get(workbook_name, execution_id) execution = db_api.execution_get(workbook_name, execution_id)
cls._create_next_tasks(task, wb_dsl, workbook_name, execution_id) # Calculate task outbound context.
# TODO(rakhmerov): publish result into context selectively
outbound_context = \
data_flow.merge_into_context(task['in_context'], result)
cls._create_next_tasks(task, wb_dsl)
# Determine what tasks need to be started. # Determine what tasks need to be started.
tasks = db_api.tasks_get(workbook_name, execution_id) tasks = db_api.tasks_get(workbook_name, execution_id)
# TODO(nmakhotkin) merge result into context
try:
new_exec_state = cls._determine_execution_state(execution, tasks) new_exec_state = cls._determine_execution_state(execution, tasks)
if execution['state'] != new_exec_state: if execution['state'] != new_exec_state:
execution = \
db_api.execution_update(workbook_name, execution_id, { db_api.execution_update(workbook_name, execution_id, {
"state": new_exec_state "state": new_exec_state
}) })
LOG.info("Changed execution state: %s" % execution) LOG.info("Changed execution state: %s" % execution)
tasks_to_start = workflow.find_resolved_tasks(tasks)
data_flow.prepare_tasks(tasks_to_start, outbound_context)
db_api.commit_tx() db_api.commit_tx()
except Exception as e: except Exception as e:
raise exc.EngineException("Failed to create necessary DB objects:" raise exc.EngineException("Failed to create necessary DB objects:"
@ -109,8 +117,8 @@ class AbstractEngine(object):
if states.is_stopped_or_finished(execution["state"]): if states.is_stopped_or_finished(execution["state"]):
return task return task
if tasks: if tasks_to_start:
cls._run_tasks(workflow.find_resolved_tasks(tasks)) cls._run_tasks(tasks_to_start)
return task return task
@ -149,11 +157,11 @@ class AbstractEngine(object):
}) })
@classmethod @classmethod
def _create_next_tasks(cls, task, wb_dsl, workbook_name, execution_id): def _create_next_tasks(cls, task, wb_dsl):
dsl_tasks = workflow.find_tasks_after_completion(task, wb_dsl) dsl_tasks = workflow.find_tasks_after_completion(task, wb_dsl)
tasks = cls._create_tasks(dsl_tasks, wb_dsl, workbook_name, tasks = cls._create_tasks(dsl_tasks, wb_dsl, task['workbook_name'],
execution_id) task['execution_id'])
return workflow.find_resolved_tasks(tasks) return workflow.find_resolved_tasks(tasks)

View File

@ -34,6 +34,7 @@ def create_action(task):
def _get_mapping(): def _get_mapping():
return { return {
action_types.ECHO: get_echo_action,
action_types.REST_API: get_rest_action, action_types.REST_API: get_rest_action,
action_types.MISTRAL_REST_API: get_mistral_rest_action, action_types.MISTRAL_REST_API: get_mistral_rest_action,
action_types.OSLO_RPC: get_amqp_action, action_types.OSLO_RPC: get_amqp_action,
@ -48,6 +49,15 @@ def _find_action_result_helper(task, action):
return {} return {}
def get_echo_action(task):
action_type = a_h.get_action_type(task)
action_name = task['task_dsl']['action'].split(':')[1]
output = task['service_dsl']['actions'][action_name].get('output', {})
return actions.EchoAction(action_type, action_name, output=output)
def get_rest_action(task): def get_rest_action(task):
action_type = a_h.get_action_type(task) action_type = a_h.get_action_type(task)
action_name = task['task_dsl']['action'].split(':')[1] action_name = task['task_dsl']['action'].split(':')[1]
@ -63,7 +73,7 @@ def get_rest_action(task):
method = action_dsl['parameters'].get('method', "GET") method = action_dsl['parameters'].get('method', "GET")
# input_yaql = task.get('input') # input_yaql = task.get('input')
# TODO(nmakhotkin) extract input from context within the YAQL expression # TODO(nmakhotkin) extract input from context with the YAQL expression
task_input = {} # expressions.evaluate(input_expr, ctx) task_input = {} # expressions.evaluate(input_expr, ctx)
task_data = {} task_data = {}

View File

@ -15,9 +15,6 @@
# limitations under the License. # limitations under the License.
from mistral.engine.actions import action_types as a_t from mistral.engine.actions import action_types as a_t
from mistral import exceptions as exc
from mistral.engine import states
from mistral.engine import expressions as expr
def get_action_type(task): def get_action_type(task):
@ -26,19 +23,3 @@ def get_action_type(task):
def is_task_synchronous(task): def is_task_synchronous(task):
return get_action_type(task) != a_t.MISTRAL_REST_API return get_action_type(task) != a_t.MISTRAL_REST_API
def extract_state_result(action, action_result):
# All non-Mistral tasks are sync-auto because service doesn't know
# about Mistral and we need to receive the result immediately.
if action.type != a_t.MISTRAL_REST_API:
if action.result_helper.get('select'):
result = expr.evaluate(action.result_helper['select'],
action_result)
else:
result = action_result
# TODO(nmakhotkin) get state for other actions
state = states.get_state_by_http_status_code(action.status)
return state, result
raise exc.InvalidActionException("Error. Wrong type of action to "
"retrieve the result")

View File

@ -17,12 +17,13 @@
"""Valid action types.""" """Valid action types."""
ECHO = 'ECHO'
REST_API = 'REST_API' REST_API = 'REST_API'
OSLO_RPC = 'OSLO_RPC' OSLO_RPC = 'OSLO_RPC'
MISTRAL_REST_API = 'MISTRAL_REST_API' MISTRAL_REST_API = 'MISTRAL_REST_API'
SEND_EMAIL = "SEND_EMAIL" SEND_EMAIL = "SEND_EMAIL"
_ALL = [REST_API, OSLO_RPC, MISTRAL_REST_API, SEND_EMAIL] _ALL = [ECHO, REST_API, OSLO_RPC, MISTRAL_REST_API, SEND_EMAIL]
def is_valid(action_type): def is_valid(action_type):

View File

@ -14,36 +14,61 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
#TODO(dzimine):separate actions across different files/modules
import abc
from email.mime.text import MIMEText
import smtplib
from amqplib import client_0_8 as amqp from amqplib import client_0_8 as amqp
import requests import requests
#TODO(dzimine):separate actions across different files/modules
import smtplib
from email.mime.text import MIMEText
from mistral.openstack.common import log as logging from mistral.openstack.common import log as logging
from mistral import exceptions as exc
LOG = logging.getLogger(__name__) LOG = logging.getLogger(__name__)
class BaseAction(object): class Action(object):
status = None status = None
def __init__(self, action_type, action_name): def __init__(self, action_type, action_name):
self.type = action_type self.type = action_type
self.name = action_name self.name = action_name
# Result_helper is a dict for retrieving result within YAQL expression @abc.abstractmethod
# and it belongs to action (for defining this attribute immediately
# at action creation).
self.result_helper = {}
def run(self): def run(self):
"""Run action logic.
:return: result of the action. Note that for asynchronous actions
it will always be None.
In case if action failed this method must throw a ActionException
to indicate that.
"""
pass pass
class RestAction(BaseAction): class EchoAction(Action):
"""Echo action.
This action just returns a configured value as a result without doing
anything else. The value of such action implementation is that it
can be used in development (for testing), demonstration and designing
of workflows themselves where echo action can play the role of temporary
stub.
"""
def __init__(self, action_type, action_name, output):
super(EchoAction, self).__init__(action_type, action_name)
self.output = output
def run(self):
return self.output
class RestAction(Action):
def __init__(self, action_type, action_name, url, params={}, def __init__(self, action_type, action_name, url, params={},
method="GET", headers={}, data={}): method="GET", headers={}, data={}):
super(RestAction, self).__init__(action_type, action_name) super(RestAction, self).__init__(action_type, action_name)
@ -57,20 +82,32 @@ class RestAction(BaseAction):
LOG.info("Sending action HTTP request " LOG.info("Sending action HTTP request "
"[method=%s, url=%s, params=%s, headers=%s]" % "[method=%s, url=%s, params=%s, headers=%s]" %
(self.method, self.url, self.params, self.headers)) (self.method, self.url, self.params, self.headers))
resp = requests.request(self.method, self.url, params=self.params,
headers=self.headers, data=self.data) try:
resp = requests.request(self.method,
self.url,
params=self.params,
headers=self.headers,
data=self.data)
except Exception as e:
raise exc.ActionException("Failed to send HTTP request: %s" % e)
LOG.info("Received HTTP response:\n%s\n%s" % LOG.info("Received HTTP response:\n%s\n%s" %
(resp.status_code, resp.content)) (resp.status_code, resp.content))
# TODO(rakhmerov):Here we need to apply logic related with
# extracting a result as configured in DSL.
# Return rather json than text, but response can contain text also. # Return rather json than text, but response can contain text also.
self.status = resp.status_code self.status = resp.status_code
try: try:
return resp.json() return resp.json()
except: except:
LOG.debug("HTTP response content is not json") LOG.debug("HTTP response content is not json.")
return resp.content return resp.content
class OsloRPCAction(BaseAction): class OsloRPCAction(Action):
def __init__(self, action_type, action_name, host, userid, password, def __init__(self, action_type, action_name, host, userid, password,
virtual_host, message, routing_key=None, port=5672, virtual_host, message, routing_key=None, port=5672,
exchange=None, queue_name=None): exchange=None, queue_name=None):
@ -116,7 +153,7 @@ class OsloRPCAction(BaseAction):
self.status = None self.status = None
class SendEmailAction(BaseAction): class SendEmailAction(Action):
def __init__(self, action_type, action_name, params, settings): def __init__(self, action_type, action_name, params, settings):
super(SendEmailAction, self).__init__(action_type, action_name) super(SendEmailAction, self).__init__(action_type, action_name)
#TODO(dzimine): validate parameters #TODO(dzimine): validate parameters
@ -143,8 +180,10 @@ class SendEmailAction(BaseAction):
message['Subject'] = self.subject message['Subject'] = self.subject
message['From'] = self.sender message['From'] = self.sender
message['To'] = self.to message['To'] = self.to
try: try:
s = smtplib.SMTP(self.smtp_server) s = smtplib.SMTP(self.smtp_server)
if self.password is not None: if self.password is not None:
# Sequence to request TLS connection and log in (RFC-2487). # Sequence to request TLS connection and log in (RFC-2487).
s.ehlo() s.ehlo()
@ -156,7 +195,5 @@ class SendEmailAction(BaseAction):
to_addrs=self.to, to_addrs=self.to,
msg=message.as_string()) msg=message.as_string())
except (smtplib.SMTPException, IOError) as e: except (smtplib.SMTPException, IOError) as e:
LOG.error("Error sending email message: %s" % e) raise exc.ActionException("Failed to send an email message: %s"
#NOTE(DZ): Raise Misral exception instead re-throwing SMTP? % e)
# For now just logging the error here and re-thorw the original
raise

View File

@ -0,0 +1,63 @@
# -*- coding: utf-8 -*-
#
# Copyright 2013 - Mirantis, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from mistral.db import api as db_api
from mistral.engine import expressions as expr
from mistral.openstack.common import log as logging
LOG = logging.getLogger(__name__)
def evaluate_task_input(task, context):
res = {}
params = task['task_dsl'].get('input', {})
if not params:
return res
for name, val in params.iteritems():
if expr.is_expression(val):
res[name] = expr.evaluate(val, context)
else:
res[name] = val
return res
def prepare_tasks(tasks, context):
for task in tasks:
# TODO(rakhmerov): Inbound context should be a merge of outbound
# contexts of task dependencies, if any.
task['in_context'] = context
task['input'] = evaluate_task_input(task, context)
db_api.task_update(task['workbook_name'],
task['execution_id'],
task['id'],
{'in_context': task['in_context'],
'input': task['input']})
def merge_into_context(context, values):
if not context:
return None
# TODO(rakhmerov): Take care of nested variables.
context.update(values)
return context

View File

@ -53,5 +53,10 @@ class YAQLEvaluator(Evaluator):
_EVALUATOR = YAQLEvaluator() _EVALUATOR = YAQLEvaluator()
def is_expression(s):
# TODO(rakhmerov): It should be generalized since it may not be YAQL.
return s and s.startswith('$.')
def evaluate(expression, context): def evaluate(expression, context):
return _EVALUATOR.evaluate(expression, context) return _EVALUATOR.evaluate(expression, context)

View File

@ -35,26 +35,32 @@ class LocalEngine(abs_eng.AbstractEngine):
@classmethod @classmethod
def _run_task(cls, task): def _run_task(cls, task):
action = a_f.create_action(task) action = a_f.create_action(task)
LOG.info("Task is started - %s" % task['name']) LOG.info("Task is started - %s" % task['name'])
db_api.task_update(task['workbook_name'], task['execution_id'],
task['id'], {'state': states.RUNNING})
if a_h.is_task_synchronous(task): if a_h.is_task_synchronous(task):
# In case of sync execution we run task try:
# and change state right after that. state, result = states.SUCCESS, action.run()
action_result = action.run() except exc.ActionException:
state, result = a_h.extract_state_result(action, action_result) state, result = states.ERROR, None
# TODO(nmakhotkin) save the result in the context with key
# action.result_helper['store_as']
if states.is_valid(state): cls.convey_task_result(task['workbook_name'],
return cls.convey_task_result(task['workbook_name'],
task['execution_id'], task['execution_id'],
task['id'], state, result) task['id'],
state, result)
else: else:
raise exc.EngineException("Action has returned invalid " try:
"state: %s" % state) action.run()
return action.run() db_api.task_update(task['workbook_name'],
task['execution_id'],
task['id'],
{'state': states.RUNNING})
except exc.ActionException:
cls.convey_task_result(task['workbook_name'],
task['execution_id'],
task['id'],
states.ERROR, None)
def get_engine(): def get_engine():

View File

@ -27,30 +27,37 @@ from mistral.engine.actions import action_helper as a_h
LOG = logging.getLogger(__name__) LOG = logging.getLogger(__name__)
# TODO(rakhmerov): Upcoming Data Flow changes:
# 1. Receive "in_context" along with task data.
# 2. Apply task input expression to "in_context" and calculate "input".
def do_task_action(task): def do_task_action(task):
LOG.info("Starting task action [task_id=%s, action='%s', service='%s'" % LOG.info("Starting task action [task_id=%s, action='%s', service='%s'" %
(task['id'], task['task_dsl']['action'], task['service_dsl'])) (task['id'], task['task_dsl']['action'], task['service_dsl']))
action = a_f.create_action(task)
if a_h.is_task_synchronous(task):
action_result = action.run()
state, result = a_h.extract_state_result(action, action_result)
# TODO(nmakhotkin) save the result in the context with key
# action.result_helper['store_as']
if states.is_valid(state): action = a_f.create_action(task)
return engine.convey_task_result(task['workbook_name'],
if a_h.is_task_synchronous(task):
try:
state, result = states.SUCCESS, action.run()
except exc.ActionException:
state, result = states.ERROR, None
engine.convey_task_result(task['workbook_name'],
task['execution_id'], task['execution_id'],
task['id'], state, result) task['id'],
state, result)
else: else:
raise exc.EngineException("Action has returned invalid " try:
"state: %s" % state)
action.run() action.run()
db_api.task_update(task['workbook_name'],
task['execution_id'],
task['id'],
{'state': states.RUNNING})
except exc.ActionException:
engine.convey_task_result(task['workbook_name'],
task['execution_id'],
task['id'],
states.ERROR, None)
def handle_task_error(task, exception): def handle_task_error(task, exception):
try: try:

View File

@ -89,10 +89,13 @@ def find_tasks_after_completion(task, wb_dsl):
found_tasks += _get_tasks_to_schedule(tasks_on_finish, wb_dsl) found_tasks += _get_tasks_to_schedule(tasks_on_finish, wb_dsl)
LOG.debug("Found tasks: %s" % found_tasks) LOG.debug("Found tasks: %s" % found_tasks)
workflow_tasks = [] workflow_tasks = []
for t in found_tasks: for t in found_tasks:
workflow_tasks += find_workflow_tasks(wb_dsl, t['name']) workflow_tasks += find_workflow_tasks(wb_dsl, t['name'])
LOG.debug("Workflow tasks to schedule: %s" % workflow_tasks) LOG.debug("Workflow tasks to schedule: %s" % workflow_tasks)
return workflow_tasks return workflow_tasks

View File

@ -59,6 +59,15 @@ class DBDuplicateEntry(MistralException):
self.message = message self.message = message
class ActionException(MistralException):
code = "ACTION_ERROR"
def __init__(self, message=None):
super(MistralException, self).__init__(message)
if message:
self.message = message
class EngineException(MistralException): class EngineException(MistralException):
code = "ENGINE_ERROR" code = "ENGINE_ERROR"

View File

@ -21,7 +21,7 @@ from webtest.app import AppError
from oslo.config import cfg from oslo.config import cfg
from mistral.openstack.common import importutils from mistral.openstack.common import importutils
from mistral.tests.unit import base as test_base from mistral.tests import base
# We need to make sure that all configuration properties are registered. # We need to make sure that all configuration properties are registered.
importutils.import_module("mistral.config") importutils.import_module("mistral.config")
@ -34,7 +34,7 @@ cfg.CONF.register_opt(cfg.BoolOpt('auth_enable', default=False), group='pecan')
__all__ = ['FunctionalTest'] __all__ = ['FunctionalTest']
class FunctionalTest(test_base.DbTestCase): class FunctionalTest(base.DbTestCase):
"""Used for functional tests where you need to test your """Used for functional tests where you need to test your
literal application and its integration with the framework. literal application and its integration with the framework.
""" """

View File

@ -15,10 +15,25 @@
# limitations under the License. # limitations under the License.
import unittest2 import unittest2
import pkg_resources as pkg
import os
import tempfile
from mistral import version
from mistral.db.sqlalchemy import api as db_api
from mistral.openstack.common.db.sqlalchemy import session
RESOURCES_PATH = 'tests/resources/'
def get_resource(resource_name):
return open(pkg.resource_filename(
version.version_info.package,
RESOURCES_PATH + resource_name)).read()
class BaseTest(unittest2.TestCase): class BaseTest(unittest2.TestCase):
def setUp(self): def setUp(self):
super(BaseTest, self).setUp() super(BaseTest, self).setUp()
@ -28,3 +43,19 @@ class BaseTest(unittest2.TestCase):
super(BaseTest, self).tearDown() super(BaseTest, self).tearDown()
# TODO: add whatever is needed for all Mistral tests in here # TODO: add whatever is needed for all Mistral tests in here
class DbTestCase(BaseTest):
def setUp(self):
self.db_fd, self.db_path = tempfile.mkstemp()
session.set_defaults('sqlite:///' + self.db_path, self.db_path)
db_api.setup_db()
def tearDown(self):
db_api.drop_db()
os.close(self.db_fd)
os.unlink(self.db_path)
def is_db_session_open(self):
return db_api._get_thread_local_session() is not None

View File

@ -0,0 +1,39 @@
Services:
MyService:
type: ECHO
actions:
build_full_name:
output:
full_name: John Doe
build_greeting:
output:
greeting: Hello, John Doe!
Workflow:
# context = {
# 'person': {
# 'first_name': 'John',
# 'last_name': 'Doe',
# 'address': {
# 'street': '124352 Broadway Street',
# 'city': 'Gloomington',
# 'country': 'USA'
# }
# }
# }
tasks:
build_full_name:
action: MyService:build_full_name
input:
first_name: $.person.first_name
last_name: $.person.last_name
output: full_name
build_greeting:
requires: [build_full_name]
action: MyService:build_greeting
input:
full_name: $.full_name
output: greeting

View File

@ -17,7 +17,7 @@
from mistral.openstack.common import timeutils from mistral.openstack.common import timeutils
from mistral.db.sqlalchemy import api as db_api from mistral.db.sqlalchemy import api as db_api
from mistral.tests.unit import base as test_base from mistral.tests import base as test_base
EVENTS = [ EVENTS = [

View File

@ -14,15 +14,10 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import copy
import mock
import unittest2 import unittest2
from mistral.engine.actions import actions
from mistral.engine.actions import action_factory from mistral.engine.actions import action_factory
from mistral.engine.actions import action_helper
from mistral.engine.actions import action_types from mistral.engine.actions import action_types
from mistral.engine import states
SAMPLE_TASK = { SAMPLE_TASK = {
@ -126,31 +121,3 @@ class ActionFactoryTest(unittest2.TestCase):
self.assertIn(email, action.to) self.assertIn(email, action.to)
self.assertEqual(task['service_dsl']['parameters']['smtp_server'], self.assertEqual(task['service_dsl']['parameters']['smtp_server'],
action.smtp_server) action.smtp_server)
@mock.patch.object(actions.RestAction, "run",
mock.MagicMock(return_value=SAMPLE_RESULT))
def test_action_result_with_results(self):
task = copy.deepcopy(SAMPLE_TASK)
task['service_dsl'].update({'type': action_types.REST_API})
create_vm = task['service_dsl']['actions']['create-vm']
create_vm.update(SAMPLE_RESULT_HELPER)
action = action_factory.create_action(task)
action_result = action.run()
action.status = 200
state, result = action_helper.extract_state_result(action,
action_result)
self.assertEqual(state, states.SUCCESS)
self.assertEqual(result, SAMPLE_RESULT['server']['id'])
@mock.patch.object(actions.RestAction, "run",
mock.MagicMock(return_value=SAMPLE_RESULT))
def test_action_result_without_results(self):
task = copy.deepcopy(SAMPLE_TASK)
task['service_dsl'].update({'type': action_types.REST_API})
action = action_factory.create_action(task)
action_result = action.run()
action.status = 200
state, result = action_helper.extract_state_result(action,
action_result)
self.assertEqual(state, states.SUCCESS)
self.assertEqual(result, SAMPLE_RESULT)

View File

@ -1,6 +1,6 @@
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
# #
# Copyright 2013 - Mirantis, Inc. # Copyright 2013 - StackStorm, Inc.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
@ -14,27 +14,17 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from mistral.engine.actions import actions
from mistral.engine.actions import action_types
import os from mistral.tests import base
import tempfile
import unittest2
from mistral.db.sqlalchemy import api as db_api
from mistral.openstack.common.db.sqlalchemy import session
class DbTestCase(unittest2.TestCase): class FakeActionTest(base.BaseTest):
def setUp(self): def test_send_email_real(self):
self.db_fd, self.db_path = tempfile.mkstemp() expected = "my output"
session.set_defaults('sqlite:///' + self.db_path, self.db_path)
db_api.setup_db()
def tearDown(self): action = actions.EchoAction(action_types.ECHO, "test", output=expected)
db_api.drop_db()
os.close(self.db_fd)
os.unlink(self.db_path)
def is_db_session_open(self): self.assertEqual(action.run(), expected)
return db_api._get_thread_local_session() is not None

View File

@ -22,6 +22,7 @@ from email.parser import Parser
from mistral.engine.actions import actions from mistral.engine.actions import actions
from mistral.engine.actions import action_types from mistral.engine.actions import action_types
from mistral import exceptions as exc
ACTION_TYPE = action_types.SEND_EMAIL ACTION_TYPE = action_types.SEND_EMAIL
ACTION_NAME = "TEMPORARY" ACTION_NAME = "TEMPORARY"
@ -117,9 +118,7 @@ class SendEmailActionTest(unittest2.TestCase):
ACTION_TYPE, ACTION_NAME, self.params, self.settings) ACTION_TYPE, ACTION_NAME, self.params, self.settings)
try: try:
action.run() action.run()
except IOError: except exc.ActionException:
pass pass
else: else:
self.assertFalse("Must throw exception") self.assertFalse("Must throw exception")
self.assertTrue(log.error.called)

View File

@ -14,35 +14,26 @@
# limitations under the License. # limitations under the License.
import mock import mock
import pkg_resources as pkg
from mistral.db import api as db_api from mistral.db import api as db_api
from mistral.engine.actions import actions from mistral.engine.actions import actions
from mistral.engine.local import engine from mistral.engine.local import engine
from mistral.engine import states from mistral.engine import states
from mistral import version from mistral.tests import base
from mistral.tests.unit import base as test_base
ENGINE = engine.get_engine() ENGINE = engine.get_engine()
CFG_PREFIX = "tests/resources/"
WB_NAME = "my_workbook" WB_NAME = "my_workbook"
CONTEXT = None # TODO(rakhmerov): Use a meaningful value. CONTEXT = None # TODO(rakhmerov): Use a meaningful value.
#TODO(rakhmerov): add more tests for errors, execution stop etc. #TODO(rakhmerov): add more tests for errors, execution stop etc.
def get_cfg(cfg_suffix): class TestLocalEngine(base.DbTestCase):
return open(pkg.resource_filename(
version.version_info.package,
CFG_PREFIX + cfg_suffix)).read()
class TestLocalEngine(test_base.DbTestCase):
@mock.patch.object(db_api, "workbook_get", @mock.patch.object(db_api, "workbook_get",
mock.MagicMock(return_value={ mock.MagicMock(return_value={
'definition': get_cfg("test_rest.yaml") 'definition': base.get_resource("test_rest.yaml")
})) }))
@mock.patch.object(actions.RestAction, "run", @mock.patch.object(actions.RestAction, "run",
mock.MagicMock(return_value={'state': states.RUNNING})) mock.MagicMock(return_value={'state': states.RUNNING}))
@ -63,7 +54,7 @@ class TestLocalEngine(test_base.DbTestCase):
@mock.patch.object(db_api, "workbook_get", @mock.patch.object(db_api, "workbook_get",
mock.MagicMock(return_value={ mock.MagicMock(return_value={
'definition': get_cfg("test_rest.yaml") 'definition': base.get_resource("test_rest.yaml")
})) }))
@mock.patch.object(actions.RestAction, "run", @mock.patch.object(actions.RestAction, "run",
mock.MagicMock(return_value={'state': states.RUNNING})) mock.MagicMock(return_value={'state': states.RUNNING}))
@ -105,7 +96,7 @@ class TestLocalEngine(test_base.DbTestCase):
mock.MagicMock(return_value={'state': states.SUCCESS})) mock.MagicMock(return_value={'state': states.SUCCESS}))
@mock.patch.object(db_api, "workbook_get", @mock.patch.object(db_api, "workbook_get",
mock.MagicMock(return_value={ mock.MagicMock(return_value={
'definition': get_cfg("test_rest.yaml") 'definition': base.get_resource("test_rest.yaml")
})) }))
@mock.patch.object(states, "get_state_by_http_status_code", @mock.patch.object(states, "get_state_by_http_status_code",
mock.MagicMock(return_value=states.SUCCESS)) mock.MagicMock(return_value=states.SUCCESS))
@ -121,7 +112,7 @@ class TestLocalEngine(test_base.DbTestCase):
@mock.patch.object(db_api, "workbook_get", @mock.patch.object(db_api, "workbook_get",
mock.MagicMock(return_value={ mock.MagicMock(return_value={
'definition': get_cfg("test_rest.yaml") 'definition': base.get_resource("test_rest.yaml")
})) }))
@mock.patch.object(actions.RestAction, "run", @mock.patch.object(actions.RestAction, "run",
mock.MagicMock(return_value={'state': states.SUCCESS})) mock.MagicMock(return_value={'state': states.SUCCESS}))
@ -172,7 +163,7 @@ class TestLocalEngine(test_base.DbTestCase):
@mock.patch.object(db_api, "workbook_get", @mock.patch.object(db_api, "workbook_get",
mock.MagicMock(return_value={ mock.MagicMock(return_value={
'definition': get_cfg("test_rest.yaml") 'definition': base.get_resource("test_rest.yaml")
})) }))
@mock.patch.object(actions.RestAction, "run", @mock.patch.object(actions.RestAction, "run",
mock.MagicMock(return_value={'state': states.SUCCESS})) mock.MagicMock(return_value={'state': states.SUCCESS}))

View File

@ -14,19 +14,16 @@
# limitations under the License. # limitations under the License.
import mock import mock
import pkg_resources as pkg
from mistral.db import api as db_api from mistral.db import api as db_api
from mistral.engine.actions import actions from mistral.engine.actions import actions
from mistral.engine.scalable import engine from mistral.engine.scalable import engine
from mistral.engine import states from mistral.engine import states
from mistral import version from mistral.tests import base
from mistral.tests.unit import base as test_base
ENGINE = engine.get_engine() ENGINE = engine.get_engine()
CFG_PREFIX = "tests/resources/"
WB_NAME = "my_workbook" WB_NAME = "my_workbook"
CONTEXT = None # TODO(rakhmerov): Use a meaningful value. CONTEXT = None # TODO(rakhmerov): Use a meaningful value.
@ -34,18 +31,12 @@ CONTEXT = None # TODO(rakhmerov): Use a meaningful value.
#TODO(rakhmerov): add more tests for errors, execution stop etc. #TODO(rakhmerov): add more tests for errors, execution stop etc.
def get_cfg(cfg_suffix): class TestScalableEngine(base.DbTestCase):
return open(pkg.resource_filename(
version.version_info.package,
CFG_PREFIX + cfg_suffix)).read()
class TestScalableEngine(test_base.DbTestCase):
@mock.patch.object(engine.ScalableEngine, "_notify_task_executors", @mock.patch.object(engine.ScalableEngine, "_notify_task_executors",
mock.MagicMock(return_value="")) mock.MagicMock(return_value=""))
@mock.patch.object(db_api, "workbook_get", @mock.patch.object(db_api, "workbook_get",
mock.MagicMock(return_value={ mock.MagicMock(return_value={
'definition': get_cfg("test_rest.yaml") 'definition': base.get_resource("test_rest.yaml")
})) }))
@mock.patch.object(actions.RestAction, "run", @mock.patch.object(actions.RestAction, "run",
mock.MagicMock(return_value="result")) mock.MagicMock(return_value="result"))
@ -68,7 +59,7 @@ class TestScalableEngine(test_base.DbTestCase):
mock.MagicMock(return_value="")) mock.MagicMock(return_value=""))
@mock.patch.object(db_api, "workbook_get", @mock.patch.object(db_api, "workbook_get",
mock.MagicMock(return_value={ mock.MagicMock(return_value={
'definition': get_cfg("test_rest.yaml") 'definition': base.get_resource("test_rest.yaml")
})) }))
@mock.patch.object(actions.RestAction, "run", @mock.patch.object(actions.RestAction, "run",
mock.MagicMock(return_value="result")) mock.MagicMock(return_value="result"))

View File

@ -13,10 +13,10 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from mistral.tests.unit import base as test_base from mistral.tests import base
class TestTaskExecutor(test_base.DbTestCase): class TestTaskExecutor(base.DbTestCase):
def setUp(self): def setUp(self):
super(TestTaskExecutor, self).setUp() super(TestTaskExecutor, self).setUp()
self.wb_name = "my_workbook" self.wb_name = "my_workbook"

View File

@ -0,0 +1,95 @@
# -*- coding: utf-8 -*-
#
# Copyright 2013 - Mirantis, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from mistral.openstack.common import log as logging
from mistral.tests import base
from mistral.db import api as db_api
from mistral.engine.local import engine
from mistral.engine import states
LOG = logging.getLogger(__name__)
ENGINE = engine.get_engine()()
CONTEXT = {
'person': {
'first_name': 'John',
'last_name': 'Doe',
'address': {
'street': '124352 Broadway Street',
'city': 'Gloomington',
'country': 'USA'
}
}
}
def create_workbook(definition_path):
return db_api.workbook_create({
'name': 'my_workbook',
'definition': base.get_resource(definition_path)
})
class DataFlowTest(base.DbTestCase):
def test_two_dependent_tasks(self):
wb = create_workbook('data_flow/two_dependent_tasks.yaml')
execution = ENGINE.start_workflow_execution(wb['name'],
'build_greeting',
CONTEXT)
# We have to reread execution to get its latest version.
execution = db_api.execution_get(execution['workbook_name'],
execution['id'])
self.assertEqual(execution['state'], states.SUCCESS)
self.assertDictEqual(execution['context'], CONTEXT)
tasks = db_api.tasks_get(wb['name'], execution['id'])
self.assertEqual(2, len(tasks))
if tasks[0]['name'] == 'build_full_name':
build_full_name_task = tasks[0]
build_greeting_task = tasks[1]
else:
build_full_name_task = tasks[1]
build_greeting_task = tasks[0]
self.assertEqual(build_full_name_task['name'], 'build_full_name')
self.assertEqual(build_greeting_task['name'], 'build_greeting')
# Check the first task.
self.assertEqual(states.SUCCESS, build_full_name_task['state'])
self.assertDictEqual(CONTEXT, build_full_name_task['in_context'])
self.assertDictEqual({'first_name': 'John', 'last_name': 'Doe'},
build_full_name_task['input'])
self.assertDictEqual({'full_name': 'John Doe'},
build_full_name_task['output'])
# Check the second task.
in_context = CONTEXT.copy()
in_context['full_name'] = 'John Doe'
self.assertEqual(states.SUCCESS, build_greeting_task['state'])
self.assertDictEqual(in_context, build_greeting_task['in_context'])
self.assertDictEqual({'full_name': 'John Doe'},
build_greeting_task['input'])
self.assertDictEqual({'greeting': 'Hello, John Doe!'},
build_greeting_task['output'])
# TODO(rakhmerov): add more checks

View File

@ -0,0 +1,77 @@
# -*- coding: utf-8 -*-
#
# Copyright 2013 - Mirantis, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from mistral.engine import data_flow
from mistral.tests import base
from mistral.db import api as db_api
from mistral.openstack.common import log as logging
LOG = logging.getLogger(__name__)
WB_NAME = 'my_workbook'
EXEC_ID = '1'
CONTEXT = {
'param1': 'val1',
'param2': 'val2',
'param3': {
'param31': 'val31',
'param32': 'val32'
}
}
TASK = {
'workbook_name': WB_NAME,
'execution_id': EXEC_ID,
'name': 'my_task',
'task_dsl': {
'input': {
'p1': 'My string',
'p2': '$.param3.param32'
}
}
}
class DataFlowTest(base.DbTestCase):
def test_prepare_task_input(self):
input = data_flow.evaluate_task_input(TASK, CONTEXT)
self.assertEqual(len(input), 2)
self.assertEqual(input['p1'], 'My string')
self.assertEqual(input['p2'], 'val32')
def test_prepare_tasks(self):
task = db_api.task_create(WB_NAME, EXEC_ID, TASK.copy())
tasks = [task]
data_flow.prepare_tasks(tasks, CONTEXT)
db_task = db_api.task_get(WB_NAME, EXEC_ID, tasks[0]['id'])
self.assertDictEqual(db_task['in_context'], CONTEXT)
self.assertDictEqual(db_task['input'], {
'p1': 'My string',
'p2': 'val32'
})
def test_merge_into_context(self):
ctx = data_flow.merge_into_context(CONTEXT.copy(),
{'new_key1': 'new_val1'})
self.assertEqual(ctx['new_key1'], 'new_val1')

View File

@ -18,7 +18,7 @@ import pkg_resources as pkg
from mistral import dsl from mistral import dsl
from mistral import version from mistral import version
from mistral.tests.unit import base from mistral.tests import base
from mistral.engine import states from mistral.engine import states
from mistral.engine import workflow from mistral.engine import workflow

View File

@ -18,7 +18,7 @@ import pkg_resources as pkg
from mistral.db import api as db_api from mistral.db import api as db_api
from mistral import dsl from mistral import dsl
from mistral.tests.unit import base from mistral.tests import base
from mistral import version from mistral import version
from mistral.services import scheduler from mistral.services import scheduler

View File

@ -20,7 +20,7 @@ from datetime import timedelta
from mistral.openstack.common import timeutils from mistral.openstack.common import timeutils
from mistral.services import scheduler as s from mistral.services import scheduler as s
from mistral.tests.unit import base as test_base from mistral.tests import base as test_base
SAMPLE_EVENT = { SAMPLE_EVENT = {