diff --git a/swiftclient/client.py b/swiftclient/client.py index 0c6d8d34..87b3e5a8 100644 --- a/swiftclient/client.py +++ b/swiftclient/client.py @@ -138,6 +138,37 @@ def encode_meta_headers(headers): return ret +class _ObjectBody(object): + """ + Readable and iterable object body response wrapper. + """ + + def __init__(self, resp, chunk_size): + """ + Wrap the underlying response + + :param resp: the response to wrap + :param chunk_size: number of bytes to return each iteration/next call + """ + self.resp = resp + self.chunk_size = chunk_size + + def read(self, length=None): + return self.resp.read(length) + + def __iter__(self): + return self + + def next(self): + buf = self.resp.read(self.chunk_size) + if not buf: + raise StopIteration() + return buf + + def __next__(self): + return self.next() + + class HTTPConnection(object): def __init__(self, url, proxy=None, cacert=None, insecure=False, ssl_compression=False, default_user_agent=None): @@ -874,13 +905,7 @@ def get_object(url, token, container, name, http_conn=None, http_reason=resp.reason, http_response_content=body) if resp_chunk_size: - - def _object_body(): - buf = resp.read(resp_chunk_size) - while buf: - yield buf - buf = resp.read(resp_chunk_size) - object_body = _object_body() + object_body = _ObjectBody(resp, resp_chunk_size) else: object_body = resp.read() http_log(('%s%s' % (url.replace(parsed.path, ''), path), method,), diff --git a/tests/functional/test_swiftclient.py b/tests/functional/test_swiftclient.py index f5d14aa3..ea44169e 100644 --- a/tests/functional/test_swiftclient.py +++ b/tests/functional/test_swiftclient.py @@ -16,7 +16,6 @@ import os import testtools import time -import types from io import BytesIO from six.moves import configparser @@ -256,8 +255,24 @@ class TestFunctional(testtools.TestCase): hdrs, body = self.conn.get_object( self.containername, self.objectname, resp_chunk_size=10) - self.assertTrue(isinstance(body, types.GeneratorType)) - self.assertEqual(self.test_data, b''.join(body)) + downloaded_contents = b'' + while True: + try: + chunk = next(body) + except StopIteration: + break + downloaded_contents += chunk + self.assertEqual(self.test_data, downloaded_contents) + + # Download in chunks, should also work with read + hdrs, body = self.conn.get_object( + self.containername, self.objectname, + resp_chunk_size=10) + num_bytes = 5 + downloaded_contents = body.read(num_bytes) + self.assertEqual(num_bytes, len(downloaded_contents)) + downloaded_contents += body.read() + self.assertEqual(self.test_data, downloaded_contents) def test_post_account(self): self.conn.post_account({'x-account-meta-data': 'Something'}) diff --git a/tests/unit/test_swiftclient.py b/tests/unit/test_swiftclient.py index 9ebcff55..5719df25 100644 --- a/tests/unit/test_swiftclient.py +++ b/tests/unit/test_swiftclient.py @@ -655,6 +655,41 @@ class TestGetObject(MockHttpTest): }), ]) + def test_chunk_size_read_method(self): + conn = c.Connection('http://auth.url/', 'some_user', 'some_key') + with mock.patch('swiftclient.client.get_auth_1_0') as mock_get_auth: + mock_get_auth.return_value = ('http://auth.url/', 'tToken') + c.http_connection = self.fake_http_connection(200, body='abcde') + __, resp = conn.get_object('asdf', 'asdf', resp_chunk_size=3) + self.assertTrue(hasattr(resp, 'read')) + self.assertEquals(resp.read(3), 'abc') + self.assertEquals(resp.read(None), 'de') + self.assertEquals(resp.read(), '') + + def test_chunk_size_iter(self): + conn = c.Connection('http://auth.url/', 'some_user', 'some_key') + with mock.patch('swiftclient.client.get_auth_1_0') as mock_get_auth: + mock_get_auth.return_value = ('http://auth.url/', 'tToken') + c.http_connection = self.fake_http_connection(200, body='abcde') + __, resp = conn.get_object('asdf', 'asdf', resp_chunk_size=3) + self.assertTrue(hasattr(resp, 'next')) + self.assertEquals(next(resp), 'abc') + self.assertEquals(next(resp), 'de') + self.assertRaises(StopIteration, next, resp) + + def test_chunk_size_read_and_iter(self): + conn = c.Connection('http://auth.url/', 'some_user', 'some_key') + with mock.patch('swiftclient.client.get_auth_1_0') as mock_get_auth: + mock_get_auth.return_value = ('http://auth.url/', 'tToken') + c.http_connection = self.fake_http_connection(200, body='abcdef') + __, resp = conn.get_object('asdf', 'asdf', resp_chunk_size=2) + self.assertTrue(hasattr(resp, 'read')) + self.assertEquals(resp.read(3), 'abc') + self.assertEquals(next(resp), 'de') + self.assertEquals(resp.read(), 'f') + self.assertRaises(StopIteration, next, resp) + self.assertEquals(resp.read(), '') + class TestHeadObject(MockHttpTest): diff --git a/tests/unit/utils.py b/tests/unit/utils.py index 88d6d129..00569931 100644 --- a/tests/unit/utils.py +++ b/tests/unit/utils.py @@ -154,7 +154,10 @@ def fake_http_connect(*code_iter, **kwargs): sleep(0.1) return ' ' rv = self.body[:amt] - self.body = self.body[amt:] + if amt is not None: + self.body = self.body[amt:] + else: + self.body = '' return rv def send(self, amt=None):