Add engine related features

* Workflow in scalable engine
  * find tasks to run
  * find workflow tasks
 * Added unit tests
 * Fix events test

Change-Id: Ib3a97f976b101a68cbbde9d2117f5b2ca5eab5cd
This commit is contained in:
Nikolay Mahotkin 2013-12-12 16:19:39 +04:00
parent bb6dd45ec4
commit 41e21c7616
7 changed files with 195 additions and 60 deletions

View File

@ -31,8 +31,14 @@ class Parser(object):
raise RuntimeError("Definition could not be parsed: %s\n" raise RuntimeError("Definition could not be parsed: %s\n"
% exc.message) % exc.message)
def get_service(self): def get_services(self):
return self.doc["Service"] services = []
for service_name in self.doc["Services"]:
services.append(self.doc["Services"][service_name])
return services
def get_service(self, service_name):
return self.doc["Services"][service_name]
def get_events(self): def get_events(self):
events_from_doc = self.doc["Workflow"]["events"] events_from_doc = self.doc["Workflow"]["events"]
@ -46,12 +52,20 @@ class Parser(object):
def get_tasks(self): def get_tasks(self):
return self.doc["Workflow"]["tasks"] return self.doc["Workflow"]["tasks"]
def get_action(self, action_name): def get_action(self, task_action_name):
# TODO(rakhmerov): it needs to return action definition as a dict service_name = task_action_name.split(':')[0]
pass action_name = task_action_name.split(':')[1]
action = self.get_service(service_name)['actions'][action_name]
return action
def get_service_name(self): def get_actions(self, service_name):
return self.doc['Service']['name'] return self.get_service(service_name)['actions']
def get_service_names(self):
names = []
for name in self.doc['Services']:
names.append(name)
return names
def get_event_task_name(self, event_name): def get_event_task_name(self, event_name):
return self.doc["Workflow"]["events"][event_name]['tasks'] return self.doc["Workflow"]["events"][event_name]['tasks']

View File

@ -14,18 +14,72 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import networkx as nx
from networkx.algorithms import traversal
from mistral.engine import states
def find_workflow_tasks(wb_dsl, target_task_name): def find_workflow_tasks(wb_dsl, target_task_name):
# TODO(rakhmerov): implement using networkX dsl_tasks = wb_dsl.get_tasks()
return None full_graph = nx.DiGraph()
for t in dsl_tasks:
full_graph.add_node(t)
def find_tasks_to_start(tasks): _update_dependencies(dsl_tasks, full_graph)
# TODO(rakhmerov): implement using networkX graph = _get_subgraph(full_graph, target_task_name)
# We need to analyse graph and see which tasks are ready to start tasks = []
for node in graph:
task = {'name': node}
task.update(dsl_tasks[node])
tasks.append(task)
return tasks return tasks
def find_tasks_to_start(tasks):
# We need to analyse graph and see which tasks are ready to start
return _get_resolved_tasks(tasks)
def is_finished(tasks): def is_finished(tasks):
# TODO(rakhmerov): implement for task in tasks:
if not states.is_finished(task['state']):
return False return False
return True
def _get_subgraph(full_graph, task_name):
nodes_set = traversal.dfs_predecessors(full_graph.reverse(),
task_name).keys()
nodes_set.append(task_name)
return full_graph.subgraph(nodes_set)
def _get_dependency_tasks(tasks, task):
if 'dependsOn' not in tasks[task]:
return []
deps = set()
for t in tasks:
for dep in tasks[task]['dependsOn']:
if dep == t:
deps.add(t)
return deps
def _update_dependencies(tasks, graph):
for task in tasks:
for dep in _get_dependency_tasks(tasks, task):
graph.add_edge(dep, task)
def _get_resolved_tasks(tasks):
resolved_tasks = []
allows = []
for t in tasks:
if t['state'] == states.SUCCESS:
allows += t['dependencies']
allow_set = set(allows)
for t in tasks:
if len(allow_set - set(t['dependencies'])) == 0:
if t['state'] == states.IDLE:
resolved_tasks.append(t)
return resolved_tasks

View File

@ -1,5 +1,5 @@
Service: Services:
name: MyRest MyRest:
type: REST_API type: REST_API
parameters: parameters:
baseUrl: http://some_host baseUrl: http://some_host

View File

@ -24,7 +24,8 @@ SAMPLE_EVENT = {
"id": "123", "id": "123",
"name": "test_event", "name": "test_event",
"pattern": "* *", "pattern": "* *",
"next_execution_time": timeutils.utcnow() "next_execution_time": timeutils.utcnow(),
'workbook_name': 'wb_name'
} }

View File

@ -0,0 +1,57 @@
# -*- coding: utf-8 -*-
#
# Copyright 2013 - Mirantis, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import pkg_resources as pkg
from mistral import dsl
from mistral import version
from mistral.tests.unit import base
from mistral.engine import states
from mistral.engine.scalable import workflow
TASKS = [
{
'dependencies': [],
'name': 'backup-vms',
'state': states.IDLE
},
{
'dependencies': [],
'name': 'create-vms',
'state': states.RUNNING
},
{
'dependencies': ['create-vms'],
'name': 'attach-volume',
'state': states.IDLE
}
]
class WorkflowTest(base.DbTestCase):
def setUp(self):
super(WorkflowTest, self).setUp()
self.doc = open(pkg.resource_filename(
version.version_info.package,
"tests/resources/test_rest.yaml")).read()
self.parser = dsl.Parser(self.doc)
def test_find_workflow_tasks(self):
tasks = workflow.find_workflow_tasks(self.parser, "attach-volumes")
self.assertEqual(tasks[1]['name'], 'create-vms')
def test_tasks_to_start(self):
tasks_to_start = workflow.find_tasks_to_start(TASKS)
self.assertEqual(len(tasks_to_start), 2)

View File

@ -27,11 +27,14 @@ class DSLParserTest(unittest2.TestCase):
"tests/resources/test_rest.yaml")).read() "tests/resources/test_rest.yaml")).read()
self.dsl = dsl.Parser(doc) self.dsl = dsl.Parser(doc)
def test_service(self): def test_services(self):
service = self.dsl.get_service() service = self.dsl.get_service("MyRest")
self.assertEqual(service["name"], "MyRest")
self.assertEqual(service["type"], "REST_API") self.assertEqual(service["type"], "REST_API")
self.assertIn("baseUrl", service["parameters"]) self.assertIn("baseUrl", service["parameters"])
services = self.dsl.get_services()
self.assertEqual(len(services), 1)
service_names = self.dsl.get_service_names()
self.assertEqual(service_names[0], "MyRest")
def test_events(self): def test_events(self):
events = self.dsl.get_events() events = self.dsl.get_events()
@ -44,6 +47,12 @@ class DSLParserTest(unittest2.TestCase):
self.assertEqual(tasks["backup-vms"]["action"], self.assertEqual(tasks["backup-vms"]["action"],
"Nova:backup-vm") "Nova:backup-vm")
def test_actions(self):
action = self.dsl.get_action("MyRest:attach-volume")
self.assertIn("method", action["parameters"])
actions = self.dsl.get_actions("MyRest")
self.assertIn("task-parameters", actions["attach-volume"])
def test_broken_definition(self): def test_broken_definition(self):
broken_yaml = """ broken_yaml = """
Workflow: Workflow: