Tidy up the WBE cache (now WBE types) module
Instead of using the expiring cache type as a means to store worker information just avoid using that type since we don't support expiry in the first place on worker information and use a worker container and a worker object that we can later extend as needed. Also add on clear methods to the cache type that will be used when the WBE executor stop occurs. This ensures we clear out the worker information and any unfinished requests. Change-Id: I6a520376eff1e8a6edcef0a59f2d8b9c0eb15752
This commit is contained in:

committed by
Joshua Harlow

parent
292adc5a62
commit
45ef595fde
@@ -1,48 +0,0 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
# Copyright (C) 2014 Yahoo! Inc. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License"); you may
|
||||
# not use this file except in compliance with the License. You may obtain
|
||||
# a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
# License for the specific language governing permissions and limitations
|
||||
# under the License.
|
||||
|
||||
import random
|
||||
|
||||
import six
|
||||
|
||||
from taskflow.engines.worker_based import protocol as pr
|
||||
from taskflow.types import cache as base
|
||||
|
||||
|
||||
class RequestsCache(base.ExpiringCache):
|
||||
"""Represents a thread-safe requests cache."""
|
||||
|
||||
def get_waiting_requests(self, tasks):
|
||||
"""Get list of waiting requests by tasks."""
|
||||
waiting_requests = []
|
||||
with self._lock:
|
||||
for request in six.itervalues(self._data):
|
||||
if request.state == pr.WAITING and request.task_cls in tasks:
|
||||
waiting_requests.append(request)
|
||||
return waiting_requests
|
||||
|
||||
|
||||
class WorkersCache(base.ExpiringCache):
|
||||
"""Represents a thread-safe workers cache."""
|
||||
|
||||
def get_topic_by_task(self, task):
|
||||
"""Get topic for a given task."""
|
||||
available_topics = []
|
||||
with self._lock:
|
||||
for topic, tasks in six.iteritems(self._data):
|
||||
if task in tasks:
|
||||
available_topics.append(topic)
|
||||
return random.choice(available_topics) if available_topics else None
|
@@ -15,15 +15,13 @@
|
||||
# under the License.
|
||||
|
||||
import functools
|
||||
import threading
|
||||
|
||||
from oslo_utils import reflection
|
||||
from oslo_utils import timeutils
|
||||
|
||||
from taskflow.engines.action_engine import executor
|
||||
from taskflow.engines.worker_based import cache
|
||||
from taskflow.engines.worker_based import protocol as pr
|
||||
from taskflow.engines.worker_based import proxy
|
||||
from taskflow.engines.worker_based import types as wt
|
||||
from taskflow import exceptions as exc
|
||||
from taskflow import logging
|
||||
from taskflow import task as task_atom
|
||||
@@ -34,35 +32,6 @@ from taskflow.utils import threading_utils as tu
|
||||
LOG = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class PeriodicWorker(object):
|
||||
"""Calls a set of functions when activated periodically.
|
||||
|
||||
NOTE(harlowja): the provided timeout object determines the periodicity.
|
||||
"""
|
||||
def __init__(self, timeout, functors):
|
||||
self._timeout = timeout
|
||||
self._functors = []
|
||||
for f in functors:
|
||||
self._functors.append((f, reflection.get_callable_name(f)))
|
||||
|
||||
def start(self):
|
||||
while not self._timeout.is_stopped():
|
||||
for (f, f_name) in self._functors:
|
||||
LOG.debug("Calling periodic function '%s'", f_name)
|
||||
try:
|
||||
f()
|
||||
except Exception:
|
||||
LOG.warn("Failed to call periodic function '%s'", f_name,
|
||||
exc_info=True)
|
||||
self._timeout.wait()
|
||||
|
||||
def stop(self):
|
||||
self._timeout.interrupt()
|
||||
|
||||
def reset(self):
|
||||
self._timeout.reset()
|
||||
|
||||
|
||||
class WorkerTaskExecutor(executor.TaskExecutor):
|
||||
"""Executes tasks on remote workers."""
|
||||
|
||||
@@ -72,10 +41,9 @@ class WorkerTaskExecutor(executor.TaskExecutor):
|
||||
retry_options=None):
|
||||
self._uuid = uuid
|
||||
self._topics = topics
|
||||
self._requests_cache = cache.RequestsCache()
|
||||
self._requests_cache = wt.RequestsCache()
|
||||
self._workers = wt.TopicWorkers()
|
||||
self._transition_timeout = transition_timeout
|
||||
self._workers_cache = cache.WorkersCache()
|
||||
self._workers_arrival = threading.Condition()
|
||||
type_handlers = {
|
||||
pr.NOTIFY: [
|
||||
self._process_notify,
|
||||
@@ -92,8 +60,8 @@ class WorkerTaskExecutor(executor.TaskExecutor):
|
||||
transport_options=transport_options,
|
||||
retry_options=retry_options)
|
||||
self._proxy_thread = None
|
||||
self._periodic = PeriodicWorker(tt.Timeout(pr.NOTIFY_PERIOD),
|
||||
[self._notify_topics])
|
||||
self._periodic = wt.PeriodicWorker(tt.Timeout(pr.NOTIFY_PERIOD),
|
||||
[self._notify_topics])
|
||||
self._periodic_thread = None
|
||||
|
||||
def _process_notify(self, notify, message):
|
||||
@@ -104,16 +72,15 @@ class WorkerTaskExecutor(executor.TaskExecutor):
|
||||
tasks = notify['tasks']
|
||||
|
||||
# Add worker info to the cache
|
||||
LOG.debug("Received that tasks %s can be processed by topic '%s'",
|
||||
tasks, topic)
|
||||
with self._workers_arrival:
|
||||
self._workers_cache[topic] = tasks
|
||||
self._workers_arrival.notify_all()
|
||||
worker = self._workers.add(topic, tasks)
|
||||
LOG.debug("Received notification about worker '%s' (%s"
|
||||
" total workers are currently known)", worker,
|
||||
len(self._workers))
|
||||
|
||||
# Publish waiting requests
|
||||
for request in self._requests_cache.get_waiting_requests(tasks):
|
||||
for request in self._requests_cache.get_waiting_requests(worker):
|
||||
if request.transition_and_log_error(pr.PENDING, logger=LOG):
|
||||
self._publish_request(request, topic)
|
||||
self._publish_request(request, worker)
|
||||
|
||||
def _process_response(self, response, message):
|
||||
"""Process response from remote side."""
|
||||
@@ -147,7 +114,7 @@ class WorkerTaskExecutor(executor.TaskExecutor):
|
||||
del self._requests_cache[request.uuid]
|
||||
request.set_result(**response.data)
|
||||
else:
|
||||
LOG.warning("Unexpected response status: '%s'",
|
||||
LOG.warning("Unexpected response status '%s'",
|
||||
response.state)
|
||||
else:
|
||||
LOG.debug("Request with id='%s' not found", task_uuid)
|
||||
@@ -196,16 +163,16 @@ class WorkerTaskExecutor(executor.TaskExecutor):
|
||||
progress_callback)
|
||||
request.result.add_done_callback(lambda fut: cleaner())
|
||||
|
||||
# Get task's topic and publish request if topic was found.
|
||||
topic = self._workers_cache.get_topic_by_task(request.task_cls)
|
||||
if topic is not None:
|
||||
# Get task's worker and publish request if worker was found.
|
||||
worker = self._workers.get_worker_for_task(task)
|
||||
if worker is not None:
|
||||
# NOTE(skudriashev): Make sure request is set to the PENDING state
|
||||
# before putting it into the requests cache to prevent the notify
|
||||
# processing thread get list of waiting requests and publish it
|
||||
# before it is published here, so it wouldn't be published twice.
|
||||
if request.transition_and_log_error(pr.PENDING, logger=LOG):
|
||||
self._requests_cache[request.uuid] = request
|
||||
self._publish_request(request, topic)
|
||||
self._publish_request(request, worker)
|
||||
else:
|
||||
LOG.debug("Delaying submission of '%s', no currently known"
|
||||
" worker/s available to process it", request)
|
||||
@@ -213,14 +180,14 @@ class WorkerTaskExecutor(executor.TaskExecutor):
|
||||
|
||||
return request.result
|
||||
|
||||
def _publish_request(self, request, topic):
|
||||
def _publish_request(self, request, worker):
|
||||
"""Publish request to a given topic."""
|
||||
LOG.debug("Submitting execution of '%s' to topic '%s' (expecting"
|
||||
LOG.debug("Submitting execution of '%s' to worker '%s' (expecting"
|
||||
" response identified by reply_to=%s and"
|
||||
" correlation_id=%s)", request, topic, self._uuid,
|
||||
" correlation_id=%s)", request, worker, self._uuid,
|
||||
request.uuid)
|
||||
try:
|
||||
self._proxy.publish(request, topic,
|
||||
self._proxy.publish(request, worker.topic,
|
||||
reply_to=self._uuid,
|
||||
correlation_id=request.uuid)
|
||||
except Exception:
|
||||
@@ -255,20 +222,7 @@ class WorkerTaskExecutor(executor.TaskExecutor):
|
||||
return how many workers are still needed, otherwise it will
|
||||
return zero.
|
||||
"""
|
||||
if workers <= 0:
|
||||
raise ValueError("Worker amount must be greater than zero")
|
||||
w = None
|
||||
if timeout is not None:
|
||||
w = tt.StopWatch(timeout).start()
|
||||
with self._workers_arrival:
|
||||
while len(self._workers_cache) < workers:
|
||||
if w is not None and w.expired():
|
||||
return workers - len(self._workers_cache)
|
||||
timeout = None
|
||||
if w is not None:
|
||||
timeout = w.leftover()
|
||||
self._workers_arrival.wait(timeout)
|
||||
return 0
|
||||
return self._workers.wait_for_workers(workers=workers, timeout=timeout)
|
||||
|
||||
def start(self):
|
||||
"""Starts proxy thread and associated topic notification thread."""
|
||||
@@ -291,3 +245,5 @@ class WorkerTaskExecutor(executor.TaskExecutor):
|
||||
self._proxy.stop()
|
||||
self._proxy_thread.join()
|
||||
self._proxy_thread = None
|
||||
self._requests_cache.clear(self._handle_expired_request)
|
||||
self._workers.clear()
|
||||
|
@@ -221,7 +221,6 @@ class Request(Message):
|
||||
|
||||
def __init__(self, task, uuid, action, arguments, timeout, **kwargs):
|
||||
self._task = task
|
||||
self._task_cls = reflection.get_class_name(task)
|
||||
self._uuid = uuid
|
||||
self._action = action
|
||||
self._event = ACTION_TO_EVENT[action]
|
||||
@@ -248,8 +247,8 @@ class Request(Message):
|
||||
return self._uuid
|
||||
|
||||
@property
|
||||
def task_cls(self):
|
||||
return self._task_cls
|
||||
def task(self):
|
||||
return self._task
|
||||
|
||||
@property
|
||||
def state(self):
|
||||
@@ -281,9 +280,13 @@ class Request(Message):
|
||||
convert all `failure.Failure` objects into dictionaries (which will
|
||||
then be reconstituted by the receiver).
|
||||
"""
|
||||
request = dict(task_cls=self._task_cls, task_name=self._task.name,
|
||||
task_version=self._task.version, action=self._action,
|
||||
arguments=self._arguments)
|
||||
request = {
|
||||
'task_cls': reflection.get_class_name(self._task),
|
||||
'task_name': self._task.name,
|
||||
'task_version': self._task.version,
|
||||
'action': self._action,
|
||||
'arguments': self._arguments,
|
||||
}
|
||||
if 'result' in self._kwargs:
|
||||
result = self._kwargs['result']
|
||||
if isinstance(result, ft.Failure):
|
||||
|
217
taskflow/engines/worker_based/types.py
Normal file
217
taskflow/engines/worker_based/types.py
Normal file
@@ -0,0 +1,217 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
# Copyright (C) 2014 Yahoo! Inc. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License"); you may
|
||||
# not use this file except in compliance with the License. You may obtain
|
||||
# a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
# License for the specific language governing permissions and limitations
|
||||
# under the License.
|
||||
|
||||
import itertools
|
||||
import logging
|
||||
import random
|
||||
import threading
|
||||
|
||||
from oslo.utils import reflection
|
||||
import six
|
||||
|
||||
from taskflow.engines.worker_based import protocol as pr
|
||||
from taskflow.types import cache as base
|
||||
from taskflow.types import timing as tt
|
||||
|
||||
LOG = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class RequestsCache(base.ExpiringCache):
|
||||
"""Represents a thread-safe requests cache."""
|
||||
|
||||
def get_waiting_requests(self, worker):
|
||||
"""Get list of waiting requests that the given worker can satisfy."""
|
||||
waiting_requests = []
|
||||
with self._lock:
|
||||
for request in six.itervalues(self._data):
|
||||
if request.state == pr.WAITING \
|
||||
and worker.performs(request.task):
|
||||
waiting_requests.append(request)
|
||||
return waiting_requests
|
||||
|
||||
|
||||
# TODO(harlowja): this needs to be made better, once
|
||||
# https://blueprints.launchpad.net/taskflow/+spec/wbe-worker-info is finally
|
||||
# implemented we can go about using that instead.
|
||||
class TopicWorker(object):
|
||||
"""A (read-only) worker and its relevant information + useful methods."""
|
||||
|
||||
_NO_IDENTITY = object()
|
||||
|
||||
def __init__(self, topic, tasks, identity=_NO_IDENTITY):
|
||||
self.tasks = []
|
||||
for task in tasks:
|
||||
if not isinstance(task, six.string_types):
|
||||
task = reflection.get_class_name(task)
|
||||
self.tasks.append(task)
|
||||
self.topic = topic
|
||||
self.identity = identity
|
||||
|
||||
def performs(self, task):
|
||||
if not isinstance(task, six.string_types):
|
||||
task = reflection.get_class_name(task)
|
||||
return task in self.tasks
|
||||
|
||||
def __eq__(self, other):
|
||||
if not isinstance(other, TopicWorker):
|
||||
return NotImplemented
|
||||
if len(other.tasks) != len(self.tasks):
|
||||
return False
|
||||
if other.topic != self.topic:
|
||||
return False
|
||||
for task in other.tasks:
|
||||
if not self.performs(task):
|
||||
return False
|
||||
# If one of the identity equals _NO_IDENTITY, then allow it to match...
|
||||
if self._NO_IDENTITY in (self.identity, other.identity):
|
||||
return True
|
||||
else:
|
||||
return other.identity == self.identity
|
||||
|
||||
def __repr__(self):
|
||||
r = reflection.get_class_name(self, fully_qualified=False)
|
||||
if self.identity is not self._NO_IDENTITY:
|
||||
r += "(identity=%s, tasks=%s, topic=%s)" % (self.identity,
|
||||
self.tasks, self.topic)
|
||||
else:
|
||||
r += "(identity=*, tasks=%s, topic=%s)" % (self.tasks, self.topic)
|
||||
return r
|
||||
|
||||
|
||||
class TopicWorkers(object):
|
||||
"""A collection of topic based workers."""
|
||||
|
||||
@staticmethod
|
||||
def _match_worker(task, available_workers):
|
||||
"""Select a worker (from geq 1 workers) that can best perform the task.
|
||||
|
||||
NOTE(harlowja): this method will be activated when there exists
|
||||
one one greater than one potential workers that can perform a task,
|
||||
the arguments provided will be the potential workers located and the
|
||||
task that is being requested to perform and the result should be one
|
||||
of those workers using whatever best-fit algorithm is possible (or
|
||||
random at the least).
|
||||
"""
|
||||
if len(available_workers) == 1:
|
||||
return available_workers[0]
|
||||
else:
|
||||
return random.choice(available_workers)
|
||||
|
||||
def __init__(self):
|
||||
self._workers = {}
|
||||
self._cond = threading.Condition()
|
||||
# Used to name workers with more useful identities...
|
||||
self._counter = itertools.count()
|
||||
|
||||
def __len__(self):
|
||||
return len(self._workers)
|
||||
|
||||
def _next_worker(self, topic, tasks, temporary=False):
|
||||
if not temporary:
|
||||
return TopicWorker(topic, tasks,
|
||||
identity=six.next(self._counter))
|
||||
else:
|
||||
return TopicWorker(topic, tasks)
|
||||
|
||||
def add(self, topic, tasks):
|
||||
"""Adds/updates a worker for the topic for the given tasks."""
|
||||
with self._cond:
|
||||
try:
|
||||
worker = self._workers[topic]
|
||||
# Check if we already have an equivalent worker, if so just
|
||||
# return it...
|
||||
if worker == self._next_worker(topic, tasks, temporary=True):
|
||||
return worker
|
||||
# This *fall through* is done so that if someone is using an
|
||||
# active worker object that already exists that we just create
|
||||
# a new one; so that the existing object doesn't get
|
||||
# affected (workers objects are supposed to be immutable).
|
||||
except KeyError:
|
||||
pass
|
||||
worker = self._next_worker(topic, tasks)
|
||||
self._workers[topic] = worker
|
||||
self._cond.notify_all()
|
||||
return worker
|
||||
|
||||
def wait_for_workers(self, workers=1, timeout=None):
|
||||
"""Waits for geq workers to notify they are ready to do work.
|
||||
|
||||
NOTE(harlowja): if a timeout is provided this function will wait
|
||||
until that timeout expires, if the amount of workers does not reach
|
||||
the desired amount of workers before the timeout expires then this will
|
||||
return how many workers are still needed, otherwise it will
|
||||
return zero.
|
||||
"""
|
||||
if workers <= 0:
|
||||
raise ValueError("Worker amount must be greater than zero")
|
||||
w = None
|
||||
if timeout is not None:
|
||||
w = tt.StopWatch(timeout).start()
|
||||
with self._cond:
|
||||
while len(self._workers) < workers:
|
||||
if w is not None and w.expired():
|
||||
return max(0, workers - len(self._workers))
|
||||
timeout = None
|
||||
if w is not None:
|
||||
timeout = w.leftover()
|
||||
self._cond.wait(timeout)
|
||||
return 0
|
||||
|
||||
def get_worker_for_task(self, task):
|
||||
"""Gets a worker that can perform a given task."""
|
||||
available_workers = []
|
||||
with self._cond:
|
||||
for worker in six.itervalues(self._workers):
|
||||
if worker.performs(task):
|
||||
available_workers.append(worker)
|
||||
if available_workers:
|
||||
return self._match_worker(task, available_workers)
|
||||
else:
|
||||
return None
|
||||
|
||||
def clear(self):
|
||||
with self._cond:
|
||||
self._workers.clear()
|
||||
self._cond.notify_all()
|
||||
|
||||
|
||||
class PeriodicWorker(object):
|
||||
"""Calls a set of functions when activated periodically.
|
||||
|
||||
NOTE(harlowja): the provided timeout object determines the periodicity.
|
||||
"""
|
||||
def __init__(self, timeout, functors):
|
||||
self._timeout = timeout
|
||||
self._functors = []
|
||||
for f in functors:
|
||||
self._functors.append((f, reflection.get_callable_name(f)))
|
||||
|
||||
def start(self):
|
||||
while not self._timeout.is_stopped():
|
||||
for (f, f_name) in self._functors:
|
||||
LOG.debug("Calling periodic function '%s'", f_name)
|
||||
try:
|
||||
f()
|
||||
except Exception:
|
||||
LOG.warn("Failed to call periodic function '%s'", f_name,
|
||||
exc_info=True)
|
||||
self._timeout.wait()
|
||||
|
||||
def stop(self):
|
||||
self._timeout.interrupt()
|
||||
|
||||
def reset(self):
|
||||
self._timeout.reset()
|
@@ -136,7 +136,7 @@ class TestProtocol(test.TestCase):
|
||||
def test_creation(self):
|
||||
request = self.request()
|
||||
self.assertEqual(request.uuid, self.task_uuid)
|
||||
self.assertEqual(request.task_cls, self.task.name)
|
||||
self.assertEqual(request.task, self.task)
|
||||
self.assertIsInstance(request.result, futures.Future)
|
||||
self.assertFalse(request.result.done())
|
||||
|
||||
|
139
taskflow/tests/unit/worker_based/test_types.py
Normal file
139
taskflow/tests/unit/worker_based/test_types.py
Normal file
@@ -0,0 +1,139 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
# Copyright (C) 2014 Yahoo! Inc. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License"); you may
|
||||
# not use this file except in compliance with the License. You may obtain
|
||||
# a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
# License for the specific language governing permissions and limitations
|
||||
# under the License.
|
||||
|
||||
import datetime
|
||||
import threading
|
||||
import time
|
||||
|
||||
from oslo.utils import reflection
|
||||
from oslo.utils import timeutils
|
||||
|
||||
from taskflow.engines.worker_based import protocol as pr
|
||||
from taskflow.engines.worker_based import types as worker_types
|
||||
from taskflow import test
|
||||
from taskflow.tests import utils
|
||||
from taskflow.types import latch
|
||||
from taskflow.types import timing
|
||||
|
||||
|
||||
class TestWorkerTypes(test.TestCase):
|
||||
|
||||
def setUp(self):
|
||||
super(TestWorkerTypes, self).setUp()
|
||||
self.task = utils.DummyTask()
|
||||
self.task_uuid = 'task-uuid'
|
||||
self.task_action = 'execute'
|
||||
self.task_args = {'a': 'a'}
|
||||
self.timeout = 60
|
||||
|
||||
def request(self, **kwargs):
|
||||
request_kwargs = dict(task=self.task,
|
||||
uuid=self.task_uuid,
|
||||
action=self.task_action,
|
||||
arguments=self.task_args,
|
||||
progress_callback=None,
|
||||
timeout=self.timeout)
|
||||
request_kwargs.update(kwargs)
|
||||
return pr.Request(**request_kwargs)
|
||||
|
||||
def test_requests_cache_expiry(self):
|
||||
# Mock out the calls the underlying objects will soon use to return
|
||||
# times that we can control more easily...
|
||||
now = timeutils.utcnow()
|
||||
overrides = [
|
||||
now,
|
||||
now,
|
||||
now + datetime.timedelta(seconds=1),
|
||||
now + datetime.timedelta(seconds=self.timeout + 1),
|
||||
]
|
||||
timeutils.set_time_override(overrides)
|
||||
self.addCleanup(timeutils.clear_time_override)
|
||||
|
||||
cache = worker_types.RequestsCache()
|
||||
cache[self.task_uuid] = self.request()
|
||||
cache.cleanup()
|
||||
self.assertEqual(1, len(cache))
|
||||
cache.cleanup()
|
||||
self.assertEqual(0, len(cache))
|
||||
|
||||
def test_requests_cache_match(self):
|
||||
cache = worker_types.RequestsCache()
|
||||
cache[self.task_uuid] = self.request()
|
||||
cache['task-uuid-2'] = self.request(task=utils.NastyTask(),
|
||||
uuid='task-uuid-2')
|
||||
worker = worker_types.TopicWorker("dummy-topic", [utils.DummyTask],
|
||||
identity="dummy")
|
||||
matches = cache.get_waiting_requests(worker)
|
||||
self.assertEqual(1, len(matches))
|
||||
self.assertEqual(2, len(cache))
|
||||
|
||||
def test_topic_worker(self):
|
||||
worker = worker_types.TopicWorker("dummy-topic",
|
||||
[utils.DummyTask], identity="dummy")
|
||||
self.assertTrue(worker.performs(utils.DummyTask))
|
||||
self.assertFalse(worker.performs(utils.NastyTask))
|
||||
self.assertEqual('dummy', worker.identity)
|
||||
self.assertEqual('dummy-topic', worker.topic)
|
||||
|
||||
def test_single_topic_workers(self):
|
||||
workers = worker_types.TopicWorkers()
|
||||
w = workers.add('dummy-topic', [utils.DummyTask])
|
||||
self.assertIsNotNone(w)
|
||||
self.assertEqual(1, len(workers))
|
||||
w2 = workers.get_worker_for_task(utils.DummyTask)
|
||||
self.assertEqual(w.identity, w2.identity)
|
||||
|
||||
def test_multi_same_topic_workers(self):
|
||||
workers = worker_types.TopicWorkers()
|
||||
w = workers.add('dummy-topic', [utils.DummyTask])
|
||||
self.assertIsNotNone(w)
|
||||
w2 = workers.add('dummy-topic-2', [utils.DummyTask])
|
||||
self.assertIsNotNone(w2)
|
||||
w3 = workers.get_worker_for_task(
|
||||
reflection.get_class_name(utils.DummyTask))
|
||||
self.assertIn(w3.identity, [w.identity, w2.identity])
|
||||
|
||||
def test_multi_different_topic_workers(self):
|
||||
workers = worker_types.TopicWorkers()
|
||||
added = []
|
||||
added.append(workers.add('dummy-topic', [utils.DummyTask]))
|
||||
added.append(workers.add('dummy-topic-2', [utils.DummyTask]))
|
||||
added.append(workers.add('dummy-topic-3', [utils.NastyTask]))
|
||||
self.assertEqual(3, len(workers))
|
||||
w = workers.get_worker_for_task(utils.NastyTask)
|
||||
self.assertEqual(added[-1].identity, w.identity)
|
||||
w = workers.get_worker_for_task(utils.DummyTask)
|
||||
self.assertIn(w.identity, [w_a.identity for w_a in added[0:2]])
|
||||
|
||||
def test_periodic_worker(self):
|
||||
barrier = latch.Latch(5)
|
||||
to = timing.Timeout(0.01)
|
||||
called_at = []
|
||||
|
||||
def callee():
|
||||
barrier.countdown()
|
||||
if barrier.needed == 0:
|
||||
to.interrupt()
|
||||
called_at.append(time.time())
|
||||
|
||||
w = worker_types.PeriodicWorker(to, [callee])
|
||||
t = threading.Thread(target=w.start)
|
||||
t.start()
|
||||
t.join()
|
||||
|
||||
self.assertEqual(0, barrier.needed)
|
||||
self.assertEqual(5, len(called_at))
|
||||
self.assertTrue(to.is_stopped())
|
@@ -54,6 +54,21 @@ class ExpiringCache(object):
|
||||
with self._lock:
|
||||
del self._data[key]
|
||||
|
||||
def clear(self, on_cleared_callback=None):
|
||||
"""Removes all keys & values from the cache."""
|
||||
cleared_items = []
|
||||
with self._lock:
|
||||
if on_cleared_callback is not None:
|
||||
cleared_items.extend(six.iteritems(self._data))
|
||||
self._data.clear()
|
||||
if on_cleared_callback is not None:
|
||||
arg_c = len(reflection.get_callable_args(on_cleared_callback))
|
||||
for (k, v) in cleared_items:
|
||||
if arg_c == 2:
|
||||
on_cleared_callback(k, v)
|
||||
else:
|
||||
on_cleared_callback(v)
|
||||
|
||||
def cleanup(self, on_expired_callback=None):
|
||||
"""Delete out-dated keys & values from the cache."""
|
||||
with self._lock:
|
||||
|
Reference in New Issue
Block a user