From d4157ce5b5eeeebb3516092de995cee20025a5c1 Mon Sep 17 00:00:00 2001 From: Tim Burke Date: Wed, 23 Sep 2015 10:42:43 -0700 Subject: [PATCH] Retry file uploads via SwiftService When we introduced LengthWrapper, we neglected to make it resettable. As a result, upload failures result in errors like: put_object(...) failure and no ability to reset contents for reupload. Now, LengthWrappers will be resettable if their _readable has seek/tell. Related-Change: I6c8bc1366dfb591a26d934a30cd21c9e6b9a04ce Change-Id: I21f43f06e8c78b24d1fc081efedf2687942e042f --- swiftclient/client.py | 4 ++- swiftclient/utils.py | 32 +++++++++++++++--- tests/unit/test_swiftclient.py | 62 ++++++++++++++++++++++++++-------- tests/unit/test_utils.py | 26 +++++++++++--- 4 files changed, 101 insertions(+), 23 deletions(-) diff --git a/swiftclient/client.py b/swiftclient/client.py index aba986e6..172f5290 100644 --- a/swiftclient/client.py +++ b/swiftclient/client.py @@ -1546,10 +1546,12 @@ class Connection(object): if self.retries > 0: tell = getattr(contents, 'tell', None) seek = getattr(contents, 'seek', None) + reset = getattr(contents, 'reset', None) if tell and seek: orig_pos = tell() reset_func = lambda *a, **k: seek(orig_pos) - + elif reset: + reset_func = reset return self._retry(reset_func, put_object, container, obj, contents, content_length=content_length, etag=etag, chunk_size=chunk_size, content_type=content_type, diff --git a/swiftclient/utils.py b/swiftclient/utils.py index 9d94b6b5..ef65bbba 100644 --- a/swiftclient/utils.py +++ b/swiftclient/utils.py @@ -202,27 +202,36 @@ class LengthWrapper(object): def __init__(self, readable, length, md5=False): """ :param readable: The filelike object to read from. - :param length: The maximum amount of content to that can be read from + :param length: The maximum amount of content that can be read from the filelike object before it is simulated to be empty. :param md5: Flag to enable calculating the MD5 of the content as it is read. """ - self.md5sum = hashlib.md5() if md5 else NoopMD5() + self._md5 = md5 + self._reset_md5() self._length = self._remaining = length self._readable = readable + self._can_reset = all(hasattr(readable, attr) + for attr in ('seek', 'tell')) + if self._can_reset: + self._start = readable.tell() def __len__(self): return self._length + def _reset_md5(self): + self.md5sum = hashlib.md5() if self._md5 else NoopMD5() + def get_md5sum(self): return self.md5sum.hexdigest() - def read(self, *args, **kwargs): + def read(self, size=-1): if self._remaining <= 0: return '' - chunk = self._readable.read(*args, **kwargs)[:self._remaining] + to_read = self._remaining if size < 0 else min(size, self._remaining) + chunk = self._readable.read(to_read) self._remaining -= len(chunk) try: @@ -232,6 +241,21 @@ class LengthWrapper(object): return chunk + @property + def reset(self): + if self._can_reset: + return self._reset + raise AttributeError("%r object has no attribute 'reset'" % + type(self).__name__) + + def _reset(self, *args, **kwargs): + if not self._can_reset: + raise TypeError('%r object cannot be reset; needs both seek and ' + 'tell methods' % type(self._readable).__name__) + self._readable.seek(self._start) + self._reset_md5() + self._remaining = self._length + def iter_wrapper(iterable): for chunk in iterable: diff --git a/tests/unit/test_swiftclient.py b/tests/unit/test_swiftclient.py index 60f65c92..03f49a67 100644 --- a/tests/unit/test_swiftclient.py +++ b/tests/unit/test_swiftclient.py @@ -17,6 +17,7 @@ import logging import mock import six import socket +import string import testtools import warnings import tempfile @@ -1774,23 +1775,24 @@ class TestConnection(MockHttpTest): class LocalContents(object): def __init__(self, tell_value=0): - self.already_read = False + self.data = six.BytesIO(string.ascii_letters.encode() * 10) + self.data.seek(tell_value) + self.reads = [] self.seeks = [] - self.tell_value = tell_value + self.tells = [] def tell(self): - return self.tell_value + self.tells.append(self.data.tell()) + return self.tells[-1] - def seek(self, position): - self.seeks.append(position) - self.already_read = False + def seek(self, position, mode=0): + self.seeks.append((position, mode)) + self.data.seek(position, mode) def read(self, size=-1): - if self.already_read: - return '' - else: - self.already_read = True - return 'abcdef' + read_data = self.data.read(size) + self.reads.append((size, read_data)) + return read_data class LocalConnection(object): @@ -1801,7 +1803,7 @@ class TestConnection(MockHttpTest): self.port = parsed_url.netloc def putrequest(self, *args, **kwargs): - self.send() + self.send('PUT', *args, **kwargs) def putheader(self, *args, **kwargs): return @@ -1810,6 +1812,13 @@ class TestConnection(MockHttpTest): return def send(self, *args, **kwargs): + data = kwargs.get('data') + if data is not None: + if hasattr(data, 'read'): + data.read() + else: + for datum in data: + pass raise socket.error('oops') def request(self, *args, **kwargs): @@ -1844,7 +1853,12 @@ class TestConnection(MockHttpTest): conn.put_object('c', 'o', contents) except socket.error as err: exc = err - self.assertEqual(contents.seeks, [0]) + self.assertEqual(contents.tells, [0]) + self.assertEqual(contents.seeks, [(0, 0)]) + # four reads: two in the initial pass, two in the retry + self.assertEqual(4, len(contents.reads)) + self.assertEqual((65536, b''), contents.reads[1]) + self.assertEqual((65536, b''), contents.reads[3]) self.assertEqual(str(exc), 'oops') contents = LocalContents(tell_value=123) @@ -1853,9 +1867,29 @@ class TestConnection(MockHttpTest): conn.put_object('c', 'o', contents) except socket.error as err: exc = err - self.assertEqual(contents.seeks, [123]) + self.assertEqual(contents.tells, [123]) + self.assertEqual(contents.seeks, [(123, 0)]) + # four reads: two in the initial pass, two in the retry + self.assertEqual(4, len(contents.reads)) + self.assertEqual((65536, b''), contents.reads[1]) + self.assertEqual((65536, b''), contents.reads[3]) self.assertEqual(str(exc), 'oops') + contents = LocalContents(tell_value=123) + wrapped_contents = swiftclient.utils.LengthWrapper( + contents, 6, md5=True) + exc = None + try: + conn.put_object('c', 'o', wrapped_contents) + except socket.error as err: + exc = err + self.assertEqual(contents.tells, [123]) + self.assertEqual(contents.seeks, [(123, 0)]) + self.assertEqual(contents.reads, [(6, b'tuvwxy')] * 2) + self.assertEqual(str(exc), 'oops') + self.assertEqual(md5(b'tuvwxy').hexdigest(), + wrapped_contents.get_md5sum()) + contents = LocalContents() contents.tell = None exc = None diff --git a/tests/unit/test_utils.py b/tests/unit/test_utils.py index 4faac6d8..3439f4a2 100644 --- a/tests/unit/test_utils.py +++ b/tests/unit/test_utils.py @@ -219,9 +219,10 @@ class TestReadableToIterable(testtools.TestCase): class TestLengthWrapper(testtools.TestCase): def test_stringio(self): - contents = six.StringIO(u'a' * 100) + contents = six.StringIO(u'a' * 50 + u'b' * 50) + contents.seek(22) data = u.LengthWrapper(contents, 42, True) - s = u'a' * 42 + s = u'a' * 28 + u'b' * 14 read_data = u''.join(iter(data.read, '')) self.assertEqual(42, len(data)) @@ -229,10 +230,19 @@ class TestLengthWrapper(testtools.TestCase): self.assertEqual(s, read_data) self.assertEqual(md5(s.encode()).hexdigest(), data.get_md5sum()) + data.reset() + self.assertEqual(md5().hexdigest(), data.get_md5sum()) + + read_data = u''.join(iter(data.read, '')) + self.assertEqual(42, len(read_data)) + self.assertEqual(s, read_data) + self.assertEqual(md5(s.encode()).hexdigest(), data.get_md5sum()) + def test_bytesio(self): - contents = six.BytesIO(b'a' * 100) + contents = six.BytesIO(b'a' * 50 + b'b' * 50) + contents.seek(22) data = u.LengthWrapper(contents, 42, True) - s = b'a' * 42 + s = b'a' * 28 + b'b' * 14 read_data = b''.join(iter(data.read, '')) self.assertEqual(42, len(data)) @@ -272,3 +282,11 @@ class TestLengthWrapper(testtools.TestCase): self.assertEqual(segment_length, len(read_data)) self.assertEqual(s, read_data) self.assertEqual(md5(s).hexdigest(), data.get_md5sum()) + + data.reset() + self.assertEqual(md5().hexdigest(), data.get_md5sum()) + read_data = b''.join(iter(data.read, '')) + self.assertEqual(segment_length, len(data)) + self.assertEqual(segment_length, len(read_data)) + self.assertEqual(s, read_data) + self.assertEqual(md5(s).hexdigest(), data.get_md5sum())