taskflow/taskflow/tests/unit/worker_based/test_message_pump.py

142 lines
4.6 KiB
Python

# -*- coding: utf-8 -*-
# Copyright (C) 2014 Yahoo! Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License. You may obtain
# a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations
# under the License.
from taskflow.engines.worker_based import protocol as pr
from taskflow.engines.worker_based import proxy
from taskflow.openstack.common import uuidutils
from taskflow import test
from taskflow.test import mock
from taskflow.tests import utils as test_utils
from taskflow.types import latch
from taskflow.utils import threading_utils
TEST_EXCHANGE, TEST_TOPIC = ('test-exchange', 'test-topic')
POLLING_INTERVAL = 0.01
class TestMessagePump(test.TestCase):
def test_notify(self):
barrier = threading_utils.Event()
on_notify = mock.MagicMock()
on_notify.side_effect = lambda *args, **kwargs: barrier.set()
handlers = {pr.NOTIFY: on_notify}
p = proxy.Proxy(TEST_TOPIC, TEST_EXCHANGE, handlers,
transport='memory',
transport_options={
'polling_interval': POLLING_INTERVAL,
})
t = threading_utils.daemon_thread(p.start)
t.start()
p.wait()
p.publish(pr.Notify(), TEST_TOPIC)
self.assertTrue(barrier.wait(test_utils.WAIT_TIMEOUT))
p.stop()
t.join()
self.assertTrue(on_notify.called)
on_notify.assert_called_with({}, mock.ANY)
def test_response(self):
barrier = threading_utils.Event()
on_response = mock.MagicMock()
on_response.side_effect = lambda *args, **kwargs: barrier.set()
handlers = {pr.RESPONSE: on_response}
p = proxy.Proxy(TEST_TOPIC, TEST_EXCHANGE, handlers,
transport='memory',
transport_options={
'polling_interval': POLLING_INTERVAL,
})
t = threading_utils.daemon_thread(p.start)
t.start()
p.wait()
resp = pr.Response(pr.RUNNING)
p.publish(resp, TEST_TOPIC)
self.assertTrue(barrier.wait(test_utils.WAIT_TIMEOUT))
self.assertTrue(barrier.is_set())
p.stop()
t.join()
self.assertTrue(on_response.called)
on_response.assert_called_with(resp.to_dict(), mock.ANY)
def test_multi_message(self):
message_count = 30
barrier = latch.Latch(message_count)
countdown = lambda data, message: barrier.countdown()
on_notify = mock.MagicMock()
on_notify.side_effect = countdown
on_response = mock.MagicMock()
on_response.side_effect = countdown
on_request = mock.MagicMock()
on_request.side_effect = countdown
handlers = {
pr.NOTIFY: on_notify,
pr.RESPONSE: on_response,
pr.REQUEST: on_request,
}
p = proxy.Proxy(TEST_TOPIC, TEST_EXCHANGE, handlers,
transport='memory',
transport_options={
'polling_interval': POLLING_INTERVAL,
})
t = threading_utils.daemon_thread(p.start)
t.start()
p.wait()
for i in range(0, message_count):
j = i % 3
if j == 0:
p.publish(pr.Notify(), TEST_TOPIC)
elif j == 1:
p.publish(pr.Response(pr.RUNNING), TEST_TOPIC)
else:
p.publish(pr.Request(test_utils.DummyTask("dummy_%s" % i),
uuidutils.generate_uuid(),
pr.EXECUTE, [], None, None), TEST_TOPIC)
self.assertTrue(barrier.wait(test_utils.WAIT_TIMEOUT))
self.assertEqual(0, barrier.needed)
p.stop()
t.join()
self.assertTrue(on_notify.called)
self.assertTrue(on_response.called)
self.assertTrue(on_request.called)
self.assertEqual(10, on_notify.call_count)
self.assertEqual(10, on_response.call_count)
self.assertEqual(10, on_request.call_count)
call_count = sum([
on_notify.call_count,
on_response.call_count,
on_request.call_count,
])
self.assertEqual(message_count, call_count)