diff --git a/swiftclient/client.py b/swiftclient/client.py index c41f35a..589b728 100644 --- a/swiftclient/client.py +++ b/swiftclient/client.py @@ -30,6 +30,7 @@ from urlparse import urlparse, urlunparse from time import sleep, time from swiftclient.exceptions import ClientException, InvalidHeadersException +from swiftclient.utils import LengthWrapper try: from logging import NullHandler @@ -935,11 +936,8 @@ def put_object(url, token=None, container=None, name=None, contents=None, conn.putrequest(path, headers=headers, data=chunk_reader()) else: # Fixes https://github.com/kennethreitz/requests/issues/1648 - try: - contents.len = content_length - except AttributeError: - pass - conn.putrequest(path, headers=headers, data=contents) + data = LengthWrapper(contents, content_length) + conn.putrequest(path, headers=headers, data=data) else: if chunk_size is not None: warn_msg = '%s object has no \"read\" method, ignoring chunk_size'\ diff --git a/swiftclient/utils.py b/swiftclient/utils.py index a038dcc..5095f9d 100644 --- a/swiftclient/utils.py +++ b/swiftclient/utils.py @@ -55,3 +55,21 @@ def prt_bytes(bytes, human_flag): bytes = '%12s' % bytes return(bytes) + + +class LengthWrapper(object): + + def __init__(self, readable, length): + self._length = self._remaining = length + self._readable = readable + + def __len__(self): + return self._length + + def read(self, *args, **kwargs): + if self._remaining <= 0: + return '' + chunk = self._readable.read( + *args, **kwargs)[:self._remaining] + self._remaining -= len(chunk) + return chunk diff --git a/tests/test_utils.py b/tests/test_utils.py index 33d9467..22af6ac 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -15,6 +15,9 @@ import testtools +from StringIO import StringIO +import tempfile + from swiftclient import utils as u @@ -117,3 +120,41 @@ class TestPrtBytes(testtools.TestCase): def test_overflow(self): bytes_ = 2 ** 90 self.assertEqual('1024Y', u.prt_bytes(bytes_, True).lstrip()) + + +class TestLengthWrapper(testtools.TestCase): + + def test_stringio(self): + contents = StringIO('a' * 100) + data = u.LengthWrapper(contents, 42) + self.assertEqual(42, len(data)) + read_data = ''.join(iter(data.read, '')) + self.assertEqual(42, len(read_data)) + self.assertEqual('a' * 42, read_data) + + def test_tempfile(self): + with tempfile.NamedTemporaryFile() as f: + f.write('a' * 100) + f.flush() + contents = open(f.name) + data = u.LengthWrapper(contents, 42) + self.assertEqual(42, len(data)) + read_data = ''.join(iter(data.read, '')) + self.assertEqual(42, len(read_data)) + self.assertEqual('a' * 42, read_data) + + def test_segmented_file(self): + with tempfile.NamedTemporaryFile() as f: + segment_length = 1024 + segments = ('a', 'b', 'c', 'd') + for c in segments: + f.write(c * segment_length) + f.flush() + for i, c in enumerate(segments): + contents = open(f.name) + contents.seek(i * segment_length) + data = u.LengthWrapper(contents, segment_length) + self.assertEqual(segment_length, len(data)) + read_data = ''.join(iter(data.read, '')) + self.assertEqual(segment_length, len(read_data)) + self.assertEqual(c * segment_length, read_data)