diff --git a/examples/websocket.html b/examples/websocket.html
index 37134554c..7216a2570 100644
--- a/examples/websocket.html
+++ b/examples/websocket.html
@@ -6,6 +6,11 @@
+
+
diff --git a/zaqar/tests/etc/websocket_mongodb_subscriptions.conf b/zaqar/tests/etc/websocket_mongodb_subscriptions.conf
new file mode 100644
index 000000000..56a97899d
--- /dev/null
+++ b/zaqar/tests/etc/websocket_mongodb_subscriptions.conf
@@ -0,0 +1,23 @@
+[DEFAULT]
+unreliable = True
+
+[drivers]
+
+# Transport driver to use (string value)
+transport = websocket
+
+# Storage driver to use (string value)
+message_store = mongodb
+
+[drivers:management_store:mongodb]
+
+# Mongodb Connection URI
+uri = mongodb://127.0.0.1:27017
+
+[drivers:message_store:mongodb]
+
+# Mongodb Connection URI
+uri = mongodb://127.0.0.1:27017
+
+[storage]
+message_pipeline = zaqar.notification.notifier
\ No newline at end of file
diff --git a/zaqar/tests/unit/transport/websocket/test_protocol.py b/zaqar/tests/unit/transport/websocket/test_protocol.py
index f601965f5..18dde3c8e 100644
--- a/zaqar/tests/unit/transport/websocket/test_protocol.py
+++ b/zaqar/tests/unit/transport/websocket/test_protocol.py
@@ -14,22 +14,27 @@
# under the License.
import json
+import uuid
+import ddt
import mock
from zaqar.tests.unit.transport.websocket import base
+from zaqar.tests.unit.transport.websocket import utils as test_utils
+@ddt.ddt
class TestMessagingProtocol(base.TestBase):
config_file = "websocket_mongodb.conf"
def setUp(self):
super(TestMessagingProtocol, self).setUp()
self.protocol = self.transport.factory()
- self.defaults = self.api.get_defaults()
-
- def tearDown(self):
- super(TestMessagingProtocol, self).tearDown()
+ self.project_id = 'protocol-test'
+ self.headers = {
+ 'Client-ID': str(uuid.uuid4()),
+ 'X-Project-ID': self.project_id
+ }
def test_on_message_with_invalid_input(self):
payload = u'\ufeff'
@@ -45,3 +50,45 @@ class TestMessagingProtocol(base.TestBase):
self.protocol.onMessage(payload, False)
resp = json.loads(send_mock.call_args[0][0])
self.assertEqual(400, resp['headers']['status'])
+
+ def test_on_message_with_invalid_input_binary(self):
+ dumps, loads, create_req = test_utils.get_pack_tools(binary=True)
+ send_mock = mock.Mock()
+ self.protocol.sendMessage = send_mock
+
+ # Test error response, when the request can't be deserialized.
+ req = "123"
+ self.protocol.onMessage(req, True)
+ resp = loads(send_mock.call_args[0][0])
+ self.assertEqual(400, resp['headers']['status'])
+ self.assertIn('Can\'t decode binary', resp['body']['error'])
+
+ # Test error response, when request body is not a dictionary.
+ req = dumps("Apparently, I'm not a dictionary")
+ self.protocol.onMessage(req, True)
+ resp = loads(send_mock.call_args[0][0])
+ self.assertEqual(400, resp['headers']['status'])
+ self.assertIn('Unexpected body type. Expected dict',
+ resp['body']['error'])
+
+ # Test error response, when validation fails.
+ action = 'queue_glorify'
+ body = {}
+ req = create_req(action, body, self.headers)
+ self.protocol.onMessage(req, True)
+ resp = loads(send_mock.call_args[0][0])
+ self.assertEqual(400, resp['headers']['status'])
+ self.assertEqual('queue_glorify is not a valid action',
+ resp['body']['error'])
+
+ @ddt.data(True, False)
+ def test_on_message_with_input_in_different_format(self, in_binary):
+ dumps, loads, create_req = test_utils.get_pack_tools(binary=in_binary)
+ action = 'queue_get'
+ body = {'queue_name': 'beautiful-non-existing-queue'}
+ req = create_req(action, body, self.headers)
+ send_mock = mock.Mock()
+ self.protocol.sendMessage = send_mock
+ self.protocol.onMessage(req, in_binary)
+ resp = loads(send_mock.call_args[0][0])
+ self.assertEqual(200, resp['headers']['status'])
diff --git a/zaqar/tests/unit/transport/websocket/utils.py b/zaqar/tests/unit/transport/websocket/utils.py
index 14f1a08fc..0bfa51156 100644
--- a/zaqar/tests/unit/transport/websocket/utils.py
+++ b/zaqar/tests/unit/transport/websocket/utils.py
@@ -12,8 +12,37 @@
# License for the specific language governing permissions and limitations under
# the License.
+import functools
import json
+import msgpack
def create_request(action, body, headers):
return json.dumps({"action": action, "body": body, "headers": headers})
+
+
+def create_binary_request(action, body, headers):
+ return msgpack.packb({"action": action, "body": body, "headers": headers})
+
+
+def get_pack_tools(binary=None):
+ """Get serialization tools for testing websocket transport.
+
+ :param bool binary: type of serialization tools.
+ True: binary (MessagePack) tools.
+ False: text (JSON) tools.
+ :returns: set of serialization tools needed for testing websocket
+ transport: (dumps, loads, create_request_function)
+ :rtype: tuple
+ """
+ if binary is None:
+ raise Exception("binary param is unspecified")
+ if binary:
+ dumps = msgpack.Packer(encoding='utf-8', use_bin_type=False).pack
+ loads = functools.partial(msgpack.unpackb, encoding='utf-8')
+ create_request_function = create_binary_request
+ else:
+ dumps = json.dumps
+ loads = json.loads
+ create_request_function = create_request
+ return dumps, loads, create_request_function
diff --git a/zaqar/tests/unit/transport/websocket/v2/test_auth.py b/zaqar/tests/unit/transport/websocket/v2/test_auth.py
index 5bf824603..a2c7fa3ec 100644
--- a/zaqar/tests/unit/transport/websocket/v2/test_auth.py
+++ b/zaqar/tests/unit/transport/websocket/v2/test_auth.py
@@ -16,6 +16,7 @@
import json
import uuid
+import ddt
from keystonemiddleware import auth_token
import mock
@@ -24,8 +25,8 @@ from zaqar.tests.unit.transport.websocket import base
from zaqar.tests.unit.transport.websocket import utils as test_utils
+@ddt.ddt
class AuthTest(base.V2Base):
-
config_file = "websocket_mongodb_keystone_auth.conf"
def setUp(self):
@@ -87,6 +88,7 @@ class AuthTest(base.V2Base):
msg_mock = mock.patch.object(self.protocol, 'sendMessage')
self.addCleanup(msg_mock.stop)
msg_mock = msg_mock.start()
+ self.protocol._auth_in_binary = False
self.protocol._auth_response('401 error', 'Failed')
self.assertEqual(1, msg_mock.call_count)
resp = json.loads(msg_mock.call_args[0][0])
@@ -122,6 +124,25 @@ class AuthTest(base.V2Base):
self.assertIn('cancelled', repr(handle))
self.assertNotIn('cancelled', repr(self.protocol._deauth_handle))
+ @ddt.data(True, False)
+ def test_auth_response_serialization_format(self, in_binary):
+ dumps, loads, create_req = test_utils.get_pack_tools(binary=in_binary)
+ headers = self.headers.copy()
+ headers['X-Auth-Token'] = 'mytoken1'
+ req = create_req("authenticate", {}, headers)
+
+ msg_mock = mock.patch.object(self.protocol, 'sendMessage')
+ self.addCleanup(msg_mock.stop)
+ msg_mock = msg_mock.start()
+ # Depending on onMessage method's second argument, auth response should
+ # be in binary or text format.
+ self.protocol.onMessage(req, in_binary)
+ self.assertEqual(in_binary, self.protocol._auth_in_binary)
+ self.protocol._auth_response('401 error', 'Failed')
+ self.assertEqual(1, msg_mock.call_count)
+ resp = loads(msg_mock.call_args[0][0])
+ self.assertEqual(401, resp['headers']['status'])
+
def test_signed_url(self):
send_mock = mock.Mock()
self.protocol.sendMessage = send_mock
diff --git a/zaqar/tests/unit/transport/websocket/v2/test_messages.py b/zaqar/tests/unit/transport/websocket/v2/test_messages.py
index 86639d3e6..ef82661a9 100644
--- a/zaqar/tests/unit/transport/websocket/v2/test_messages.py
+++ b/zaqar/tests/unit/transport/websocket/v2/test_messages.py
@@ -69,7 +69,7 @@ class MessagesBaseTest(base.V2Base):
resp = json.loads(send_mock.call_args[0][0])
self.assertEqual(204, resp['headers']['status'])
- def _test_post(self, sample_messages):
+ def _test_post(self, sample_messages, in_binary=False):
action = "message_post"
body = {"queue_name": "kitkat",
"messages": sample_messages}
@@ -77,11 +77,13 @@ class MessagesBaseTest(base.V2Base):
send_mock = mock.Mock()
self.protocol.sendMessage = send_mock
- req = test_utils.create_request(action, body, self.headers)
+ dumps, loads, create_req = test_utils.get_pack_tools(binary=in_binary)
- self.protocol.onMessage(req, False)
+ req = create_req(action, body, self.headers)
- resp = json.loads(send_mock.call_args[0][0])
+ self.protocol.onMessage(req, in_binary)
+
+ resp = loads(send_mock.call_args[0][0])
self.assertEqual(201, resp['headers']['status'])
self.msg_ids = resp['body']['message_ids']
self.assertEqual(len(sample_messages), len(self.msg_ids))
@@ -102,19 +104,19 @@ class MessagesBaseTest(base.V2Base):
body = {"queue_name": "kitkat",
"message_id": msg_id}
- req = test_utils.create_request(action, body, headers)
+ req = create_req(action, body, headers)
- self.protocol.onMessage(req, False)
+ self.protocol.onMessage(req, in_binary)
- resp = json.loads(send_mock.call_args[0][0])
+ resp = loads(send_mock.call_args[0][0])
self.assertEqual(404, resp['headers']['status'])
# Correct project ID
- req = test_utils.create_request(action, body, self.headers)
+ req = create_req(action, body, self.headers)
- self.protocol.onMessage(req, False)
+ self.protocol.onMessage(req, in_binary)
- resp = json.loads(send_mock.call_args[0][0])
+ resp = loads(send_mock.call_args[0][0])
self.assertEqual(200, resp['headers']['status'])
# Check message properties
@@ -132,11 +134,11 @@ class MessagesBaseTest(base.V2Base):
action = "message_get_many"
body = {"queue_name": "kitkat",
"message_ids": self.msg_ids}
- req = test_utils.create_request(action, body, self.headers)
+ req = create_req(action, body, self.headers)
- self.protocol.onMessage(req, False)
+ self.protocol.onMessage(req, in_binary)
- resp = json.loads(send_mock.call_args[0][0])
+ resp = loads(send_mock.call_args[0][0])
self.assertEqual(200, resp['headers']['status'])
expected_ttls = set(m['ttl'] for m in sample_messages)
actual_ttls = set(m['ttl'] for m in resp['body']['messages'])
@@ -181,21 +183,23 @@ class MessagesBaseTest(base.V2Base):
resp = json.loads(send_mock.call_args[0][0])
self.assertEqual(400, resp['headers']['status'])
- def test_post_single(self):
+ @ddt.data(True, False)
+ def test_post_single(self, in_binary):
sample_messages = [
{'body': {'key': 'value'}, 'ttl': 200},
]
- self._test_post(sample_messages)
+ self._test_post(sample_messages, in_binary=in_binary)
- def test_post_multiple(self):
+ @ddt.data(True, False)
+ def test_post_multiple(self, in_binary):
sample_messages = [
{'body': 239, 'ttl': 100},
{'body': {'key': 'value'}, 'ttl': 200},
{'body': [1, 3], 'ttl': 300},
]
- self._test_post(sample_messages)
+ self._test_post(sample_messages, in_binary=in_binary)
def test_post_optional_ttl(self):
messages = [{'body': 239},
diff --git a/zaqar/tests/unit/transport/websocket/v2/test_subscriptions.py b/zaqar/tests/unit/transport/websocket/v2/test_subscriptions.py
index cd0eb3cba..dac740df7 100644
--- a/zaqar/tests/unit/transport/websocket/v2/test_subscriptions.py
+++ b/zaqar/tests/unit/transport/websocket/v2/test_subscriptions.py
@@ -17,6 +17,7 @@ import json
import uuid
import mock
+import msgpack
from zaqar.storage import errors as storage_errors
from zaqar.tests.unit.transport.websocket import base
@@ -26,7 +27,7 @@ from zaqar.transport.websocket import factory
class SubscriptionTest(base.V1_1Base):
- config_file = 'websocket_mongodb.conf'
+ config_file = 'websocket_mongodb_subscriptions.conf'
def setUp(self):
super(SubscriptionTest, self).setUp()
@@ -222,6 +223,69 @@ class SubscriptionTest(base.V1_1Base):
self.assertEqual(1, sender.call_count)
self.assertEqual(response, json.loads(sender.call_args[0][0]))
+ def test_subscription_sustainable_notifications_format(self):
+ # NOTE(Eva-i): The websocket subscription's notifications must be
+ # sent in the same format, binary or text, as the format of the
+ # subscription creation request.
+ # This test checks that notifications keep their encoding format, even
+ # if the client suddenly starts sending requests in another format.
+
+ # Create a subscription in binary format
+ action = 'subscription_create'
+ body = {'queue_name': 'kitkat', 'ttl': 600}
+
+ send_mock = mock.patch.object(self.protocol, 'sendMessage')
+ self.addCleanup(send_mock.stop)
+ sender = send_mock.start()
+
+ subscription_factory = factory.NotificationFactory(
+ self.transport.factory)
+ subscription_factory.set_subscription_url('http://localhost:1234/')
+ self.protocol._handler.set_subscription_factory(subscription_factory)
+
+ req = test_utils.create_binary_request(action, body, self.headers)
+ self.protocol.onMessage(req, True)
+ self.assertTrue(self.protocol.notify_in_binary)
+
+ [subscriber] = list(
+ next(
+ self.boot.storage.subscription_controller.list(
+ 'kitkat', self.project_id)))
+ self.addCleanup(
+ self.boot.storage.subscription_controller.delete, 'kitkat',
+ subscriber['id'], project=self.project_id)
+
+ # Send a message in text format
+ webhook_notification_send_mock = mock.patch('requests.post')
+ self.addCleanup(webhook_notification_send_mock.stop)
+ webhook_notification_sender = webhook_notification_send_mock.start()
+
+ action = "message_post"
+ body = {"queue_name": "kitkat",
+ "messages": [{'body': {'status': 'disco queen'}, 'ttl': 60}]}
+ req = test_utils.create_request(action, body, self.headers)
+ self.protocol.onMessage(req, False)
+ self.assertTrue(self.protocol.notify_in_binary)
+
+ # Check that the server responded in text format to the message
+ # creation request
+ message_create_response = json.loads(sender.call_args_list[1][0][0])
+ self.assertEqual(201, message_create_response['headers']['status'])
+
+ # Fetch webhook notification that was intended to arrive to
+ # notification protocol's listen address. Make subscription factory
+ # send it as websocket notification to the client
+ wh_notification = webhook_notification_sender.call_args[1]['data']
+ subscription_factory.send_data(wh_notification, self.protocol.proto_id)
+
+ # Check that the server sent the websocket notification in binary
+ # format
+ self.assertEqual(3, sender.call_count)
+ ws_notification = msgpack.unpackb(sender.call_args_list[2][0][0],
+ encoding='utf-8')
+ self.assertEqual({'body': {'status': 'disco queen'}, 'ttl': 60,
+ 'queue_name': 'kitkat'}, ws_notification)
+
def test_list_returns_503_on_nopoolfound_exception(self):
sub = self.boot.storage.subscription_controller.create(
'kitkat', '', 600, {}, project=self.project_id)
diff --git a/zaqar/transport/websocket/factory.py b/zaqar/transport/websocket/factory.py
index 293026291..4ec036d26 100644
--- a/zaqar/transport/websocket/factory.py
+++ b/zaqar/transport/websocket/factory.py
@@ -13,9 +13,11 @@
# See the License for the specific language governing permissions and
# limitations under the License.
+import json
import uuid
from autobahn.asyncio import websocket
+import msgpack
from zaqar.transport.websocket import protocol
@@ -59,7 +61,11 @@ class NotificationFactory(object):
def send_data(self, data, proto_id):
instance = self.message_factory._protos.get(proto_id)
if instance:
- instance.sendMessage(data, False)
+ # NOTE(Eva-i): incoming data is encoded in JSON, let's convert it
+ # to MsgPack, if notification should be encoded in binary format.
+ if instance.notify_in_binary:
+ data = msgpack.packb(json.loads(data))
+ instance.sendMessage(data, instance.notify_in_binary)
def __call__(self):
return self.protocol(self)
diff --git a/zaqar/transport/websocket/protocol.py b/zaqar/transport/websocket/protocol.py
index 0a1e137a9..233d14921 100644
--- a/zaqar/transport/websocket/protocol.py
+++ b/zaqar/transport/websocket/protocol.py
@@ -16,8 +16,10 @@
import datetime
import io
import json
+import sys
from autobahn.asyncio import websocket
+import msgpack
from oslo_log import log as logging
from oslo_utils import timeutils
import pytz
@@ -35,6 +37,8 @@ except ImportError:
from email.mime import message
Message = message.MIMEMessage
+from zaqar.i18n import _LI
+
LOG = logging.getLogger(__name__)
@@ -60,38 +64,46 @@ class MessagingProtocol(websocket.WebSocketServerProtocol):
self._loop = loop
self._authentified = False
self._auth_app = None
+ self._auth_in_binary = None
self._deauth_handle = None
+ self.notify_in_binary = None
def onConnect(self, request):
- print("Client connecting: {0}".format(request.peer))
+ LOG.info(_LI("Client connecting: %s"), request.peer)
def onOpen(self):
- print("WebSocket connection open.")
+ LOG.info(_LI("WebSocket connection open."))
def onMessage(self, payload, isBinary):
- if isBinary:
- # TODO(vkmc): Binary support will be added in the next cycle
- # For now, we are returning an invalid request response
- print("Binary message received: {0} bytes".format(len(payload)))
- body = {'error': 'Schema validation failed.'}
- resp = self._handler.create_response(400, body)
- return self._send_response(resp)
+ # Deserialize the request
try:
- print("Text message received: {0}".format(payload))
- payload = json.loads(payload)
- except ValueError as ex:
- LOG.exception(ex)
- body = {'error': str(ex)}
+ if isBinary:
+ payload = msgpack.unpackb(payload, encoding='utf-8')
+ else:
+ payload = json.loads(payload)
+ except Exception:
+ if isBinary:
+ pack_name = 'binary (MessagePack)'
+ else:
+ pack_name = 'text (JSON)'
+ ex_type, ex_value = sys.exc_info()[:2]
+ ex_name = ex_type.__name__
+ msg = 'Can\'t decode {0} request. {1}: {2}'.format(
+ pack_name, ex_name, ex_value)
+ LOG.debug(msg)
+ body = {'error': msg}
resp = self._handler.create_response(400, body)
- return self._send_response(resp)
+ return self._send_response(resp, isBinary)
+ # Check if the request is dict
if not isinstance(payload, dict):
body = {
- 'error': "Unexpected body type. Expected dict or dict like"
+ 'error': 'Unexpected body type. Expected dict or dict like.'
}
resp = self._handler.create_response(400, body)
- return self._send_response(resp)
-
+ return self._send_response(resp, isBinary)
+ # Parse the request
req = self._handler.create_request(payload)
+ # Validate and process the request
resp = self._handler.validate_request(payload, req)
if resp is None:
if self._auth_strategy and not self._authentified:
@@ -108,17 +120,28 @@ class MessagingProtocol(websocket.WebSocketServerProtocol):
body = {'error': 'Not authentified.'}
resp = self._handler.create_response(403, body, req)
else:
- return self._authenticate(payload)
+ return self._authenticate(payload, isBinary)
elif payload.get('action') == 'authenticate':
- return self._authenticate(payload)
+ return self._authenticate(payload, isBinary)
else:
resp = self._handler.process_request(req, self)
- return self._send_response(resp)
+ if payload.get('action') == 'subscription_create':
+ # NOTE(Eva-i): this will make further websocket
+ # notifications encoded in the same format as the last
+ # successful websocket subscription create request.
+ if resp._headers['status'] == 201:
+ subscriber = payload['body'].get('subscriber')
+ # If there is no subscriber, the user has created websocket
+ # subscription.
+ if not subscriber:
+ self.notify_in_binary = isBinary
+ return self._send_response(resp, isBinary)
def onClose(self, wasClean, code, reason):
- print("WebSocket connection closed: {0}".format(reason))
+ LOG.info(_LI("WebSocket connection closed: %s"), reason)
- def _authenticate(self, payload):
+ def _authenticate(self, payload, in_binary):
+ self._auth_in_binary = in_binary
self._auth_app = self._auth_strategy(self._auth_start)
env = self._fake_env.copy()
env.update(
@@ -150,18 +173,33 @@ class MessagingProtocol(websocket.WebSocketServerProtocol):
if code != 200:
body = {'error': 'Authentication failed.'}
resp = self._handler.create_response(code, body, req)
- self._send_response(resp)
+ self._send_response(resp, self._auth_in_binary)
else:
body = {'message': 'Authentified.'}
resp = self._handler.create_response(200, body, req)
- self._send_response(resp)
+ self._send_response(resp, self._auth_in_binary)
def _header_to_env_var(self, key):
return 'HTTP_%s' % key.replace('-', '_').upper()
- def _send_response(self, resp):
- resp_json = json.dumps(resp.get_response())
- self.sendMessage(resp_json, False)
+ def _send_response(self, resp, in_binary):
+ if in_binary:
+ pack_name = 'bin'
+ self.sendMessage(msgpack.packb(resp.get_response()), True)
+ else:
+ pack_name = 'txt'
+ self.sendMessage(json.dumps(resp.get_response()), False)
+ if LOG.isEnabledFor(logging.INFO):
+ api = resp._request._api
+ status = resp._headers['status']
+ action = resp._request._action
+ # Dump to JSON to print body without unicode prefixes on Python 2
+ body = json.dumps(resp._request._body)
+ var_dict = {'api': api, 'pack_name': pack_name, 'status':
+ status, 'action': action, 'body': body}
+ LOG.info(_LI('Response: API %(api)s %(pack_name)s, %(status)s. '
+ 'Request: action "%(action)s", body %(body)s.'),
+ var_dict)
class NotificationProtocol(asyncio.Protocol):