Introduce remote tasks cache for worker-executor
Updated WorkerTaskExecutor to use cache for remote tasks. Change-Id: I4572052f63647472367cb69fc02911bbec2bd4cc
This commit is contained in:
57
taskflow/engines/worker_based/cache.py
Normal file
57
taskflow/engines/worker_based/cache.py
Normal file
@@ -0,0 +1,57 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
# Copyright (C) 2014 Yahoo! Inc. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License"); you may
|
||||
# not use this file except in compliance with the License. You may obtain
|
||||
# a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
# License for the specific language governing permissions and limitations
|
||||
# under the License.
|
||||
|
||||
import logging
|
||||
|
||||
import six
|
||||
|
||||
from taskflow.utils import lock_utils as lu
|
||||
|
||||
LOG = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class Cache(object):
|
||||
"""Represents thread-safe cache."""
|
||||
|
||||
def __init__(self):
|
||||
self._data = {}
|
||||
self._lock = lu.ReaderWriterLock()
|
||||
|
||||
def get(self, key):
|
||||
"""Retrieve a value from the cache."""
|
||||
with self._lock.read_lock():
|
||||
return self._data.get(key)
|
||||
|
||||
def set(self, key, value):
|
||||
"""Set a value in the cache."""
|
||||
with self._lock.write_lock():
|
||||
self._data[key] = value
|
||||
LOG.debug("Cache updated. Capacity: %s", len(self._data))
|
||||
|
||||
def delete(self, key):
|
||||
"""Delete a value from the cache."""
|
||||
with self._lock.write_lock():
|
||||
self._data.pop(key, None)
|
||||
|
||||
def cleanup(self, on_expired_callback=None):
|
||||
"""Delete out-dated values from the cache."""
|
||||
with self._lock.write_lock():
|
||||
expired_values = [(k, v) for k, v in six.iteritems(self._data)
|
||||
if v.expired]
|
||||
for k, v in expired_values:
|
||||
if on_expired_callback:
|
||||
on_expired_callback(v)
|
||||
self._data.pop(k, None)
|
||||
@@ -15,12 +15,14 @@
|
||||
# under the License.
|
||||
|
||||
import logging
|
||||
import six
|
||||
import threading
|
||||
|
||||
import six
|
||||
|
||||
from kombu import exceptions as kombu_exc
|
||||
|
||||
from taskflow.engines.action_engine import executor
|
||||
from taskflow.engines.worker_based import cache
|
||||
from taskflow.engines.worker_based import protocol as pr
|
||||
from taskflow.engines.worker_based import proxy
|
||||
from taskflow.engines.worker_based import remote_task as rt
|
||||
@@ -40,7 +42,7 @@ class WorkerTaskExecutor(executor.TaskExecutorBase):
|
||||
self._proxy = proxy.Proxy(uuid, exchange, self._on_message,
|
||||
self._on_wait, **kwargs)
|
||||
self._proxy_thread = None
|
||||
self._remote_tasks = {}
|
||||
self._remote_tasks_cache = cache.Cache()
|
||||
|
||||
# TODO(skudriashev): This data should be collected from workers
|
||||
# using broadcast messages directly.
|
||||
@@ -78,55 +80,41 @@ class WorkerTaskExecutor(executor.TaskExecutorBase):
|
||||
|
||||
def _process_response(self, task_uuid, response):
|
||||
"""Process response from remote side."""
|
||||
try:
|
||||
task = self._remote_tasks[task_uuid]
|
||||
except KeyError:
|
||||
LOG.debug("Task with id='%s' not found.", task_uuid)
|
||||
else:
|
||||
remote_task = self._remote_tasks_cache.get(task_uuid)
|
||||
if remote_task is not None:
|
||||
state = response.pop('state')
|
||||
if state == pr.RUNNING:
|
||||
task.set_running()
|
||||
remote_task.set_running()
|
||||
elif state == pr.PROGRESS:
|
||||
task.on_progress(**response)
|
||||
remote_task.on_progress(**response)
|
||||
elif state == pr.FAILURE:
|
||||
response['result'] = pu.failure_from_dict(response['result'])
|
||||
task.set_result(**response)
|
||||
self._remove_remote_task(task)
|
||||
remote_task.set_result(**response)
|
||||
self._remote_tasks_cache.delete(remote_task.uuid)
|
||||
elif state == pr.SUCCESS:
|
||||
task.set_result(**response)
|
||||
self._remove_remote_task(task)
|
||||
remote_task.set_result(**response)
|
||||
self._remote_tasks_cache.delete(remote_task.uuid)
|
||||
else:
|
||||
LOG.warning("Unexpected response status: '%s'", state)
|
||||
else:
|
||||
LOG.debug("Remote task with id='%s' not found.", task_uuid)
|
||||
|
||||
@staticmethod
|
||||
def _handle_expired_remote_task(task):
|
||||
LOG.debug("Remote task '%r' has expired.", task)
|
||||
task.set_result(misc.Failure.from_exception(
|
||||
exc.Timeout("Remote task '%r' has expired" % task)))
|
||||
|
||||
def _on_wait(self):
|
||||
"""This function is called cyclically between draining events
|
||||
iterations to clean-up expired task requests.
|
||||
"""
|
||||
expired_tasks = [task for task in six.itervalues(self._remote_tasks)
|
||||
if task.expired]
|
||||
for task in expired_tasks:
|
||||
LOG.debug("Task request '%s' has expired.", task)
|
||||
task.set_result(misc.Failure.from_exception(
|
||||
exc.Timeout("Task request '%s' has expired" % task)))
|
||||
del self._remote_tasks[task.uuid]
|
||||
|
||||
def _store_remote_task(self, task):
|
||||
"""Store task in the remote tasks map."""
|
||||
self._remote_tasks[task.uuid] = task
|
||||
return task
|
||||
|
||||
def _remove_remote_task(self, task):
|
||||
"""Remove remote task from the tasks map."""
|
||||
if task.uuid in self._remote_tasks:
|
||||
del self._remote_tasks[task.uuid]
|
||||
"""This function is called cyclically between draining events."""
|
||||
self._remote_tasks_cache.cleanup(self._handle_expired_remote_task)
|
||||
|
||||
def _submit_task(self, task, task_uuid, action, arguments,
|
||||
progress_callback, timeout=pr.REQUEST_TIMEOUT, **kwargs):
|
||||
"""Submit task request to workers."""
|
||||
remote_task = self._store_remote_task(
|
||||
rt.RemoteTask(task, task_uuid, action, arguments,
|
||||
progress_callback, timeout, **kwargs)
|
||||
)
|
||||
remote_task = rt.RemoteTask(task, task_uuid, action, arguments,
|
||||
progress_callback, timeout, **kwargs)
|
||||
self._remote_tasks_cache.set(remote_task.uuid, remote_task)
|
||||
try:
|
||||
# get task's workers topic to send request to
|
||||
try:
|
||||
@@ -143,7 +131,7 @@ class WorkerTaskExecutor(executor.TaskExecutorBase):
|
||||
except Exception:
|
||||
with misc.capture_failure() as failure:
|
||||
LOG.exception("Failed to submit the '%s' task", remote_task)
|
||||
self._remove_remote_task(remote_task)
|
||||
self._remote_tasks_cache.delete(remote_task.uuid)
|
||||
remote_task.set_result(failure)
|
||||
return remote_task.result
|
||||
|
||||
|
||||
@@ -105,7 +105,7 @@ class TestWorkerTaskExecutor(test.MockTestCase):
|
||||
def test_on_message_state_running(self):
|
||||
response = dict(state=pr.RUNNING)
|
||||
ex = self.executor()
|
||||
ex._store_remote_task(self.remote_task_mock)
|
||||
ex._remote_tasks_cache.set(self.task_uuid, self.remote_task_mock)
|
||||
ex._on_message(response, self.message_mock)
|
||||
|
||||
self.assertEqual(self.remote_task_mock.mock_calls,
|
||||
@@ -115,7 +115,7 @@ class TestWorkerTaskExecutor(test.MockTestCase):
|
||||
def test_on_message_state_progress(self):
|
||||
response = dict(state=pr.PROGRESS, progress=1.0)
|
||||
ex = self.executor()
|
||||
ex._store_remote_task(self.remote_task_mock)
|
||||
ex._remote_tasks_cache.set(self.task_uuid, self.remote_task_mock)
|
||||
ex._on_message(response, self.message_mock)
|
||||
|
||||
self.assertEqual(self.remote_task_mock.mock_calls,
|
||||
@@ -127,10 +127,10 @@ class TestWorkerTaskExecutor(test.MockTestCase):
|
||||
failure_dict = pu.failure_to_dict(failure)
|
||||
response = dict(state=pr.FAILURE, result=failure_dict)
|
||||
ex = self.executor()
|
||||
ex._store_remote_task(self.remote_task_mock)
|
||||
ex._remote_tasks_cache.set(self.task_uuid, self.remote_task_mock)
|
||||
ex._on_message(response, self.message_mock)
|
||||
|
||||
self.assertEqual(len(ex._remote_tasks), 0)
|
||||
self.assertEqual(len(ex._remote_tasks_cache._data), 0)
|
||||
self.assertEqual(self.remote_task_mock.mock_calls, [
|
||||
mock.call.set_result(result=utils.FailureMatcher(failure))
|
||||
])
|
||||
@@ -140,7 +140,7 @@ class TestWorkerTaskExecutor(test.MockTestCase):
|
||||
response = dict(state=pr.SUCCESS, result=self.task_result,
|
||||
event='executed')
|
||||
ex = self.executor()
|
||||
ex._store_remote_task(self.remote_task_mock)
|
||||
ex._remote_tasks_cache.set(self.task_uuid, self.remote_task_mock)
|
||||
ex._on_message(response, self.message_mock)
|
||||
|
||||
self.assertEqual(self.remote_task_mock.mock_calls,
|
||||
@@ -151,7 +151,7 @@ class TestWorkerTaskExecutor(test.MockTestCase):
|
||||
def test_on_message_unknown_state(self):
|
||||
response = dict(state='unknown')
|
||||
ex = self.executor()
|
||||
ex._store_remote_task(self.remote_task_mock)
|
||||
ex._remote_tasks_cache.set(self.task_uuid, self.remote_task_mock)
|
||||
ex._on_message(response, self.message_mock)
|
||||
|
||||
self.assertEqual(self.remote_task_mock.mock_calls, [])
|
||||
@@ -161,7 +161,7 @@ class TestWorkerTaskExecutor(test.MockTestCase):
|
||||
self.message_mock.properties = {'correlation_id': 'non-existent'}
|
||||
response = dict(state=pr.RUNNING)
|
||||
ex = self.executor()
|
||||
ex._store_remote_task(self.remote_task_mock)
|
||||
ex._remote_tasks_cache.set(self.task_uuid, self.remote_task_mock)
|
||||
ex._on_message(response, self.message_mock)
|
||||
|
||||
self.assertEqual(self.remote_task_mock.mock_calls, [])
|
||||
@@ -171,7 +171,7 @@ class TestWorkerTaskExecutor(test.MockTestCase):
|
||||
self.message_mock.properties = {}
|
||||
response = dict(state=pr.RUNNING)
|
||||
ex = self.executor()
|
||||
ex._store_remote_task(self.remote_task_mock)
|
||||
ex._remote_tasks_cache.set(self.task_uuid, self.remote_task_mock)
|
||||
ex._on_message(response, self.message_mock)
|
||||
|
||||
self.assertEqual(self.remote_task_mock.mock_calls, [])
|
||||
@@ -184,37 +184,37 @@ class TestWorkerTaskExecutor(test.MockTestCase):
|
||||
self.assertTrue(mocked_exception.called)
|
||||
|
||||
@mock.patch('taskflow.engines.worker_based.remote_task.misc.wallclock')
|
||||
def test_on_wait_task_not_expired(self, mock_time):
|
||||
mock_time.side_effect = [1, self.timeout]
|
||||
def test_on_wait_task_not_expired(self, mocked_time):
|
||||
mocked_time.side_effect = [1, self.timeout]
|
||||
ex = self.executor()
|
||||
ex._store_remote_task(self.remote_task())
|
||||
ex._remote_tasks_cache.set(self.task_uuid, self.remote_task())
|
||||
|
||||
self.assertEqual(len(ex._remote_tasks), 1)
|
||||
self.assertEqual(len(ex._remote_tasks_cache._data), 1)
|
||||
ex._on_wait()
|
||||
self.assertEqual(len(ex._remote_tasks), 1)
|
||||
self.assertEqual(len(ex._remote_tasks_cache._data), 1)
|
||||
|
||||
@mock.patch('taskflow.engines.worker_based.remote_task.misc.wallclock')
|
||||
def test_on_wait_task_expired(self, mock_time):
|
||||
mock_time.side_effect = [1, self.timeout + 2, self.timeout * 2]
|
||||
def test_on_wait_task_expired(self, mocked_time):
|
||||
mocked_time.side_effect = [1, self.timeout + 2, self.timeout * 2]
|
||||
ex = self.executor()
|
||||
ex._store_remote_task(self.remote_task())
|
||||
ex._remote_tasks_cache.set(self.task_uuid, self.remote_task())
|
||||
|
||||
self.assertEqual(len(ex._remote_tasks), 1)
|
||||
self.assertEqual(len(ex._remote_tasks_cache._data), 1)
|
||||
ex._on_wait()
|
||||
self.assertEqual(len(ex._remote_tasks), 0)
|
||||
self.assertEqual(len(ex._remote_tasks_cache._data), 0)
|
||||
|
||||
def test_remove_task_non_existent(self):
|
||||
task = self.remote_task()
|
||||
ex = self.executor()
|
||||
ex._store_remote_task(task)
|
||||
ex._remote_tasks_cache.set(self.task_uuid, task)
|
||||
|
||||
self.assertEqual(len(ex._remote_tasks), 1)
|
||||
ex._remove_remote_task(task)
|
||||
self.assertEqual(len(ex._remote_tasks), 0)
|
||||
self.assertEqual(len(ex._remote_tasks_cache._data), 1)
|
||||
ex._remote_tasks_cache.delete(self.task_uuid)
|
||||
self.assertEqual(len(ex._remote_tasks_cache._data), 0)
|
||||
|
||||
# remove non-existent
|
||||
ex._remove_remote_task(task)
|
||||
self.assertEqual(len(ex._remote_tasks), 0)
|
||||
ex._remote_tasks_cache.delete(self.task_uuid)
|
||||
self.assertEqual(len(ex._remote_tasks_cache._data), 0)
|
||||
|
||||
def test_execute_task(self):
|
||||
request = self.request(action='execute')
|
||||
|
||||
Reference in New Issue
Block a user