Introduce remote tasks cache for worker-executor
Updated WorkerTaskExecutor to use cache for remote tasks. Change-Id: I4572052f63647472367cb69fc02911bbec2bd4cc
This commit is contained in:
		
							
								
								
									
										57
									
								
								taskflow/engines/worker_based/cache.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										57
									
								
								taskflow/engines/worker_based/cache.py
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,57 @@
 | 
			
		||||
# -*- coding: utf-8 -*-
 | 
			
		||||
 | 
			
		||||
#    Copyright (C) 2014 Yahoo! Inc. All Rights Reserved.
 | 
			
		||||
#
 | 
			
		||||
#    Licensed under the Apache License, Version 2.0 (the "License"); you may
 | 
			
		||||
#    not use this file except in compliance with the License. You may obtain
 | 
			
		||||
#    a copy of the License at
 | 
			
		||||
#
 | 
			
		||||
#         http://www.apache.org/licenses/LICENSE-2.0
 | 
			
		||||
#
 | 
			
		||||
#    Unless required by applicable law or agreed to in writing, software
 | 
			
		||||
#    distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
 | 
			
		||||
#    WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
 | 
			
		||||
#    License for the specific language governing permissions and limitations
 | 
			
		||||
#    under the License.
 | 
			
		||||
 | 
			
		||||
import logging
 | 
			
		||||
 | 
			
		||||
import six
 | 
			
		||||
 | 
			
		||||
from taskflow.utils import lock_utils as lu
 | 
			
		||||
 | 
			
		||||
LOG = logging.getLogger(__name__)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class Cache(object):
 | 
			
		||||
    """Represents thread-safe cache."""
 | 
			
		||||
 | 
			
		||||
    def __init__(self):
 | 
			
		||||
        self._data = {}
 | 
			
		||||
        self._lock = lu.ReaderWriterLock()
 | 
			
		||||
 | 
			
		||||
    def get(self, key):
 | 
			
		||||
        """Retrieve a value from the cache."""
 | 
			
		||||
        with self._lock.read_lock():
 | 
			
		||||
            return self._data.get(key)
 | 
			
		||||
 | 
			
		||||
    def set(self, key, value):
 | 
			
		||||
        """Set a value in the cache."""
 | 
			
		||||
        with self._lock.write_lock():
 | 
			
		||||
            self._data[key] = value
 | 
			
		||||
            LOG.debug("Cache updated. Capacity: %s", len(self._data))
 | 
			
		||||
 | 
			
		||||
    def delete(self, key):
 | 
			
		||||
        """Delete a value from the cache."""
 | 
			
		||||
        with self._lock.write_lock():
 | 
			
		||||
            self._data.pop(key, None)
 | 
			
		||||
 | 
			
		||||
    def cleanup(self, on_expired_callback=None):
 | 
			
		||||
        """Delete out-dated values from the cache."""
 | 
			
		||||
        with self._lock.write_lock():
 | 
			
		||||
            expired_values = [(k, v) for k, v in six.iteritems(self._data)
 | 
			
		||||
                              if v.expired]
 | 
			
		||||
            for k, v in expired_values:
 | 
			
		||||
                if on_expired_callback:
 | 
			
		||||
                    on_expired_callback(v)
 | 
			
		||||
                self._data.pop(k, None)
 | 
			
		||||
@@ -15,12 +15,14 @@
 | 
			
		||||
#    under the License.
 | 
			
		||||
 | 
			
		||||
import logging
 | 
			
		||||
import six
 | 
			
		||||
import threading
 | 
			
		||||
 | 
			
		||||
import six
 | 
			
		||||
 | 
			
		||||
from kombu import exceptions as kombu_exc
 | 
			
		||||
 | 
			
		||||
from taskflow.engines.action_engine import executor
 | 
			
		||||
from taskflow.engines.worker_based import cache
 | 
			
		||||
from taskflow.engines.worker_based import protocol as pr
 | 
			
		||||
from taskflow.engines.worker_based import proxy
 | 
			
		||||
from taskflow.engines.worker_based import remote_task as rt
 | 
			
		||||
@@ -40,7 +42,7 @@ class WorkerTaskExecutor(executor.TaskExecutorBase):
 | 
			
		||||
        self._proxy = proxy.Proxy(uuid, exchange, self._on_message,
 | 
			
		||||
                                  self._on_wait, **kwargs)
 | 
			
		||||
        self._proxy_thread = None
 | 
			
		||||
        self._remote_tasks = {}
 | 
			
		||||
        self._remote_tasks_cache = cache.Cache()
 | 
			
		||||
 | 
			
		||||
        # TODO(skudriashev): This data should be collected from workers
 | 
			
		||||
        # using broadcast messages directly.
 | 
			
		||||
@@ -78,55 +80,41 @@ class WorkerTaskExecutor(executor.TaskExecutorBase):
 | 
			
		||||
 | 
			
		||||
    def _process_response(self, task_uuid, response):
 | 
			
		||||
        """Process response from remote side."""
 | 
			
		||||
        try:
 | 
			
		||||
            task = self._remote_tasks[task_uuid]
 | 
			
		||||
        except KeyError:
 | 
			
		||||
            LOG.debug("Task with id='%s' not found.", task_uuid)
 | 
			
		||||
        else:
 | 
			
		||||
        remote_task = self._remote_tasks_cache.get(task_uuid)
 | 
			
		||||
        if remote_task is not None:
 | 
			
		||||
            state = response.pop('state')
 | 
			
		||||
            if state == pr.RUNNING:
 | 
			
		||||
                task.set_running()
 | 
			
		||||
                remote_task.set_running()
 | 
			
		||||
            elif state == pr.PROGRESS:
 | 
			
		||||
                task.on_progress(**response)
 | 
			
		||||
                remote_task.on_progress(**response)
 | 
			
		||||
            elif state == pr.FAILURE:
 | 
			
		||||
                response['result'] = pu.failure_from_dict(response['result'])
 | 
			
		||||
                task.set_result(**response)
 | 
			
		||||
                self._remove_remote_task(task)
 | 
			
		||||
                remote_task.set_result(**response)
 | 
			
		||||
                self._remote_tasks_cache.delete(remote_task.uuid)
 | 
			
		||||
            elif state == pr.SUCCESS:
 | 
			
		||||
                task.set_result(**response)
 | 
			
		||||
                self._remove_remote_task(task)
 | 
			
		||||
                remote_task.set_result(**response)
 | 
			
		||||
                self._remote_tasks_cache.delete(remote_task.uuid)
 | 
			
		||||
            else:
 | 
			
		||||
                LOG.warning("Unexpected response status: '%s'", state)
 | 
			
		||||
        else:
 | 
			
		||||
            LOG.debug("Remote task with id='%s' not found.", task_uuid)
 | 
			
		||||
 | 
			
		||||
    @staticmethod
 | 
			
		||||
    def _handle_expired_remote_task(task):
 | 
			
		||||
        LOG.debug("Remote task '%r' has expired.", task)
 | 
			
		||||
        task.set_result(misc.Failure.from_exception(
 | 
			
		||||
            exc.Timeout("Remote task '%r' has expired" % task)))
 | 
			
		||||
 | 
			
		||||
    def _on_wait(self):
 | 
			
		||||
        """This function is called cyclically between draining events
 | 
			
		||||
        iterations to clean-up expired task requests.
 | 
			
		||||
        """
 | 
			
		||||
        expired_tasks = [task for task in six.itervalues(self._remote_tasks)
 | 
			
		||||
                         if task.expired]
 | 
			
		||||
        for task in expired_tasks:
 | 
			
		||||
            LOG.debug("Task request '%s' has expired.", task)
 | 
			
		||||
            task.set_result(misc.Failure.from_exception(
 | 
			
		||||
                exc.Timeout("Task request '%s' has expired" % task)))
 | 
			
		||||
            del self._remote_tasks[task.uuid]
 | 
			
		||||
 | 
			
		||||
    def _store_remote_task(self, task):
 | 
			
		||||
        """Store task in the remote tasks map."""
 | 
			
		||||
        self._remote_tasks[task.uuid] = task
 | 
			
		||||
        return task
 | 
			
		||||
 | 
			
		||||
    def _remove_remote_task(self, task):
 | 
			
		||||
        """Remove remote task from the tasks map."""
 | 
			
		||||
        if task.uuid in self._remote_tasks:
 | 
			
		||||
            del self._remote_tasks[task.uuid]
 | 
			
		||||
        """This function is called cyclically between draining events."""
 | 
			
		||||
        self._remote_tasks_cache.cleanup(self._handle_expired_remote_task)
 | 
			
		||||
 | 
			
		||||
    def _submit_task(self, task, task_uuid, action, arguments,
 | 
			
		||||
                     progress_callback, timeout=pr.REQUEST_TIMEOUT, **kwargs):
 | 
			
		||||
        """Submit task request to workers."""
 | 
			
		||||
        remote_task = self._store_remote_task(
 | 
			
		||||
            rt.RemoteTask(task, task_uuid, action, arguments,
 | 
			
		||||
        remote_task = rt.RemoteTask(task, task_uuid, action, arguments,
 | 
			
		||||
                                    progress_callback, timeout, **kwargs)
 | 
			
		||||
        )
 | 
			
		||||
        self._remote_tasks_cache.set(remote_task.uuid, remote_task)
 | 
			
		||||
        try:
 | 
			
		||||
            # get task's workers topic to send request to
 | 
			
		||||
            try:
 | 
			
		||||
@@ -143,7 +131,7 @@ class WorkerTaskExecutor(executor.TaskExecutorBase):
 | 
			
		||||
        except Exception:
 | 
			
		||||
            with misc.capture_failure() as failure:
 | 
			
		||||
                LOG.exception("Failed to submit the '%s' task", remote_task)
 | 
			
		||||
                self._remove_remote_task(remote_task)
 | 
			
		||||
                self._remote_tasks_cache.delete(remote_task.uuid)
 | 
			
		||||
                remote_task.set_result(failure)
 | 
			
		||||
        return remote_task.result
 | 
			
		||||
 | 
			
		||||
 
 | 
			
		||||
@@ -105,7 +105,7 @@ class TestWorkerTaskExecutor(test.MockTestCase):
 | 
			
		||||
    def test_on_message_state_running(self):
 | 
			
		||||
        response = dict(state=pr.RUNNING)
 | 
			
		||||
        ex = self.executor()
 | 
			
		||||
        ex._store_remote_task(self.remote_task_mock)
 | 
			
		||||
        ex._remote_tasks_cache.set(self.task_uuid, self.remote_task_mock)
 | 
			
		||||
        ex._on_message(response, self.message_mock)
 | 
			
		||||
 | 
			
		||||
        self.assertEqual(self.remote_task_mock.mock_calls,
 | 
			
		||||
@@ -115,7 +115,7 @@ class TestWorkerTaskExecutor(test.MockTestCase):
 | 
			
		||||
    def test_on_message_state_progress(self):
 | 
			
		||||
        response = dict(state=pr.PROGRESS, progress=1.0)
 | 
			
		||||
        ex = self.executor()
 | 
			
		||||
        ex._store_remote_task(self.remote_task_mock)
 | 
			
		||||
        ex._remote_tasks_cache.set(self.task_uuid, self.remote_task_mock)
 | 
			
		||||
        ex._on_message(response, self.message_mock)
 | 
			
		||||
 | 
			
		||||
        self.assertEqual(self.remote_task_mock.mock_calls,
 | 
			
		||||
@@ -127,10 +127,10 @@ class TestWorkerTaskExecutor(test.MockTestCase):
 | 
			
		||||
        failure_dict = pu.failure_to_dict(failure)
 | 
			
		||||
        response = dict(state=pr.FAILURE, result=failure_dict)
 | 
			
		||||
        ex = self.executor()
 | 
			
		||||
        ex._store_remote_task(self.remote_task_mock)
 | 
			
		||||
        ex._remote_tasks_cache.set(self.task_uuid, self.remote_task_mock)
 | 
			
		||||
        ex._on_message(response, self.message_mock)
 | 
			
		||||
 | 
			
		||||
        self.assertEqual(len(ex._remote_tasks), 0)
 | 
			
		||||
        self.assertEqual(len(ex._remote_tasks_cache._data), 0)
 | 
			
		||||
        self.assertEqual(self.remote_task_mock.mock_calls, [
 | 
			
		||||
            mock.call.set_result(result=utils.FailureMatcher(failure))
 | 
			
		||||
        ])
 | 
			
		||||
@@ -140,7 +140,7 @@ class TestWorkerTaskExecutor(test.MockTestCase):
 | 
			
		||||
        response = dict(state=pr.SUCCESS, result=self.task_result,
 | 
			
		||||
                        event='executed')
 | 
			
		||||
        ex = self.executor()
 | 
			
		||||
        ex._store_remote_task(self.remote_task_mock)
 | 
			
		||||
        ex._remote_tasks_cache.set(self.task_uuid, self.remote_task_mock)
 | 
			
		||||
        ex._on_message(response, self.message_mock)
 | 
			
		||||
 | 
			
		||||
        self.assertEqual(self.remote_task_mock.mock_calls,
 | 
			
		||||
@@ -151,7 +151,7 @@ class TestWorkerTaskExecutor(test.MockTestCase):
 | 
			
		||||
    def test_on_message_unknown_state(self):
 | 
			
		||||
        response = dict(state='unknown')
 | 
			
		||||
        ex = self.executor()
 | 
			
		||||
        ex._store_remote_task(self.remote_task_mock)
 | 
			
		||||
        ex._remote_tasks_cache.set(self.task_uuid, self.remote_task_mock)
 | 
			
		||||
        ex._on_message(response, self.message_mock)
 | 
			
		||||
 | 
			
		||||
        self.assertEqual(self.remote_task_mock.mock_calls, [])
 | 
			
		||||
@@ -161,7 +161,7 @@ class TestWorkerTaskExecutor(test.MockTestCase):
 | 
			
		||||
        self.message_mock.properties = {'correlation_id': 'non-existent'}
 | 
			
		||||
        response = dict(state=pr.RUNNING)
 | 
			
		||||
        ex = self.executor()
 | 
			
		||||
        ex._store_remote_task(self.remote_task_mock)
 | 
			
		||||
        ex._remote_tasks_cache.set(self.task_uuid, self.remote_task_mock)
 | 
			
		||||
        ex._on_message(response, self.message_mock)
 | 
			
		||||
 | 
			
		||||
        self.assertEqual(self.remote_task_mock.mock_calls, [])
 | 
			
		||||
@@ -171,7 +171,7 @@ class TestWorkerTaskExecutor(test.MockTestCase):
 | 
			
		||||
        self.message_mock.properties = {}
 | 
			
		||||
        response = dict(state=pr.RUNNING)
 | 
			
		||||
        ex = self.executor()
 | 
			
		||||
        ex._store_remote_task(self.remote_task_mock)
 | 
			
		||||
        ex._remote_tasks_cache.set(self.task_uuid, self.remote_task_mock)
 | 
			
		||||
        ex._on_message(response, self.message_mock)
 | 
			
		||||
 | 
			
		||||
        self.assertEqual(self.remote_task_mock.mock_calls, [])
 | 
			
		||||
@@ -184,37 +184,37 @@ class TestWorkerTaskExecutor(test.MockTestCase):
 | 
			
		||||
        self.assertTrue(mocked_exception.called)
 | 
			
		||||
 | 
			
		||||
    @mock.patch('taskflow.engines.worker_based.remote_task.misc.wallclock')
 | 
			
		||||
    def test_on_wait_task_not_expired(self, mock_time):
 | 
			
		||||
        mock_time.side_effect = [1, self.timeout]
 | 
			
		||||
    def test_on_wait_task_not_expired(self, mocked_time):
 | 
			
		||||
        mocked_time.side_effect = [1, self.timeout]
 | 
			
		||||
        ex = self.executor()
 | 
			
		||||
        ex._store_remote_task(self.remote_task())
 | 
			
		||||
        ex._remote_tasks_cache.set(self.task_uuid, self.remote_task())
 | 
			
		||||
 | 
			
		||||
        self.assertEqual(len(ex._remote_tasks), 1)
 | 
			
		||||
        self.assertEqual(len(ex._remote_tasks_cache._data), 1)
 | 
			
		||||
        ex._on_wait()
 | 
			
		||||
        self.assertEqual(len(ex._remote_tasks), 1)
 | 
			
		||||
        self.assertEqual(len(ex._remote_tasks_cache._data), 1)
 | 
			
		||||
 | 
			
		||||
    @mock.patch('taskflow.engines.worker_based.remote_task.misc.wallclock')
 | 
			
		||||
    def test_on_wait_task_expired(self, mock_time):
 | 
			
		||||
        mock_time.side_effect = [1, self.timeout + 2, self.timeout * 2]
 | 
			
		||||
    def test_on_wait_task_expired(self, mocked_time):
 | 
			
		||||
        mocked_time.side_effect = [1, self.timeout + 2, self.timeout * 2]
 | 
			
		||||
        ex = self.executor()
 | 
			
		||||
        ex._store_remote_task(self.remote_task())
 | 
			
		||||
        ex._remote_tasks_cache.set(self.task_uuid, self.remote_task())
 | 
			
		||||
 | 
			
		||||
        self.assertEqual(len(ex._remote_tasks), 1)
 | 
			
		||||
        self.assertEqual(len(ex._remote_tasks_cache._data), 1)
 | 
			
		||||
        ex._on_wait()
 | 
			
		||||
        self.assertEqual(len(ex._remote_tasks), 0)
 | 
			
		||||
        self.assertEqual(len(ex._remote_tasks_cache._data), 0)
 | 
			
		||||
 | 
			
		||||
    def test_remove_task_non_existent(self):
 | 
			
		||||
        task = self.remote_task()
 | 
			
		||||
        ex = self.executor()
 | 
			
		||||
        ex._store_remote_task(task)
 | 
			
		||||
        ex._remote_tasks_cache.set(self.task_uuid, task)
 | 
			
		||||
 | 
			
		||||
        self.assertEqual(len(ex._remote_tasks), 1)
 | 
			
		||||
        ex._remove_remote_task(task)
 | 
			
		||||
        self.assertEqual(len(ex._remote_tasks), 0)
 | 
			
		||||
        self.assertEqual(len(ex._remote_tasks_cache._data), 1)
 | 
			
		||||
        ex._remote_tasks_cache.delete(self.task_uuid)
 | 
			
		||||
        self.assertEqual(len(ex._remote_tasks_cache._data), 0)
 | 
			
		||||
 | 
			
		||||
        # remove non-existent
 | 
			
		||||
        ex._remove_remote_task(task)
 | 
			
		||||
        self.assertEqual(len(ex._remote_tasks), 0)
 | 
			
		||||
        ex._remote_tasks_cache.delete(self.task_uuid)
 | 
			
		||||
        self.assertEqual(len(ex._remote_tasks_cache._data), 0)
 | 
			
		||||
 | 
			
		||||
    def test_execute_task(self):
 | 
			
		||||
        request = self.request(action='execute')
 | 
			
		||||
 
 | 
			
		||||
		Reference in New Issue
	
	Block a user