142 lines
4.6 KiB
Python
142 lines
4.6 KiB
Python
# -*- coding: utf-8 -*-
|
|
|
|
# Copyright (C) 2014 Yahoo! Inc. All Rights Reserved.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License"); you may
|
|
# not use this file except in compliance with the License. You may obtain
|
|
# a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
|
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
|
# License for the specific language governing permissions and limitations
|
|
# under the License.
|
|
|
|
from taskflow.engines.worker_based import protocol as pr
|
|
from taskflow.engines.worker_based import proxy
|
|
from taskflow.openstack.common import uuidutils
|
|
from taskflow import test
|
|
from taskflow.test import mock
|
|
from taskflow.tests import utils as test_utils
|
|
from taskflow.types import latch
|
|
from taskflow.utils import threading_utils
|
|
|
|
TEST_EXCHANGE, TEST_TOPIC = ('test-exchange', 'test-topic')
|
|
POLLING_INTERVAL = 0.01
|
|
|
|
|
|
class TestMessagePump(test.TestCase):
|
|
def test_notify(self):
|
|
barrier = threading_utils.Event()
|
|
|
|
on_notify = mock.MagicMock()
|
|
on_notify.side_effect = lambda *args, **kwargs: barrier.set()
|
|
|
|
handlers = {pr.NOTIFY: on_notify}
|
|
p = proxy.Proxy(TEST_TOPIC, TEST_EXCHANGE, handlers,
|
|
transport='memory',
|
|
transport_options={
|
|
'polling_interval': POLLING_INTERVAL,
|
|
})
|
|
|
|
t = threading_utils.daemon_thread(p.start)
|
|
t.start()
|
|
p.wait()
|
|
p.publish(pr.Notify(), TEST_TOPIC)
|
|
|
|
self.assertTrue(barrier.wait(test_utils.WAIT_TIMEOUT))
|
|
p.stop()
|
|
t.join()
|
|
|
|
self.assertTrue(on_notify.called)
|
|
on_notify.assert_called_with({}, mock.ANY)
|
|
|
|
def test_response(self):
|
|
barrier = threading_utils.Event()
|
|
|
|
on_response = mock.MagicMock()
|
|
on_response.side_effect = lambda *args, **kwargs: barrier.set()
|
|
|
|
handlers = {pr.RESPONSE: on_response}
|
|
p = proxy.Proxy(TEST_TOPIC, TEST_EXCHANGE, handlers,
|
|
transport='memory',
|
|
transport_options={
|
|
'polling_interval': POLLING_INTERVAL,
|
|
})
|
|
|
|
t = threading_utils.daemon_thread(p.start)
|
|
t.start()
|
|
p.wait()
|
|
resp = pr.Response(pr.RUNNING)
|
|
p.publish(resp, TEST_TOPIC)
|
|
|
|
self.assertTrue(barrier.wait(test_utils.WAIT_TIMEOUT))
|
|
self.assertTrue(barrier.is_set())
|
|
p.stop()
|
|
t.join()
|
|
|
|
self.assertTrue(on_response.called)
|
|
on_response.assert_called_with(resp.to_dict(), mock.ANY)
|
|
|
|
def test_multi_message(self):
|
|
message_count = 30
|
|
barrier = latch.Latch(message_count)
|
|
countdown = lambda data, message: barrier.countdown()
|
|
|
|
on_notify = mock.MagicMock()
|
|
on_notify.side_effect = countdown
|
|
|
|
on_response = mock.MagicMock()
|
|
on_response.side_effect = countdown
|
|
|
|
on_request = mock.MagicMock()
|
|
on_request.side_effect = countdown
|
|
|
|
handlers = {
|
|
pr.NOTIFY: on_notify,
|
|
pr.RESPONSE: on_response,
|
|
pr.REQUEST: on_request,
|
|
}
|
|
p = proxy.Proxy(TEST_TOPIC, TEST_EXCHANGE, handlers,
|
|
transport='memory',
|
|
transport_options={
|
|
'polling_interval': POLLING_INTERVAL,
|
|
})
|
|
|
|
t = threading_utils.daemon_thread(p.start)
|
|
t.start()
|
|
p.wait()
|
|
|
|
for i in range(0, message_count):
|
|
j = i % 3
|
|
if j == 0:
|
|
p.publish(pr.Notify(), TEST_TOPIC)
|
|
elif j == 1:
|
|
p.publish(pr.Response(pr.RUNNING), TEST_TOPIC)
|
|
else:
|
|
p.publish(pr.Request(test_utils.DummyTask("dummy_%s" % i),
|
|
uuidutils.generate_uuid(),
|
|
pr.EXECUTE, [], None, None), TEST_TOPIC)
|
|
|
|
self.assertTrue(barrier.wait(test_utils.WAIT_TIMEOUT))
|
|
self.assertEqual(0, barrier.needed)
|
|
p.stop()
|
|
t.join()
|
|
|
|
self.assertTrue(on_notify.called)
|
|
self.assertTrue(on_response.called)
|
|
self.assertTrue(on_request.called)
|
|
|
|
self.assertEqual(10, on_notify.call_count)
|
|
self.assertEqual(10, on_response.call_count)
|
|
self.assertEqual(10, on_request.call_count)
|
|
|
|
call_count = sum([
|
|
on_notify.call_count,
|
|
on_response.call_count,
|
|
on_request.call_count,
|
|
])
|
|
self.assertEqual(message_count, call_count)
|