Add LengthWrapper in put_object to honor content_length param
Closes-Bug:#1284360 Change-Id: Iec63a3fde77bb8195bfe46c764403b367999ff43
This commit is contained in:
parent
f4e057923c
commit
3d35a3e989
@ -30,6 +30,7 @@ from urlparse import urlparse, urlunparse
|
||||
from time import sleep, time
|
||||
|
||||
from swiftclient.exceptions import ClientException, InvalidHeadersException
|
||||
from swiftclient.utils import LengthWrapper
|
||||
|
||||
try:
|
||||
from logging import NullHandler
|
||||
@ -935,11 +936,8 @@ def put_object(url, token=None, container=None, name=None, contents=None,
|
||||
conn.putrequest(path, headers=headers, data=chunk_reader())
|
||||
else:
|
||||
# Fixes https://github.com/kennethreitz/requests/issues/1648
|
||||
try:
|
||||
contents.len = content_length
|
||||
except AttributeError:
|
||||
pass
|
||||
conn.putrequest(path, headers=headers, data=contents)
|
||||
data = LengthWrapper(contents, content_length)
|
||||
conn.putrequest(path, headers=headers, data=data)
|
||||
else:
|
||||
if chunk_size is not None:
|
||||
warn_msg = '%s object has no \"read\" method, ignoring chunk_size'\
|
||||
|
@ -55,3 +55,21 @@ def prt_bytes(bytes, human_flag):
|
||||
bytes = '%12s' % bytes
|
||||
|
||||
return(bytes)
|
||||
|
||||
|
||||
class LengthWrapper(object):
|
||||
|
||||
def __init__(self, readable, length):
|
||||
self._length = self._remaining = length
|
||||
self._readable = readable
|
||||
|
||||
def __len__(self):
|
||||
return self._length
|
||||
|
||||
def read(self, *args, **kwargs):
|
||||
if self._remaining <= 0:
|
||||
return ''
|
||||
chunk = self._readable.read(
|
||||
*args, **kwargs)[:self._remaining]
|
||||
self._remaining -= len(chunk)
|
||||
return chunk
|
||||
|
@ -15,6 +15,9 @@
|
||||
|
||||
import testtools
|
||||
|
||||
from StringIO import StringIO
|
||||
import tempfile
|
||||
|
||||
from swiftclient import utils as u
|
||||
|
||||
|
||||
@ -117,3 +120,41 @@ class TestPrtBytes(testtools.TestCase):
|
||||
def test_overflow(self):
|
||||
bytes_ = 2 ** 90
|
||||
self.assertEqual('1024Y', u.prt_bytes(bytes_, True).lstrip())
|
||||
|
||||
|
||||
class TestLengthWrapper(testtools.TestCase):
|
||||
|
||||
def test_stringio(self):
|
||||
contents = StringIO('a' * 100)
|
||||
data = u.LengthWrapper(contents, 42)
|
||||
self.assertEqual(42, len(data))
|
||||
read_data = ''.join(iter(data.read, ''))
|
||||
self.assertEqual(42, len(read_data))
|
||||
self.assertEqual('a' * 42, read_data)
|
||||
|
||||
def test_tempfile(self):
|
||||
with tempfile.NamedTemporaryFile() as f:
|
||||
f.write('a' * 100)
|
||||
f.flush()
|
||||
contents = open(f.name)
|
||||
data = u.LengthWrapper(contents, 42)
|
||||
self.assertEqual(42, len(data))
|
||||
read_data = ''.join(iter(data.read, ''))
|
||||
self.assertEqual(42, len(read_data))
|
||||
self.assertEqual('a' * 42, read_data)
|
||||
|
||||
def test_segmented_file(self):
|
||||
with tempfile.NamedTemporaryFile() as f:
|
||||
segment_length = 1024
|
||||
segments = ('a', 'b', 'c', 'd')
|
||||
for c in segments:
|
||||
f.write(c * segment_length)
|
||||
f.flush()
|
||||
for i, c in enumerate(segments):
|
||||
contents = open(f.name)
|
||||
contents.seek(i * segment_length)
|
||||
data = u.LengthWrapper(contents, segment_length)
|
||||
self.assertEqual(segment_length, len(data))
|
||||
read_data = ''.join(iter(data.read, ''))
|
||||
self.assertEqual(segment_length, len(read_data))
|
||||
self.assertEqual(c * segment_length, read_data)
|
||||
|
Loading…
Reference in New Issue
Block a user