Add LengthWrapper in put_object to honor content_length param

Closes-Bug:#1284360

Change-Id: Iec63a3fde77bb8195bfe46c764403b367999ff43
This commit is contained in:
Clay Gerrard 2014-02-24 22:41:45 -08:00
parent f4e057923c
commit 3d35a3e989
3 changed files with 62 additions and 5 deletions

View File

@ -30,6 +30,7 @@ from urlparse import urlparse, urlunparse
from time import sleep, time
from swiftclient.exceptions import ClientException, InvalidHeadersException
from swiftclient.utils import LengthWrapper
try:
from logging import NullHandler
@ -935,11 +936,8 @@ def put_object(url, token=None, container=None, name=None, contents=None,
conn.putrequest(path, headers=headers, data=chunk_reader())
else:
# Fixes https://github.com/kennethreitz/requests/issues/1648
try:
contents.len = content_length
except AttributeError:
pass
conn.putrequest(path, headers=headers, data=contents)
data = LengthWrapper(contents, content_length)
conn.putrequest(path, headers=headers, data=data)
else:
if chunk_size is not None:
warn_msg = '%s object has no \"read\" method, ignoring chunk_size'\

View File

@ -55,3 +55,21 @@ def prt_bytes(bytes, human_flag):
bytes = '%12s' % bytes
return(bytes)
class LengthWrapper(object):
def __init__(self, readable, length):
self._length = self._remaining = length
self._readable = readable
def __len__(self):
return self._length
def read(self, *args, **kwargs):
if self._remaining <= 0:
return ''
chunk = self._readable.read(
*args, **kwargs)[:self._remaining]
self._remaining -= len(chunk)
return chunk

View File

@ -15,6 +15,9 @@
import testtools
from StringIO import StringIO
import tempfile
from swiftclient import utils as u
@ -117,3 +120,41 @@ class TestPrtBytes(testtools.TestCase):
def test_overflow(self):
bytes_ = 2 ** 90
self.assertEqual('1024Y', u.prt_bytes(bytes_, True).lstrip())
class TestLengthWrapper(testtools.TestCase):
def test_stringio(self):
contents = StringIO('a' * 100)
data = u.LengthWrapper(contents, 42)
self.assertEqual(42, len(data))
read_data = ''.join(iter(data.read, ''))
self.assertEqual(42, len(read_data))
self.assertEqual('a' * 42, read_data)
def test_tempfile(self):
with tempfile.NamedTemporaryFile() as f:
f.write('a' * 100)
f.flush()
contents = open(f.name)
data = u.LengthWrapper(contents, 42)
self.assertEqual(42, len(data))
read_data = ''.join(iter(data.read, ''))
self.assertEqual(42, len(read_data))
self.assertEqual('a' * 42, read_data)
def test_segmented_file(self):
with tempfile.NamedTemporaryFile() as f:
segment_length = 1024
segments = ('a', 'b', 'c', 'd')
for c in segments:
f.write(c * segment_length)
f.flush()
for i, c in enumerate(segments):
contents = open(f.name)
contents.seek(i * segment_length)
data = u.LengthWrapper(contents, segment_length)
self.assertEqual(segment_length, len(data))
read_data = ''.join(iter(data.read, ''))
self.assertEqual(segment_length, len(read_data))
self.assertEqual(c * segment_length, read_data)