Introduce remote tasks cache for worker-executor
Updated WorkerTaskExecutor to use cache for remote tasks. Change-Id: I4572052f63647472367cb69fc02911bbec2bd4cc
This commit is contained in:
57
taskflow/engines/worker_based/cache.py
Normal file
57
taskflow/engines/worker_based/cache.py
Normal file
@@ -0,0 +1,57 @@
|
|||||||
|
# -*- coding: utf-8 -*-
|
||||||
|
|
||||||
|
# Copyright (C) 2014 Yahoo! Inc. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License"); you may
|
||||||
|
# not use this file except in compliance with the License. You may obtain
|
||||||
|
# a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||||
|
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||||
|
# License for the specific language governing permissions and limitations
|
||||||
|
# under the License.
|
||||||
|
|
||||||
|
import logging
|
||||||
|
|
||||||
|
import six
|
||||||
|
|
||||||
|
from taskflow.utils import lock_utils as lu
|
||||||
|
|
||||||
|
LOG = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class Cache(object):
|
||||||
|
"""Represents thread-safe cache."""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
self._data = {}
|
||||||
|
self._lock = lu.ReaderWriterLock()
|
||||||
|
|
||||||
|
def get(self, key):
|
||||||
|
"""Retrieve a value from the cache."""
|
||||||
|
with self._lock.read_lock():
|
||||||
|
return self._data.get(key)
|
||||||
|
|
||||||
|
def set(self, key, value):
|
||||||
|
"""Set a value in the cache."""
|
||||||
|
with self._lock.write_lock():
|
||||||
|
self._data[key] = value
|
||||||
|
LOG.debug("Cache updated. Capacity: %s", len(self._data))
|
||||||
|
|
||||||
|
def delete(self, key):
|
||||||
|
"""Delete a value from the cache."""
|
||||||
|
with self._lock.write_lock():
|
||||||
|
self._data.pop(key, None)
|
||||||
|
|
||||||
|
def cleanup(self, on_expired_callback=None):
|
||||||
|
"""Delete out-dated values from the cache."""
|
||||||
|
with self._lock.write_lock():
|
||||||
|
expired_values = [(k, v) for k, v in six.iteritems(self._data)
|
||||||
|
if v.expired]
|
||||||
|
for k, v in expired_values:
|
||||||
|
if on_expired_callback:
|
||||||
|
on_expired_callback(v)
|
||||||
|
self._data.pop(k, None)
|
||||||
@@ -15,12 +15,14 @@
|
|||||||
# under the License.
|
# under the License.
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
import six
|
|
||||||
import threading
|
import threading
|
||||||
|
|
||||||
|
import six
|
||||||
|
|
||||||
from kombu import exceptions as kombu_exc
|
from kombu import exceptions as kombu_exc
|
||||||
|
|
||||||
from taskflow.engines.action_engine import executor
|
from taskflow.engines.action_engine import executor
|
||||||
|
from taskflow.engines.worker_based import cache
|
||||||
from taskflow.engines.worker_based import protocol as pr
|
from taskflow.engines.worker_based import protocol as pr
|
||||||
from taskflow.engines.worker_based import proxy
|
from taskflow.engines.worker_based import proxy
|
||||||
from taskflow.engines.worker_based import remote_task as rt
|
from taskflow.engines.worker_based import remote_task as rt
|
||||||
@@ -40,7 +42,7 @@ class WorkerTaskExecutor(executor.TaskExecutorBase):
|
|||||||
self._proxy = proxy.Proxy(uuid, exchange, self._on_message,
|
self._proxy = proxy.Proxy(uuid, exchange, self._on_message,
|
||||||
self._on_wait, **kwargs)
|
self._on_wait, **kwargs)
|
||||||
self._proxy_thread = None
|
self._proxy_thread = None
|
||||||
self._remote_tasks = {}
|
self._remote_tasks_cache = cache.Cache()
|
||||||
|
|
||||||
# TODO(skudriashev): This data should be collected from workers
|
# TODO(skudriashev): This data should be collected from workers
|
||||||
# using broadcast messages directly.
|
# using broadcast messages directly.
|
||||||
@@ -78,55 +80,41 @@ class WorkerTaskExecutor(executor.TaskExecutorBase):
|
|||||||
|
|
||||||
def _process_response(self, task_uuid, response):
|
def _process_response(self, task_uuid, response):
|
||||||
"""Process response from remote side."""
|
"""Process response from remote side."""
|
||||||
try:
|
remote_task = self._remote_tasks_cache.get(task_uuid)
|
||||||
task = self._remote_tasks[task_uuid]
|
if remote_task is not None:
|
||||||
except KeyError:
|
|
||||||
LOG.debug("Task with id='%s' not found.", task_uuid)
|
|
||||||
else:
|
|
||||||
state = response.pop('state')
|
state = response.pop('state')
|
||||||
if state == pr.RUNNING:
|
if state == pr.RUNNING:
|
||||||
task.set_running()
|
remote_task.set_running()
|
||||||
elif state == pr.PROGRESS:
|
elif state == pr.PROGRESS:
|
||||||
task.on_progress(**response)
|
remote_task.on_progress(**response)
|
||||||
elif state == pr.FAILURE:
|
elif state == pr.FAILURE:
|
||||||
response['result'] = pu.failure_from_dict(response['result'])
|
response['result'] = pu.failure_from_dict(response['result'])
|
||||||
task.set_result(**response)
|
remote_task.set_result(**response)
|
||||||
self._remove_remote_task(task)
|
self._remote_tasks_cache.delete(remote_task.uuid)
|
||||||
elif state == pr.SUCCESS:
|
elif state == pr.SUCCESS:
|
||||||
task.set_result(**response)
|
remote_task.set_result(**response)
|
||||||
self._remove_remote_task(task)
|
self._remote_tasks_cache.delete(remote_task.uuid)
|
||||||
else:
|
else:
|
||||||
LOG.warning("Unexpected response status: '%s'", state)
|
LOG.warning("Unexpected response status: '%s'", state)
|
||||||
|
else:
|
||||||
|
LOG.debug("Remote task with id='%s' not found.", task_uuid)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _handle_expired_remote_task(task):
|
||||||
|
LOG.debug("Remote task '%r' has expired.", task)
|
||||||
|
task.set_result(misc.Failure.from_exception(
|
||||||
|
exc.Timeout("Remote task '%r' has expired" % task)))
|
||||||
|
|
||||||
def _on_wait(self):
|
def _on_wait(self):
|
||||||
"""This function is called cyclically between draining events
|
"""This function is called cyclically between draining events."""
|
||||||
iterations to clean-up expired task requests.
|
self._remote_tasks_cache.cleanup(self._handle_expired_remote_task)
|
||||||
"""
|
|
||||||
expired_tasks = [task for task in six.itervalues(self._remote_tasks)
|
|
||||||
if task.expired]
|
|
||||||
for task in expired_tasks:
|
|
||||||
LOG.debug("Task request '%s' has expired.", task)
|
|
||||||
task.set_result(misc.Failure.from_exception(
|
|
||||||
exc.Timeout("Task request '%s' has expired" % task)))
|
|
||||||
del self._remote_tasks[task.uuid]
|
|
||||||
|
|
||||||
def _store_remote_task(self, task):
|
|
||||||
"""Store task in the remote tasks map."""
|
|
||||||
self._remote_tasks[task.uuid] = task
|
|
||||||
return task
|
|
||||||
|
|
||||||
def _remove_remote_task(self, task):
|
|
||||||
"""Remove remote task from the tasks map."""
|
|
||||||
if task.uuid in self._remote_tasks:
|
|
||||||
del self._remote_tasks[task.uuid]
|
|
||||||
|
|
||||||
def _submit_task(self, task, task_uuid, action, arguments,
|
def _submit_task(self, task, task_uuid, action, arguments,
|
||||||
progress_callback, timeout=pr.REQUEST_TIMEOUT, **kwargs):
|
progress_callback, timeout=pr.REQUEST_TIMEOUT, **kwargs):
|
||||||
"""Submit task request to workers."""
|
"""Submit task request to workers."""
|
||||||
remote_task = self._store_remote_task(
|
remote_task = rt.RemoteTask(task, task_uuid, action, arguments,
|
||||||
rt.RemoteTask(task, task_uuid, action, arguments,
|
progress_callback, timeout, **kwargs)
|
||||||
progress_callback, timeout, **kwargs)
|
self._remote_tasks_cache.set(remote_task.uuid, remote_task)
|
||||||
)
|
|
||||||
try:
|
try:
|
||||||
# get task's workers topic to send request to
|
# get task's workers topic to send request to
|
||||||
try:
|
try:
|
||||||
@@ -143,7 +131,7 @@ class WorkerTaskExecutor(executor.TaskExecutorBase):
|
|||||||
except Exception:
|
except Exception:
|
||||||
with misc.capture_failure() as failure:
|
with misc.capture_failure() as failure:
|
||||||
LOG.exception("Failed to submit the '%s' task", remote_task)
|
LOG.exception("Failed to submit the '%s' task", remote_task)
|
||||||
self._remove_remote_task(remote_task)
|
self._remote_tasks_cache.delete(remote_task.uuid)
|
||||||
remote_task.set_result(failure)
|
remote_task.set_result(failure)
|
||||||
return remote_task.result
|
return remote_task.result
|
||||||
|
|
||||||
|
|||||||
@@ -105,7 +105,7 @@ class TestWorkerTaskExecutor(test.MockTestCase):
|
|||||||
def test_on_message_state_running(self):
|
def test_on_message_state_running(self):
|
||||||
response = dict(state=pr.RUNNING)
|
response = dict(state=pr.RUNNING)
|
||||||
ex = self.executor()
|
ex = self.executor()
|
||||||
ex._store_remote_task(self.remote_task_mock)
|
ex._remote_tasks_cache.set(self.task_uuid, self.remote_task_mock)
|
||||||
ex._on_message(response, self.message_mock)
|
ex._on_message(response, self.message_mock)
|
||||||
|
|
||||||
self.assertEqual(self.remote_task_mock.mock_calls,
|
self.assertEqual(self.remote_task_mock.mock_calls,
|
||||||
@@ -115,7 +115,7 @@ class TestWorkerTaskExecutor(test.MockTestCase):
|
|||||||
def test_on_message_state_progress(self):
|
def test_on_message_state_progress(self):
|
||||||
response = dict(state=pr.PROGRESS, progress=1.0)
|
response = dict(state=pr.PROGRESS, progress=1.0)
|
||||||
ex = self.executor()
|
ex = self.executor()
|
||||||
ex._store_remote_task(self.remote_task_mock)
|
ex._remote_tasks_cache.set(self.task_uuid, self.remote_task_mock)
|
||||||
ex._on_message(response, self.message_mock)
|
ex._on_message(response, self.message_mock)
|
||||||
|
|
||||||
self.assertEqual(self.remote_task_mock.mock_calls,
|
self.assertEqual(self.remote_task_mock.mock_calls,
|
||||||
@@ -127,10 +127,10 @@ class TestWorkerTaskExecutor(test.MockTestCase):
|
|||||||
failure_dict = pu.failure_to_dict(failure)
|
failure_dict = pu.failure_to_dict(failure)
|
||||||
response = dict(state=pr.FAILURE, result=failure_dict)
|
response = dict(state=pr.FAILURE, result=failure_dict)
|
||||||
ex = self.executor()
|
ex = self.executor()
|
||||||
ex._store_remote_task(self.remote_task_mock)
|
ex._remote_tasks_cache.set(self.task_uuid, self.remote_task_mock)
|
||||||
ex._on_message(response, self.message_mock)
|
ex._on_message(response, self.message_mock)
|
||||||
|
|
||||||
self.assertEqual(len(ex._remote_tasks), 0)
|
self.assertEqual(len(ex._remote_tasks_cache._data), 0)
|
||||||
self.assertEqual(self.remote_task_mock.mock_calls, [
|
self.assertEqual(self.remote_task_mock.mock_calls, [
|
||||||
mock.call.set_result(result=utils.FailureMatcher(failure))
|
mock.call.set_result(result=utils.FailureMatcher(failure))
|
||||||
])
|
])
|
||||||
@@ -140,7 +140,7 @@ class TestWorkerTaskExecutor(test.MockTestCase):
|
|||||||
response = dict(state=pr.SUCCESS, result=self.task_result,
|
response = dict(state=pr.SUCCESS, result=self.task_result,
|
||||||
event='executed')
|
event='executed')
|
||||||
ex = self.executor()
|
ex = self.executor()
|
||||||
ex._store_remote_task(self.remote_task_mock)
|
ex._remote_tasks_cache.set(self.task_uuid, self.remote_task_mock)
|
||||||
ex._on_message(response, self.message_mock)
|
ex._on_message(response, self.message_mock)
|
||||||
|
|
||||||
self.assertEqual(self.remote_task_mock.mock_calls,
|
self.assertEqual(self.remote_task_mock.mock_calls,
|
||||||
@@ -151,7 +151,7 @@ class TestWorkerTaskExecutor(test.MockTestCase):
|
|||||||
def test_on_message_unknown_state(self):
|
def test_on_message_unknown_state(self):
|
||||||
response = dict(state='unknown')
|
response = dict(state='unknown')
|
||||||
ex = self.executor()
|
ex = self.executor()
|
||||||
ex._store_remote_task(self.remote_task_mock)
|
ex._remote_tasks_cache.set(self.task_uuid, self.remote_task_mock)
|
||||||
ex._on_message(response, self.message_mock)
|
ex._on_message(response, self.message_mock)
|
||||||
|
|
||||||
self.assertEqual(self.remote_task_mock.mock_calls, [])
|
self.assertEqual(self.remote_task_mock.mock_calls, [])
|
||||||
@@ -161,7 +161,7 @@ class TestWorkerTaskExecutor(test.MockTestCase):
|
|||||||
self.message_mock.properties = {'correlation_id': 'non-existent'}
|
self.message_mock.properties = {'correlation_id': 'non-existent'}
|
||||||
response = dict(state=pr.RUNNING)
|
response = dict(state=pr.RUNNING)
|
||||||
ex = self.executor()
|
ex = self.executor()
|
||||||
ex._store_remote_task(self.remote_task_mock)
|
ex._remote_tasks_cache.set(self.task_uuid, self.remote_task_mock)
|
||||||
ex._on_message(response, self.message_mock)
|
ex._on_message(response, self.message_mock)
|
||||||
|
|
||||||
self.assertEqual(self.remote_task_mock.mock_calls, [])
|
self.assertEqual(self.remote_task_mock.mock_calls, [])
|
||||||
@@ -171,7 +171,7 @@ class TestWorkerTaskExecutor(test.MockTestCase):
|
|||||||
self.message_mock.properties = {}
|
self.message_mock.properties = {}
|
||||||
response = dict(state=pr.RUNNING)
|
response = dict(state=pr.RUNNING)
|
||||||
ex = self.executor()
|
ex = self.executor()
|
||||||
ex._store_remote_task(self.remote_task_mock)
|
ex._remote_tasks_cache.set(self.task_uuid, self.remote_task_mock)
|
||||||
ex._on_message(response, self.message_mock)
|
ex._on_message(response, self.message_mock)
|
||||||
|
|
||||||
self.assertEqual(self.remote_task_mock.mock_calls, [])
|
self.assertEqual(self.remote_task_mock.mock_calls, [])
|
||||||
@@ -184,37 +184,37 @@ class TestWorkerTaskExecutor(test.MockTestCase):
|
|||||||
self.assertTrue(mocked_exception.called)
|
self.assertTrue(mocked_exception.called)
|
||||||
|
|
||||||
@mock.patch('taskflow.engines.worker_based.remote_task.misc.wallclock')
|
@mock.patch('taskflow.engines.worker_based.remote_task.misc.wallclock')
|
||||||
def test_on_wait_task_not_expired(self, mock_time):
|
def test_on_wait_task_not_expired(self, mocked_time):
|
||||||
mock_time.side_effect = [1, self.timeout]
|
mocked_time.side_effect = [1, self.timeout]
|
||||||
ex = self.executor()
|
ex = self.executor()
|
||||||
ex._store_remote_task(self.remote_task())
|
ex._remote_tasks_cache.set(self.task_uuid, self.remote_task())
|
||||||
|
|
||||||
self.assertEqual(len(ex._remote_tasks), 1)
|
self.assertEqual(len(ex._remote_tasks_cache._data), 1)
|
||||||
ex._on_wait()
|
ex._on_wait()
|
||||||
self.assertEqual(len(ex._remote_tasks), 1)
|
self.assertEqual(len(ex._remote_tasks_cache._data), 1)
|
||||||
|
|
||||||
@mock.patch('taskflow.engines.worker_based.remote_task.misc.wallclock')
|
@mock.patch('taskflow.engines.worker_based.remote_task.misc.wallclock')
|
||||||
def test_on_wait_task_expired(self, mock_time):
|
def test_on_wait_task_expired(self, mocked_time):
|
||||||
mock_time.side_effect = [1, self.timeout + 2, self.timeout * 2]
|
mocked_time.side_effect = [1, self.timeout + 2, self.timeout * 2]
|
||||||
ex = self.executor()
|
ex = self.executor()
|
||||||
ex._store_remote_task(self.remote_task())
|
ex._remote_tasks_cache.set(self.task_uuid, self.remote_task())
|
||||||
|
|
||||||
self.assertEqual(len(ex._remote_tasks), 1)
|
self.assertEqual(len(ex._remote_tasks_cache._data), 1)
|
||||||
ex._on_wait()
|
ex._on_wait()
|
||||||
self.assertEqual(len(ex._remote_tasks), 0)
|
self.assertEqual(len(ex._remote_tasks_cache._data), 0)
|
||||||
|
|
||||||
def test_remove_task_non_existent(self):
|
def test_remove_task_non_existent(self):
|
||||||
task = self.remote_task()
|
task = self.remote_task()
|
||||||
ex = self.executor()
|
ex = self.executor()
|
||||||
ex._store_remote_task(task)
|
ex._remote_tasks_cache.set(self.task_uuid, task)
|
||||||
|
|
||||||
self.assertEqual(len(ex._remote_tasks), 1)
|
self.assertEqual(len(ex._remote_tasks_cache._data), 1)
|
||||||
ex._remove_remote_task(task)
|
ex._remote_tasks_cache.delete(self.task_uuid)
|
||||||
self.assertEqual(len(ex._remote_tasks), 0)
|
self.assertEqual(len(ex._remote_tasks_cache._data), 0)
|
||||||
|
|
||||||
# remove non-existent
|
# remove non-existent
|
||||||
ex._remove_remote_task(task)
|
ex._remote_tasks_cache.delete(self.task_uuid)
|
||||||
self.assertEqual(len(ex._remote_tasks), 0)
|
self.assertEqual(len(ex._remote_tasks_cache._data), 0)
|
||||||
|
|
||||||
def test_execute_task(self):
|
def test_execute_task(self):
|
||||||
request = self.request(action='execute')
|
request = self.request(action='execute')
|
||||||
|
|||||||
Reference in New Issue
Block a user