Introduce remote tasks cache for worker-executor

Updated WorkerTaskExecutor to use cache for remote tasks.

Change-Id: I4572052f63647472367cb69fc02911bbec2bd4cc
This commit is contained in:
Stanislav Kudriashev
2014-03-06 18:47:35 +02:00
parent 014fc96d5a
commit ae28f0ba4b
3 changed files with 107 additions and 62 deletions

View File

@@ -0,0 +1,57 @@
# -*- coding: utf-8 -*-
# Copyright (C) 2014 Yahoo! Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License. You may obtain
# a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations
# under the License.
import logging
import six
from taskflow.utils import lock_utils as lu
LOG = logging.getLogger(__name__)
class Cache(object):
"""Represents thread-safe cache."""
def __init__(self):
self._data = {}
self._lock = lu.ReaderWriterLock()
def get(self, key):
"""Retrieve a value from the cache."""
with self._lock.read_lock():
return self._data.get(key)
def set(self, key, value):
"""Set a value in the cache."""
with self._lock.write_lock():
self._data[key] = value
LOG.debug("Cache updated. Capacity: %s", len(self._data))
def delete(self, key):
"""Delete a value from the cache."""
with self._lock.write_lock():
self._data.pop(key, None)
def cleanup(self, on_expired_callback=None):
"""Delete out-dated values from the cache."""
with self._lock.write_lock():
expired_values = [(k, v) for k, v in six.iteritems(self._data)
if v.expired]
for k, v in expired_values:
if on_expired_callback:
on_expired_callback(v)
self._data.pop(k, None)

View File

@@ -15,12 +15,14 @@
# under the License.
import logging
import six
import threading
import six
from kombu import exceptions as kombu_exc
from taskflow.engines.action_engine import executor
from taskflow.engines.worker_based import cache
from taskflow.engines.worker_based import protocol as pr
from taskflow.engines.worker_based import proxy
from taskflow.engines.worker_based import remote_task as rt
@@ -40,7 +42,7 @@ class WorkerTaskExecutor(executor.TaskExecutorBase):
self._proxy = proxy.Proxy(uuid, exchange, self._on_message,
self._on_wait, **kwargs)
self._proxy_thread = None
self._remote_tasks = {}
self._remote_tasks_cache = cache.Cache()
# TODO(skudriashev): This data should be collected from workers
# using broadcast messages directly.
@@ -78,55 +80,41 @@ class WorkerTaskExecutor(executor.TaskExecutorBase):
def _process_response(self, task_uuid, response):
"""Process response from remote side."""
try:
task = self._remote_tasks[task_uuid]
except KeyError:
LOG.debug("Task with id='%s' not found.", task_uuid)
else:
remote_task = self._remote_tasks_cache.get(task_uuid)
if remote_task is not None:
state = response.pop('state')
if state == pr.RUNNING:
task.set_running()
remote_task.set_running()
elif state == pr.PROGRESS:
task.on_progress(**response)
remote_task.on_progress(**response)
elif state == pr.FAILURE:
response['result'] = pu.failure_from_dict(response['result'])
task.set_result(**response)
self._remove_remote_task(task)
remote_task.set_result(**response)
self._remote_tasks_cache.delete(remote_task.uuid)
elif state == pr.SUCCESS:
task.set_result(**response)
self._remove_remote_task(task)
remote_task.set_result(**response)
self._remote_tasks_cache.delete(remote_task.uuid)
else:
LOG.warning("Unexpected response status: '%s'", state)
else:
LOG.debug("Remote task with id='%s' not found.", task_uuid)
@staticmethod
def _handle_expired_remote_task(task):
LOG.debug("Remote task '%r' has expired.", task)
task.set_result(misc.Failure.from_exception(
exc.Timeout("Remote task '%r' has expired" % task)))
def _on_wait(self):
"""This function is called cyclically between draining events
iterations to clean-up expired task requests.
"""
expired_tasks = [task for task in six.itervalues(self._remote_tasks)
if task.expired]
for task in expired_tasks:
LOG.debug("Task request '%s' has expired.", task)
task.set_result(misc.Failure.from_exception(
exc.Timeout("Task request '%s' has expired" % task)))
del self._remote_tasks[task.uuid]
def _store_remote_task(self, task):
"""Store task in the remote tasks map."""
self._remote_tasks[task.uuid] = task
return task
def _remove_remote_task(self, task):
"""Remove remote task from the tasks map."""
if task.uuid in self._remote_tasks:
del self._remote_tasks[task.uuid]
"""This function is called cyclically between draining events."""
self._remote_tasks_cache.cleanup(self._handle_expired_remote_task)
def _submit_task(self, task, task_uuid, action, arguments,
progress_callback, timeout=pr.REQUEST_TIMEOUT, **kwargs):
"""Submit task request to workers."""
remote_task = self._store_remote_task(
rt.RemoteTask(task, task_uuid, action, arguments,
progress_callback, timeout, **kwargs)
)
remote_task = rt.RemoteTask(task, task_uuid, action, arguments,
progress_callback, timeout, **kwargs)
self._remote_tasks_cache.set(remote_task.uuid, remote_task)
try:
# get task's workers topic to send request to
try:
@@ -143,7 +131,7 @@ class WorkerTaskExecutor(executor.TaskExecutorBase):
except Exception:
with misc.capture_failure() as failure:
LOG.exception("Failed to submit the '%s' task", remote_task)
self._remove_remote_task(remote_task)
self._remote_tasks_cache.delete(remote_task.uuid)
remote_task.set_result(failure)
return remote_task.result

View File

@@ -105,7 +105,7 @@ class TestWorkerTaskExecutor(test.MockTestCase):
def test_on_message_state_running(self):
response = dict(state=pr.RUNNING)
ex = self.executor()
ex._store_remote_task(self.remote_task_mock)
ex._remote_tasks_cache.set(self.task_uuid, self.remote_task_mock)
ex._on_message(response, self.message_mock)
self.assertEqual(self.remote_task_mock.mock_calls,
@@ -115,7 +115,7 @@ class TestWorkerTaskExecutor(test.MockTestCase):
def test_on_message_state_progress(self):
response = dict(state=pr.PROGRESS, progress=1.0)
ex = self.executor()
ex._store_remote_task(self.remote_task_mock)
ex._remote_tasks_cache.set(self.task_uuid, self.remote_task_mock)
ex._on_message(response, self.message_mock)
self.assertEqual(self.remote_task_mock.mock_calls,
@@ -127,10 +127,10 @@ class TestWorkerTaskExecutor(test.MockTestCase):
failure_dict = pu.failure_to_dict(failure)
response = dict(state=pr.FAILURE, result=failure_dict)
ex = self.executor()
ex._store_remote_task(self.remote_task_mock)
ex._remote_tasks_cache.set(self.task_uuid, self.remote_task_mock)
ex._on_message(response, self.message_mock)
self.assertEqual(len(ex._remote_tasks), 0)
self.assertEqual(len(ex._remote_tasks_cache._data), 0)
self.assertEqual(self.remote_task_mock.mock_calls, [
mock.call.set_result(result=utils.FailureMatcher(failure))
])
@@ -140,7 +140,7 @@ class TestWorkerTaskExecutor(test.MockTestCase):
response = dict(state=pr.SUCCESS, result=self.task_result,
event='executed')
ex = self.executor()
ex._store_remote_task(self.remote_task_mock)
ex._remote_tasks_cache.set(self.task_uuid, self.remote_task_mock)
ex._on_message(response, self.message_mock)
self.assertEqual(self.remote_task_mock.mock_calls,
@@ -151,7 +151,7 @@ class TestWorkerTaskExecutor(test.MockTestCase):
def test_on_message_unknown_state(self):
response = dict(state='unknown')
ex = self.executor()
ex._store_remote_task(self.remote_task_mock)
ex._remote_tasks_cache.set(self.task_uuid, self.remote_task_mock)
ex._on_message(response, self.message_mock)
self.assertEqual(self.remote_task_mock.mock_calls, [])
@@ -161,7 +161,7 @@ class TestWorkerTaskExecutor(test.MockTestCase):
self.message_mock.properties = {'correlation_id': 'non-existent'}
response = dict(state=pr.RUNNING)
ex = self.executor()
ex._store_remote_task(self.remote_task_mock)
ex._remote_tasks_cache.set(self.task_uuid, self.remote_task_mock)
ex._on_message(response, self.message_mock)
self.assertEqual(self.remote_task_mock.mock_calls, [])
@@ -171,7 +171,7 @@ class TestWorkerTaskExecutor(test.MockTestCase):
self.message_mock.properties = {}
response = dict(state=pr.RUNNING)
ex = self.executor()
ex._store_remote_task(self.remote_task_mock)
ex._remote_tasks_cache.set(self.task_uuid, self.remote_task_mock)
ex._on_message(response, self.message_mock)
self.assertEqual(self.remote_task_mock.mock_calls, [])
@@ -184,37 +184,37 @@ class TestWorkerTaskExecutor(test.MockTestCase):
self.assertTrue(mocked_exception.called)
@mock.patch('taskflow.engines.worker_based.remote_task.misc.wallclock')
def test_on_wait_task_not_expired(self, mock_time):
mock_time.side_effect = [1, self.timeout]
def test_on_wait_task_not_expired(self, mocked_time):
mocked_time.side_effect = [1, self.timeout]
ex = self.executor()
ex._store_remote_task(self.remote_task())
ex._remote_tasks_cache.set(self.task_uuid, self.remote_task())
self.assertEqual(len(ex._remote_tasks), 1)
self.assertEqual(len(ex._remote_tasks_cache._data), 1)
ex._on_wait()
self.assertEqual(len(ex._remote_tasks), 1)
self.assertEqual(len(ex._remote_tasks_cache._data), 1)
@mock.patch('taskflow.engines.worker_based.remote_task.misc.wallclock')
def test_on_wait_task_expired(self, mock_time):
mock_time.side_effect = [1, self.timeout + 2, self.timeout * 2]
def test_on_wait_task_expired(self, mocked_time):
mocked_time.side_effect = [1, self.timeout + 2, self.timeout * 2]
ex = self.executor()
ex._store_remote_task(self.remote_task())
ex._remote_tasks_cache.set(self.task_uuid, self.remote_task())
self.assertEqual(len(ex._remote_tasks), 1)
self.assertEqual(len(ex._remote_tasks_cache._data), 1)
ex._on_wait()
self.assertEqual(len(ex._remote_tasks), 0)
self.assertEqual(len(ex._remote_tasks_cache._data), 0)
def test_remove_task_non_existent(self):
task = self.remote_task()
ex = self.executor()
ex._store_remote_task(task)
ex._remote_tasks_cache.set(self.task_uuid, task)
self.assertEqual(len(ex._remote_tasks), 1)
ex._remove_remote_task(task)
self.assertEqual(len(ex._remote_tasks), 0)
self.assertEqual(len(ex._remote_tasks_cache._data), 1)
ex._remote_tasks_cache.delete(self.task_uuid)
self.assertEqual(len(ex._remote_tasks_cache._data), 0)
# remove non-existent
ex._remove_remote_task(task)
self.assertEqual(len(ex._remote_tasks), 0)
ex._remote_tasks_cache.delete(self.task_uuid)
self.assertEqual(len(ex._remote_tasks_cache._data), 0)
def test_execute_task(self):
request = self.request(action='execute')