diff --git a/examples/websocket.html b/examples/websocket.html index 37134554c..7216a2570 100644 --- a/examples/websocket.html +++ b/examples/websocket.html @@ -6,6 +6,11 @@ + + diff --git a/zaqar/tests/etc/websocket_mongodb_subscriptions.conf b/zaqar/tests/etc/websocket_mongodb_subscriptions.conf new file mode 100644 index 000000000..56a97899d --- /dev/null +++ b/zaqar/tests/etc/websocket_mongodb_subscriptions.conf @@ -0,0 +1,23 @@ +[DEFAULT] +unreliable = True + +[drivers] + +# Transport driver to use (string value) +transport = websocket + +# Storage driver to use (string value) +message_store = mongodb + +[drivers:management_store:mongodb] + +# Mongodb Connection URI +uri = mongodb://127.0.0.1:27017 + +[drivers:message_store:mongodb] + +# Mongodb Connection URI +uri = mongodb://127.0.0.1:27017 + +[storage] +message_pipeline = zaqar.notification.notifier \ No newline at end of file diff --git a/zaqar/tests/unit/transport/websocket/test_protocol.py b/zaqar/tests/unit/transport/websocket/test_protocol.py index f601965f5..18dde3c8e 100644 --- a/zaqar/tests/unit/transport/websocket/test_protocol.py +++ b/zaqar/tests/unit/transport/websocket/test_protocol.py @@ -14,22 +14,27 @@ # under the License. import json +import uuid +import ddt import mock from zaqar.tests.unit.transport.websocket import base +from zaqar.tests.unit.transport.websocket import utils as test_utils +@ddt.ddt class TestMessagingProtocol(base.TestBase): config_file = "websocket_mongodb.conf" def setUp(self): super(TestMessagingProtocol, self).setUp() self.protocol = self.transport.factory() - self.defaults = self.api.get_defaults() - - def tearDown(self): - super(TestMessagingProtocol, self).tearDown() + self.project_id = 'protocol-test' + self.headers = { + 'Client-ID': str(uuid.uuid4()), + 'X-Project-ID': self.project_id + } def test_on_message_with_invalid_input(self): payload = u'\ufeff' @@ -45,3 +50,45 @@ class TestMessagingProtocol(base.TestBase): self.protocol.onMessage(payload, False) resp = json.loads(send_mock.call_args[0][0]) self.assertEqual(400, resp['headers']['status']) + + def test_on_message_with_invalid_input_binary(self): + dumps, loads, create_req = test_utils.get_pack_tools(binary=True) + send_mock = mock.Mock() + self.protocol.sendMessage = send_mock + + # Test error response, when the request can't be deserialized. + req = "123" + self.protocol.onMessage(req, True) + resp = loads(send_mock.call_args[0][0]) + self.assertEqual(400, resp['headers']['status']) + self.assertIn('Can\'t decode binary', resp['body']['error']) + + # Test error response, when request body is not a dictionary. + req = dumps("Apparently, I'm not a dictionary") + self.protocol.onMessage(req, True) + resp = loads(send_mock.call_args[0][0]) + self.assertEqual(400, resp['headers']['status']) + self.assertIn('Unexpected body type. Expected dict', + resp['body']['error']) + + # Test error response, when validation fails. + action = 'queue_glorify' + body = {} + req = create_req(action, body, self.headers) + self.protocol.onMessage(req, True) + resp = loads(send_mock.call_args[0][0]) + self.assertEqual(400, resp['headers']['status']) + self.assertEqual('queue_glorify is not a valid action', + resp['body']['error']) + + @ddt.data(True, False) + def test_on_message_with_input_in_different_format(self, in_binary): + dumps, loads, create_req = test_utils.get_pack_tools(binary=in_binary) + action = 'queue_get' + body = {'queue_name': 'beautiful-non-existing-queue'} + req = create_req(action, body, self.headers) + send_mock = mock.Mock() + self.protocol.sendMessage = send_mock + self.protocol.onMessage(req, in_binary) + resp = loads(send_mock.call_args[0][0]) + self.assertEqual(200, resp['headers']['status']) diff --git a/zaqar/tests/unit/transport/websocket/utils.py b/zaqar/tests/unit/transport/websocket/utils.py index 14f1a08fc..0bfa51156 100644 --- a/zaqar/tests/unit/transport/websocket/utils.py +++ b/zaqar/tests/unit/transport/websocket/utils.py @@ -12,8 +12,37 @@ # License for the specific language governing permissions and limitations under # the License. +import functools import json +import msgpack def create_request(action, body, headers): return json.dumps({"action": action, "body": body, "headers": headers}) + + +def create_binary_request(action, body, headers): + return msgpack.packb({"action": action, "body": body, "headers": headers}) + + +def get_pack_tools(binary=None): + """Get serialization tools for testing websocket transport. + + :param bool binary: type of serialization tools. + True: binary (MessagePack) tools. + False: text (JSON) tools. + :returns: set of serialization tools needed for testing websocket + transport: (dumps, loads, create_request_function) + :rtype: tuple + """ + if binary is None: + raise Exception("binary param is unspecified") + if binary: + dumps = msgpack.Packer(encoding='utf-8', use_bin_type=False).pack + loads = functools.partial(msgpack.unpackb, encoding='utf-8') + create_request_function = create_binary_request + else: + dumps = json.dumps + loads = json.loads + create_request_function = create_request + return dumps, loads, create_request_function diff --git a/zaqar/tests/unit/transport/websocket/v2/test_auth.py b/zaqar/tests/unit/transport/websocket/v2/test_auth.py index 5bf824603..a2c7fa3ec 100644 --- a/zaqar/tests/unit/transport/websocket/v2/test_auth.py +++ b/zaqar/tests/unit/transport/websocket/v2/test_auth.py @@ -16,6 +16,7 @@ import json import uuid +import ddt from keystonemiddleware import auth_token import mock @@ -24,8 +25,8 @@ from zaqar.tests.unit.transport.websocket import base from zaqar.tests.unit.transport.websocket import utils as test_utils +@ddt.ddt class AuthTest(base.V2Base): - config_file = "websocket_mongodb_keystone_auth.conf" def setUp(self): @@ -87,6 +88,7 @@ class AuthTest(base.V2Base): msg_mock = mock.patch.object(self.protocol, 'sendMessage') self.addCleanup(msg_mock.stop) msg_mock = msg_mock.start() + self.protocol._auth_in_binary = False self.protocol._auth_response('401 error', 'Failed') self.assertEqual(1, msg_mock.call_count) resp = json.loads(msg_mock.call_args[0][0]) @@ -122,6 +124,25 @@ class AuthTest(base.V2Base): self.assertIn('cancelled', repr(handle)) self.assertNotIn('cancelled', repr(self.protocol._deauth_handle)) + @ddt.data(True, False) + def test_auth_response_serialization_format(self, in_binary): + dumps, loads, create_req = test_utils.get_pack_tools(binary=in_binary) + headers = self.headers.copy() + headers['X-Auth-Token'] = 'mytoken1' + req = create_req("authenticate", {}, headers) + + msg_mock = mock.patch.object(self.protocol, 'sendMessage') + self.addCleanup(msg_mock.stop) + msg_mock = msg_mock.start() + # Depending on onMessage method's second argument, auth response should + # be in binary or text format. + self.protocol.onMessage(req, in_binary) + self.assertEqual(in_binary, self.protocol._auth_in_binary) + self.protocol._auth_response('401 error', 'Failed') + self.assertEqual(1, msg_mock.call_count) + resp = loads(msg_mock.call_args[0][0]) + self.assertEqual(401, resp['headers']['status']) + def test_signed_url(self): send_mock = mock.Mock() self.protocol.sendMessage = send_mock diff --git a/zaqar/tests/unit/transport/websocket/v2/test_messages.py b/zaqar/tests/unit/transport/websocket/v2/test_messages.py index 86639d3e6..ef82661a9 100644 --- a/zaqar/tests/unit/transport/websocket/v2/test_messages.py +++ b/zaqar/tests/unit/transport/websocket/v2/test_messages.py @@ -69,7 +69,7 @@ class MessagesBaseTest(base.V2Base): resp = json.loads(send_mock.call_args[0][0]) self.assertEqual(204, resp['headers']['status']) - def _test_post(self, sample_messages): + def _test_post(self, sample_messages, in_binary=False): action = "message_post" body = {"queue_name": "kitkat", "messages": sample_messages} @@ -77,11 +77,13 @@ class MessagesBaseTest(base.V2Base): send_mock = mock.Mock() self.protocol.sendMessage = send_mock - req = test_utils.create_request(action, body, self.headers) + dumps, loads, create_req = test_utils.get_pack_tools(binary=in_binary) - self.protocol.onMessage(req, False) + req = create_req(action, body, self.headers) - resp = json.loads(send_mock.call_args[0][0]) + self.protocol.onMessage(req, in_binary) + + resp = loads(send_mock.call_args[0][0]) self.assertEqual(201, resp['headers']['status']) self.msg_ids = resp['body']['message_ids'] self.assertEqual(len(sample_messages), len(self.msg_ids)) @@ -102,19 +104,19 @@ class MessagesBaseTest(base.V2Base): body = {"queue_name": "kitkat", "message_id": msg_id} - req = test_utils.create_request(action, body, headers) + req = create_req(action, body, headers) - self.protocol.onMessage(req, False) + self.protocol.onMessage(req, in_binary) - resp = json.loads(send_mock.call_args[0][0]) + resp = loads(send_mock.call_args[0][0]) self.assertEqual(404, resp['headers']['status']) # Correct project ID - req = test_utils.create_request(action, body, self.headers) + req = create_req(action, body, self.headers) - self.protocol.onMessage(req, False) + self.protocol.onMessage(req, in_binary) - resp = json.loads(send_mock.call_args[0][0]) + resp = loads(send_mock.call_args[0][0]) self.assertEqual(200, resp['headers']['status']) # Check message properties @@ -132,11 +134,11 @@ class MessagesBaseTest(base.V2Base): action = "message_get_many" body = {"queue_name": "kitkat", "message_ids": self.msg_ids} - req = test_utils.create_request(action, body, self.headers) + req = create_req(action, body, self.headers) - self.protocol.onMessage(req, False) + self.protocol.onMessage(req, in_binary) - resp = json.loads(send_mock.call_args[0][0]) + resp = loads(send_mock.call_args[0][0]) self.assertEqual(200, resp['headers']['status']) expected_ttls = set(m['ttl'] for m in sample_messages) actual_ttls = set(m['ttl'] for m in resp['body']['messages']) @@ -181,21 +183,23 @@ class MessagesBaseTest(base.V2Base): resp = json.loads(send_mock.call_args[0][0]) self.assertEqual(400, resp['headers']['status']) - def test_post_single(self): + @ddt.data(True, False) + def test_post_single(self, in_binary): sample_messages = [ {'body': {'key': 'value'}, 'ttl': 200}, ] - self._test_post(sample_messages) + self._test_post(sample_messages, in_binary=in_binary) - def test_post_multiple(self): + @ddt.data(True, False) + def test_post_multiple(self, in_binary): sample_messages = [ {'body': 239, 'ttl': 100}, {'body': {'key': 'value'}, 'ttl': 200}, {'body': [1, 3], 'ttl': 300}, ] - self._test_post(sample_messages) + self._test_post(sample_messages, in_binary=in_binary) def test_post_optional_ttl(self): messages = [{'body': 239}, diff --git a/zaqar/tests/unit/transport/websocket/v2/test_subscriptions.py b/zaqar/tests/unit/transport/websocket/v2/test_subscriptions.py index cd0eb3cba..dac740df7 100644 --- a/zaqar/tests/unit/transport/websocket/v2/test_subscriptions.py +++ b/zaqar/tests/unit/transport/websocket/v2/test_subscriptions.py @@ -17,6 +17,7 @@ import json import uuid import mock +import msgpack from zaqar.storage import errors as storage_errors from zaqar.tests.unit.transport.websocket import base @@ -26,7 +27,7 @@ from zaqar.transport.websocket import factory class SubscriptionTest(base.V1_1Base): - config_file = 'websocket_mongodb.conf' + config_file = 'websocket_mongodb_subscriptions.conf' def setUp(self): super(SubscriptionTest, self).setUp() @@ -222,6 +223,69 @@ class SubscriptionTest(base.V1_1Base): self.assertEqual(1, sender.call_count) self.assertEqual(response, json.loads(sender.call_args[0][0])) + def test_subscription_sustainable_notifications_format(self): + # NOTE(Eva-i): The websocket subscription's notifications must be + # sent in the same format, binary or text, as the format of the + # subscription creation request. + # This test checks that notifications keep their encoding format, even + # if the client suddenly starts sending requests in another format. + + # Create a subscription in binary format + action = 'subscription_create' + body = {'queue_name': 'kitkat', 'ttl': 600} + + send_mock = mock.patch.object(self.protocol, 'sendMessage') + self.addCleanup(send_mock.stop) + sender = send_mock.start() + + subscription_factory = factory.NotificationFactory( + self.transport.factory) + subscription_factory.set_subscription_url('http://localhost:1234/') + self.protocol._handler.set_subscription_factory(subscription_factory) + + req = test_utils.create_binary_request(action, body, self.headers) + self.protocol.onMessage(req, True) + self.assertTrue(self.protocol.notify_in_binary) + + [subscriber] = list( + next( + self.boot.storage.subscription_controller.list( + 'kitkat', self.project_id))) + self.addCleanup( + self.boot.storage.subscription_controller.delete, 'kitkat', + subscriber['id'], project=self.project_id) + + # Send a message in text format + webhook_notification_send_mock = mock.patch('requests.post') + self.addCleanup(webhook_notification_send_mock.stop) + webhook_notification_sender = webhook_notification_send_mock.start() + + action = "message_post" + body = {"queue_name": "kitkat", + "messages": [{'body': {'status': 'disco queen'}, 'ttl': 60}]} + req = test_utils.create_request(action, body, self.headers) + self.protocol.onMessage(req, False) + self.assertTrue(self.protocol.notify_in_binary) + + # Check that the server responded in text format to the message + # creation request + message_create_response = json.loads(sender.call_args_list[1][0][0]) + self.assertEqual(201, message_create_response['headers']['status']) + + # Fetch webhook notification that was intended to arrive to + # notification protocol's listen address. Make subscription factory + # send it as websocket notification to the client + wh_notification = webhook_notification_sender.call_args[1]['data'] + subscription_factory.send_data(wh_notification, self.protocol.proto_id) + + # Check that the server sent the websocket notification in binary + # format + self.assertEqual(3, sender.call_count) + ws_notification = msgpack.unpackb(sender.call_args_list[2][0][0], + encoding='utf-8') + self.assertEqual({'body': {'status': 'disco queen'}, 'ttl': 60, + 'queue_name': 'kitkat'}, ws_notification) + def test_list_returns_503_on_nopoolfound_exception(self): sub = self.boot.storage.subscription_controller.create( 'kitkat', '', 600, {}, project=self.project_id) diff --git a/zaqar/transport/websocket/factory.py b/zaqar/transport/websocket/factory.py index 293026291..4ec036d26 100644 --- a/zaqar/transport/websocket/factory.py +++ b/zaqar/transport/websocket/factory.py @@ -13,9 +13,11 @@ # See the License for the specific language governing permissions and # limitations under the License. +import json import uuid from autobahn.asyncio import websocket +import msgpack from zaqar.transport.websocket import protocol @@ -59,7 +61,11 @@ class NotificationFactory(object): def send_data(self, data, proto_id): instance = self.message_factory._protos.get(proto_id) if instance: - instance.sendMessage(data, False) + # NOTE(Eva-i): incoming data is encoded in JSON, let's convert it + # to MsgPack, if notification should be encoded in binary format. + if instance.notify_in_binary: + data = msgpack.packb(json.loads(data)) + instance.sendMessage(data, instance.notify_in_binary) def __call__(self): return self.protocol(self) diff --git a/zaqar/transport/websocket/protocol.py b/zaqar/transport/websocket/protocol.py index 0a1e137a9..233d14921 100644 --- a/zaqar/transport/websocket/protocol.py +++ b/zaqar/transport/websocket/protocol.py @@ -16,8 +16,10 @@ import datetime import io import json +import sys from autobahn.asyncio import websocket +import msgpack from oslo_log import log as logging from oslo_utils import timeutils import pytz @@ -35,6 +37,8 @@ except ImportError: from email.mime import message Message = message.MIMEMessage +from zaqar.i18n import _LI + LOG = logging.getLogger(__name__) @@ -60,38 +64,46 @@ class MessagingProtocol(websocket.WebSocketServerProtocol): self._loop = loop self._authentified = False self._auth_app = None + self._auth_in_binary = None self._deauth_handle = None + self.notify_in_binary = None def onConnect(self, request): - print("Client connecting: {0}".format(request.peer)) + LOG.info(_LI("Client connecting: %s"), request.peer) def onOpen(self): - print("WebSocket connection open.") + LOG.info(_LI("WebSocket connection open.")) def onMessage(self, payload, isBinary): - if isBinary: - # TODO(vkmc): Binary support will be added in the next cycle - # For now, we are returning an invalid request response - print("Binary message received: {0} bytes".format(len(payload))) - body = {'error': 'Schema validation failed.'} - resp = self._handler.create_response(400, body) - return self._send_response(resp) + # Deserialize the request try: - print("Text message received: {0}".format(payload)) - payload = json.loads(payload) - except ValueError as ex: - LOG.exception(ex) - body = {'error': str(ex)} + if isBinary: + payload = msgpack.unpackb(payload, encoding='utf-8') + else: + payload = json.loads(payload) + except Exception: + if isBinary: + pack_name = 'binary (MessagePack)' + else: + pack_name = 'text (JSON)' + ex_type, ex_value = sys.exc_info()[:2] + ex_name = ex_type.__name__ + msg = 'Can\'t decode {0} request. {1}: {2}'.format( + pack_name, ex_name, ex_value) + LOG.debug(msg) + body = {'error': msg} resp = self._handler.create_response(400, body) - return self._send_response(resp) + return self._send_response(resp, isBinary) + # Check if the request is dict if not isinstance(payload, dict): body = { - 'error': "Unexpected body type. Expected dict or dict like" + 'error': 'Unexpected body type. Expected dict or dict like.' } resp = self._handler.create_response(400, body) - return self._send_response(resp) - + return self._send_response(resp, isBinary) + # Parse the request req = self._handler.create_request(payload) + # Validate and process the request resp = self._handler.validate_request(payload, req) if resp is None: if self._auth_strategy and not self._authentified: @@ -108,17 +120,28 @@ class MessagingProtocol(websocket.WebSocketServerProtocol): body = {'error': 'Not authentified.'} resp = self._handler.create_response(403, body, req) else: - return self._authenticate(payload) + return self._authenticate(payload, isBinary) elif payload.get('action') == 'authenticate': - return self._authenticate(payload) + return self._authenticate(payload, isBinary) else: resp = self._handler.process_request(req, self) - return self._send_response(resp) + if payload.get('action') == 'subscription_create': + # NOTE(Eva-i): this will make further websocket + # notifications encoded in the same format as the last + # successful websocket subscription create request. + if resp._headers['status'] == 201: + subscriber = payload['body'].get('subscriber') + # If there is no subscriber, the user has created websocket + # subscription. + if not subscriber: + self.notify_in_binary = isBinary + return self._send_response(resp, isBinary) def onClose(self, wasClean, code, reason): - print("WebSocket connection closed: {0}".format(reason)) + LOG.info(_LI("WebSocket connection closed: %s"), reason) - def _authenticate(self, payload): + def _authenticate(self, payload, in_binary): + self._auth_in_binary = in_binary self._auth_app = self._auth_strategy(self._auth_start) env = self._fake_env.copy() env.update( @@ -150,18 +173,33 @@ class MessagingProtocol(websocket.WebSocketServerProtocol): if code != 200: body = {'error': 'Authentication failed.'} resp = self._handler.create_response(code, body, req) - self._send_response(resp) + self._send_response(resp, self._auth_in_binary) else: body = {'message': 'Authentified.'} resp = self._handler.create_response(200, body, req) - self._send_response(resp) + self._send_response(resp, self._auth_in_binary) def _header_to_env_var(self, key): return 'HTTP_%s' % key.replace('-', '_').upper() - def _send_response(self, resp): - resp_json = json.dumps(resp.get_response()) - self.sendMessage(resp_json, False) + def _send_response(self, resp, in_binary): + if in_binary: + pack_name = 'bin' + self.sendMessage(msgpack.packb(resp.get_response()), True) + else: + pack_name = 'txt' + self.sendMessage(json.dumps(resp.get_response()), False) + if LOG.isEnabledFor(logging.INFO): + api = resp._request._api + status = resp._headers['status'] + action = resp._request._action + # Dump to JSON to print body without unicode prefixes on Python 2 + body = json.dumps(resp._request._body) + var_dict = {'api': api, 'pack_name': pack_name, 'status': + status, 'action': action, 'body': body} + LOG.info(_LI('Response: API %(api)s %(pack_name)s, %(status)s. ' + 'Request: action "%(action)s", body %(body)s.'), + var_dict) class NotificationProtocol(asyncio.Protocol):