diff --git a/marconi/storage/base.py b/marconi/storage/base.py index a43ac3ad2..150d094d0 100644 --- a/marconi/storage/base.py +++ b/marconi/storage/base.py @@ -182,15 +182,16 @@ class MessageBase(ControllerBase): """ raise NotImplementedError - def get(self, queue, message_id, project=None): + def get(self, queue, message_ids, project=None): """Base method for getting a message. :param queue: Name of the queue to get the message from. :param project: Project id - :param message_id: Message ID + :param message_ids: One or more message IDs. Can be a single + string ID or a list of IDs. - :returns: Dictionary containing message data + :returns: An iterable, yielding dicts containing message details :raises: DoesNotExist """ raise NotImplementedError diff --git a/marconi/storage/mongodb/controllers.py b/marconi/storage/mongodb/controllers.py index eb9416e96..177f40b16 100644 --- a/marconi/storage/mongodb/controllers.py +++ b/marconi/storage/mongodb/controllers.py @@ -482,31 +482,34 @@ class MessageController(storage.MessageBase): yield utils.HookedCursor(messages, denormalizer) yield str(marker_id['next']) - def get(self, queue, message_id, project=None): - mid = utils.to_oid(message_id) + def get(self, queue, message_ids, project=None): + if not isinstance(message_ids, list): + message_ids = [message_ids] + + message_ids = [utils.to_oid(id) for id in message_ids] now = timeutils.utcnow() # Base query, always check expire time query = { 'q': self._get_queue_id(queue, project), 'e': {'$gt': now}, - '_id': mid + '_id': {'$in': message_ids}, } - message = self._col.find_one(query) + messages = self._col.find(query) - if message is None: - raise exceptions.MessageDoesNotExist(message_id, queue, project) + def denormalizer(msg): + oid = msg['_id'] + age = now - utils.oid_utc(oid) - oid = message['_id'] - age = now - utils.oid_utc(oid) + return { + 'id': str(oid), + 'age': age.seconds, + 'ttl': msg['t'], + 'body': msg['b'], + } - return { - 'id': str(oid), - 'age': age.seconds, - 'ttl': message['t'], - 'body': message['b'], - } + return utils.HookedCursor(messages, denormalizer) def post(self, queue, messages, client_uuid, project=None): now = timeutils.utcnow() diff --git a/marconi/storage/sqlite/controllers.py b/marconi/storage/sqlite/controllers.py index cf029ab4b..106f85819 100644 --- a/marconi/storage/sqlite/controllers.py +++ b/marconi/storage/sqlite/controllers.py @@ -143,26 +143,30 @@ class Message(base.MessageBase): ) ''') - def get(self, queue, message_id, project): - try: - content, ttl, age = self.driver.get(''' - select content, ttl, julianday() * 86400.0 - created - from Queues as Q join Messages as M - on qid = Q.id - where ttl > julianday() * 86400.0 - created - and M.id = ? and project = ? and name = ? - ''', _msgid_decode(message_id), project, queue) + def get(self, queue, message_ids, project): + if not isinstance(message_ids, list): + message_ids = [message_ids] - return { - 'id': message_id, + message_ids = ["'%s'" % _msgid_decode(id) for id in message_ids] + message_ids = ','.join(message_ids) + + sql = ''' + select M.id, content, ttl, julianday() * 86400.0 - created + from Queues as Q join Messages as M + on qid = Q.id + where ttl > julianday() * 86400.0 - created + and M.id in (%s) and project = ? and name = ? + ''' % message_ids + + records = self.driver.run(sql, project, queue) + for id, content, ttl, age in records: + yield { + 'id': id, 'ttl': ttl, 'age': int(age), 'body': content, } - except _NoResult: - raise exceptions.MessageDoesNotExist(message_id, queue, project) - def list(self, queue, project, marker=None, limit=10, echo=False, client_uuid=None): @@ -446,7 +450,11 @@ def _get_qid(driver, queue, project): # come with no special functionalities. def _msgid_encode(id): - return hex(id ^ 0x5c693a53)[2:] + try: + return hex(id ^ 0x5c693a53)[2:] + + except TypeError: + raise exceptions.MalformedID() def _msgid_decode(id): diff --git a/marconi/tests/storage/base.py b/marconi/tests/storage/base.py index 058e94360..89ff2293d 100644 --- a/marconi/tests/storage/base.py +++ b/marconi/tests/storage/base.py @@ -154,11 +154,10 @@ class MessageControllerTest(ControllerBaseTest): # Test Message Deletion self.controller.delete(queue_name, created[0], project=self.project) - # Test DoesNotExist - self.assertRaises(storage.exceptions.DoesNotExist, - self.controller.get, - queue_name, message_id=created[0], - project=self.project) + # Test does not exist + messages = self.controller.get(queue_name, message_ids=created, + project=self.project) + self.assertRaises(StopIteration, messages.next) def test_get_multi(self): _insert_fixtures(self.controller, self.queue_name, @@ -187,6 +186,18 @@ class MessageControllerTest(ControllerBaseTest): load_messages(5, self.queue_name, echo=True, project=self.project, marker=interaction.next(), client_uuid='my_uuid') + def test_get_multi_by_id(self): + messages_in = [{'ttl': 120, 'body': 0}, {'ttl': 240, 'body': 1}] + ids = self.controller.post(self.queue_name, messages_in, + project=self.project, + client_uuid='my_uuid') + + messages_out = self.controller.get(self.queue_name, ids, + project=self.project) + + for idx, message in enumerate(messages_out): + self.assertEquals(message['body'], idx) + def test_claim_effects(self): _insert_fixtures(self.controller, self.queue_name, project=self.project, client_uuid='my_uuid', num=12) @@ -210,9 +221,9 @@ class MessageControllerTest(ControllerBaseTest): project=self.project, claim=cid) - with testing.expect(storage.exceptions.DoesNotExist): + with testing.expect(StopIteration): self.controller.get(self.queue_name, msg1['id'], - project=self.project) + project=self.project).next() # Make sure such a deletion is idempotent self.controller.delete(self.queue_name, msg1['id'], @@ -235,9 +246,9 @@ class MessageControllerTest(ControllerBaseTest): project=self.project, client_uuid='my_uuid') - with testing.expect(storage.exceptions.DoesNotExist): + with testing.expect(StopIteration): self.controller.get(self.queue_name, msgid, - project=self.project) + project=self.project).next() countof = self.queue_controller.stats(self.queue_name, project=self.project) @@ -261,7 +272,7 @@ class MessageControllerTest(ControllerBaseTest): self.controller.delete(queue, bad_message_id, project) with testing.expect(exceptions.MalformedID): - self.controller.get(queue, bad_message_id, project) + self.controller.get(queue, bad_message_id, project).next() def test_bad_claim_id(self): self.queue_controller.upsert('unused', {}, '480924') diff --git a/marconi/tests/util/faulty_storage.py b/marconi/tests/util/faulty_storage.py index 049db7bf5..ea9d5ba5c 100644 --- a/marconi/tests/util/faulty_storage.py +++ b/marconi/tests/util/faulty_storage.py @@ -58,7 +58,7 @@ class MessageController(storage.MessageBase): def __init__(self, driver): pass - def get(self, queue, project=None, message_id=None, + def get(self, queue, project=None, message_ids=None, marker=None, echo=False, client_uuid=None): raise NotImplementedError() diff --git a/marconi/transport/wsgi/messages.py b/marconi/transport/wsgi/messages.py index 22a7a16a9..e581940a7 100644 --- a/marconi/transport/wsgi/messages.py +++ b/marconi/transport/wsgi/messages.py @@ -170,10 +170,10 @@ class ItemResource(object): try: message = self.message_controller.get( queue_name, - message_id=message_id, - project=project_id) + message_id, + project=project_id).next() - except storage_exceptions.DoesNotExist: + except StopIteration: raise falcon.HTTPNotFound() except Exception as ex: LOG.exception(ex)