diff --git a/shade/_adapter.py b/shade/_adapter.py index d15214a00..94dde19ed 100644 --- a/shade/_adapter.py +++ b/shade/_adapter.py @@ -140,7 +140,7 @@ class ShadeAdapter(adapter.Adapter): return meta.obj_to_dict(result, request_id=request_id) return result - def request(self, url, method, *args, **kwargs): + def request(self, url, method, run_async=False, *args, **kwargs): name_parts = extract_name(url) name = '.'.join([self.service_type, method] + name_parts) class_name = "".join([ @@ -155,6 +155,7 @@ class ShadeAdapter(adapter.Adapter): super(RequestTask, self).__init__(**kw) self.name = name self.__class__.__name__ = str(class_name) + self.run_async = run_async def main(self, client): self.args.setdefault('raise_exc', False) diff --git a/shade/task_manager.py b/shade/task_manager.py index 1c01b01ed..5144af01f 100644 --- a/shade/task_manager.py +++ b/shade/task_manager.py @@ -15,6 +15,7 @@ # limitations under the License. import abc +import concurrent.futures import sys import threading import time @@ -72,6 +73,7 @@ class BaseTask(object): self._result = None self._response = None self._finished = threading.Event() + self.run_async = False self.args = kw self.name = type(self).__name__ @@ -210,17 +212,23 @@ def generate_task_class(method, name, result_filter_cb): class TaskManager(object): log = _log.setup_logging(__name__) - def __init__(self, client, name, result_filter_cb=None): + def __init__( + self, client, name, result_filter_cb=None, workers=5, **kwargs): self.name = name self._client = client + self._executor = concurrent.futures.ThreadPoolExecutor( + max_workers=workers) if not result_filter_cb: self._result_filter_cb = _result_filter_cb else: self._result_filter_cb = result_filter_cb + def set_client(self, client): + self._client = client + def stop(self): """ This is a direct action passthrough TaskManager """ - pass + self._executor.shutdown(wait=True) def run(self): """ This is a direct action passthrough TaskManager """ @@ -233,15 +241,36 @@ class TaskManager(object): :param bool raw: If True, return the raw result as received from the underlying client call. """ + return self.run_task(task=task, raw=raw) + + def _run_task_async(self, task, raw=False): + self.log.debug( + "Manager %s submitting task %s", self.name, task.name) + return self._executor.submit(self._run_task, task, raw=raw) + + def run_task(self, task, raw=False): + if hasattr(task, 'run_async') and task.run_async: + return self._run_task_async(task, raw=raw) + else: + return self._run_task(task, raw=raw) + + def _run_task(self, task, raw=False): self.log.debug( "Manager %s running task %s", self.name, task.name) start = time.time() task.run(self._client) end = time.time() + dt = end - start self.log.debug( - "Manager %s ran task %s in %ss", - self.name, task.name, (end - start)) + "Manager %s ran task %s in %ss", self.name, task.name, dt) + + self.post_run_task(dt) + return task.wait(raw) + + def post_run_task(self, elasped_time): + pass + # Backwards compatibility submitTask = submit_task @@ -257,4 +286,4 @@ class TaskManager(object): task_class = generate_task_class(method, name, result_filter_cb) - return self.manager.submit_task(task_class(**kwargs)) + return self._executor.submit_task(task_class(**kwargs)) diff --git a/shade/tests/unit/test_task_manager.py b/shade/tests/unit/test_task_manager.py index 46531c8e0..1a416ee44 100644 --- a/shade/tests/unit/test_task_manager.py +++ b/shade/tests/unit/test_task_manager.py @@ -13,6 +13,9 @@ # limitations under the License. +import concurrent.futures +import mock + from shade import task_manager from shade.tests.unit import base @@ -56,6 +59,15 @@ class TaskTestSet(task_manager.Task): return set([1, 2]) +class TaskTestAsync(task_manager.Task): + def __init__(self): + super(task_manager.Task, self).__init__() + self.run_async = True + + def main(self, client): + pass + + class TestTaskManager(base.TestCase): def setUp(self): @@ -90,3 +102,8 @@ class TestTaskManager(base.TestCase): def test_dont_munchify_set(self): ret = self.manager.submit_task(TaskTestSet()) self.assertIsInstance(ret, set) + + @mock.patch.object(concurrent.futures.ThreadPoolExecutor, 'submit') + def test_async(self, mock_submit): + self.manager.submit_task(TaskTestAsync()) + self.assertTrue(mock_submit.called)