This patch fixes downloading files to stdout.

This patch fixes downloading files to stdout and modifies
_SwiftReader to operate as an iterator that performs file
checks at the end of iteration as well as a context manager.
File verification checks have been removed from __exit__
and added to __iter__.

Change-Id: I3250bdeeef8484a9122c4b5b854756a7c8f8731e
Closes-Bug: 1395922
Closes-Bug: 1387376
This commit is contained in:
Joel Wright 2015-01-04 21:14:02 +00:00
parent 7709fea51e
commit bd42c2b00d
7 changed files with 164 additions and 97 deletions

@ -67,6 +67,17 @@ class OutputManager(object):
self.error_print_pool.__exit__(exc_type, exc_value, traceback) self.error_print_pool.__exit__(exc_type, exc_value, traceback)
self.print_pool.__exit__(exc_type, exc_value, traceback) self.print_pool.__exit__(exc_type, exc_value, traceback)
def print_raw(self, data):
self.print_pool.submit(self._write, data, self.print_stream)
def _write(self, data, stream):
if six.PY3:
stream.buffer.write(data)
stream.flush()
if six.PY2:
stream.write(data)
stream.flush()
def print_msg(self, msg, *fmt_args): def print_msg(self, msg, *fmt_args):
if fmt_args: if fmt_args:
msg = msg % fmt_args msg = msg % fmt_args

@ -315,10 +315,15 @@ class _SwiftReader(object):
except ValueError: except ValueError:
raise SwiftError('content-length header must be an integer') raise SwiftError('content-length header must be an integer')
def __enter__(self): def __iter__(self):
return self for chunk in self._body:
if self._actual_md5:
self._actual_md5.update(chunk)
self._actual_read += len(chunk)
yield chunk
self._check_contents()
def __exit__(self, exc_type, exc_val, exc_tb): def _check_contents(self):
if self._actual_md5 and self._expected_etag: if self._actual_md5 and self._expected_etag:
etag = self._actual_md5.hexdigest() etag = self._actual_md5.hexdigest()
if etag != self._expected_etag: if etag != self._expected_etag:
@ -333,13 +338,6 @@ class _SwiftReader(object):
self._path, self._actual_read, self._path, self._actual_read,
self._content_length)) self._content_length))
def buffer(self):
for chunk in self._body:
if self._actual_md5:
self._actual_md5.update(chunk)
self._actual_read += len(chunk)
yield chunk
def bytes_read(self): def bytes_read(self):
return self._actual_read return self._actual_read
@ -999,18 +997,28 @@ class SwiftService(object):
try: try:
start_time = time() start_time = time()
headers, body = \ headers, body = \
conn.get_object(container, obj, resp_chunk_size=65536, conn.get_object(container, obj, resp_chunk_size=65536,
headers=req_headers, headers=req_headers,
response_dict=results_dict) response_dict=results_dict)
headers_receipt = time() headers_receipt = time()
reader = _SwiftReader(path, body, headers) obj_body = _SwiftReader(path, body, headers)
with reader as obj_body:
no_file = options['no_download']
if out_file == "-" and not no_file:
res = {
'action': 'download_object',
'container': container,
'object': obj,
'path': path,
'pseudodir': pseudodir,
'contents': obj_body
}
return res
fp = None fp = None
try: try:
no_file = options['no_download']
content_type = headers.get('content-type') content_type = headers.get('content-type')
if (content_type and if (content_type and
content_type.split(';', 1)[0] == 'text/directory'): content_type.split(';', 1)[0] == 'text/directory'):
@ -1026,12 +1034,6 @@ class SwiftService(object):
mkdirs(dirpath) mkdirs(dirpath)
if not no_file: if not no_file:
if out_file == "-":
res = {
'path': path,
'contents': obj_body
}
return res
if out_file: if out_file:
fp = open(out_file, 'wb') fp = open(out_file, 'wb')
else: else:
@ -1040,7 +1042,7 @@ class SwiftService(object):
else: else:
pseudodir = True pseudodir = True
for chunk in obj_body.buffer(): for chunk in obj_body:
if fp is not None: if fp is not None:
fp.write(chunk) fp.write(chunk)
@ -1052,8 +1054,7 @@ class SwiftService(object):
fp.close() fp.close()
if 'x-object-meta-mtime' in headers and not no_file: if 'x-object-meta-mtime' in headers and not no_file:
mtime = float(headers['x-object-meta-mtime']) mtime = float(headers['x-object-meta-mtime'])
if options['out_file'] \ if options['out_file']:
and not options['out_file'] == "-":
utime(options['out_file'], (mtime, mtime)) utime(options['out_file'], (mtime, mtime))
else: else:
utime(path, (mtime, mtime)) utime(path, (mtime, mtime))

@ -16,9 +16,9 @@
from __future__ import print_function from __future__ import print_function
import logging
import signal import signal
import socket import socket
import logging
from optparse import OptionParser, OptionGroup, SUPPRESS_HELP from optparse import OptionParser, OptionGroup, SUPPRESS_HELP
from os import environ, walk, _exit as os_exit from os import environ, walk, _exit as os_exit
@ -261,8 +261,9 @@ def st_download(parser, args, output_manager):
for down in down_iter: for down in down_iter:
if options.out_file == '-' and 'contents' in down: if options.out_file == '-' and 'contents' in down:
for chunk in down['contents']: contents = down['contents']
output_manager.print_msg(chunk) for chunk in contents:
output_manager.print_raw(chunk)
else: else:
if down['success']: if down['success']:
if options.verbose: if options.verbose:

@ -23,6 +23,7 @@ from six.moves.queue import Queue, Empty
from time import sleep from time import sleep
from swiftclient import multithreading as mt from swiftclient import multithreading as mt
from .utils import CaptureStream
class ThreadTestCase(testtools.TestCase): class ThreadTestCase(testtools.TestCase):
@ -175,8 +176,8 @@ class TestOutputManager(testtools.TestCase):
self.assertEqual(sys.stderr, output_manager.error_stream) self.assertEqual(sys.stderr, output_manager.error_stream)
def test_printers(self): def test_printers(self):
out_stream = six.StringIO() out_stream = CaptureStream(sys.stdout)
err_stream = six.StringIO() err_stream = CaptureStream(sys.stderr)
starting_thread_count = threading.active_count() starting_thread_count = threading.active_count()
with mt.OutputManager( with mt.OutputManager(
@ -201,6 +202,8 @@ class TestOutputManager(testtools.TestCase):
thread_manager.error('one-error-argument') thread_manager.error('one-error-argument')
thread_manager.error('Sometimes\n%.1f%% just\ndoes not\nwork!', thread_manager.error('Sometimes\n%.1f%% just\ndoes not\nwork!',
3.14159) 3.14159)
thread_manager.print_raw(
u'some raw bytes: \u062A\u062A'.encode('utf-8'))
# Now we have a thread for error printing and a thread for # Now we have a thread for error printing and a thread for
# normal print messages # normal print messages
@ -210,25 +213,30 @@ class TestOutputManager(testtools.TestCase):
# The threads should have been cleaned up # The threads should have been cleaned up
self.assertEqual(starting_thread_count, threading.active_count()) self.assertEqual(starting_thread_count, threading.active_count())
out_stream.seek(0)
if six.PY3: if six.PY3:
over_the = "over the '\u062a\u062a'\n" over_the = "over the '\u062a\u062a'\n"
# The CaptureStreamBuffer just encodes all bytes written to it by
# mapping chr over the byte string to produce a str.
raw_bytes = ''.join(
map(chr, u'some raw bytes: \u062A\u062A'.encode('utf-8'))
)
else: else:
over_the = "over the u'\\u062a\\u062a'\n" over_the = "over the u'\\u062a\\u062a'\n"
self.assertEqual([ # We write to the CaptureStream so no decoding is performed
raw_bytes = 'some raw bytes: \xd8\xaa\xd8\xaa'
self.assertEqual(''.join([
'one-argument\n', 'one-argument\n',
'one fish, 88 fish\n', 'one fish, 88 fish\n',
'some\n', 'where\n', over_the, 'some\n', 'where\n', over_the, raw_bytes
], list(out_stream.readlines())) ]), out_stream.getvalue())
err_stream.seek(0)
first_item = u'I have 99 problems, but a \u062A\u062A is not one\n' first_item = u'I have 99 problems, but a \u062A\u062A is not one\n'
if six.PY2: if six.PY2:
first_item = first_item.encode('utf8') first_item = first_item.encode('utf8')
self.assertEqual([ self.assertEqual(''.join([
first_item, first_item,
'one-error-argument\n', 'one-error-argument\n',
'Sometimes\n', '3.1% just\n', 'does not\n', 'work!\n', 'Sometimes\n', '3.1% just\n', 'does not\n', 'work!\n'
], list(err_stream.readlines())) ]), err_stream.getvalue())
self.assertEqual(3, thread_manager.error_count) self.assertEqual(3, thread_manager.error_count)

@ -19,6 +19,7 @@ import testtools
from hashlib import md5 from hashlib import md5
from mock import Mock, PropertyMock from mock import Mock, PropertyMock
from six.moves.queue import Queue, Empty as QueueEmptyError from six.moves.queue import Queue, Empty as QueueEmptyError
from six import BytesIO
import swiftclient import swiftclient
from swiftclient.service import SwiftService, SwiftError from swiftclient.service import SwiftService, SwiftError
@ -101,41 +102,35 @@ class TestSwiftReader(testtools.TestCase):
self.assertRaises(SwiftError, self.sr, 'path', 'body', self.assertRaises(SwiftError, self.sr, 'path', 'body',
{'content-length': 'notanint'}) {'content-length': 'notanint'})
def test_context_usage(self): def test_iterator_usage(self):
def _context(sr): def _consume(sr):
with sr: for _ in sr:
pass pass
sr = self.sr('path', 'body', {}) sr = self.sr('path', BytesIO(b'body'), {})
_context(sr) _consume(sr)
# Check error is raised if expected etag doesnt match calculated md5. # Check error is raised if expected etag doesnt match calculated md5.
# md5 for a SwiftReader that has done nothing is # md5 for a SwiftReader that has done nothing is
# d41d8cd98f00b204e9800998ecf8427e i.e md5 of nothing # d41d8cd98f00b204e9800998ecf8427e i.e md5 of nothing
sr = self.sr('path', 'body', {'etag': 'doesntmatch'}) sr = self.sr('path', BytesIO(b'body'), {'etag': 'doesntmatch'})
self.assertRaises(SwiftError, _context, sr) self.assertRaises(SwiftError, _consume, sr)
sr = self.sr('path', 'body', sr = self.sr('path', BytesIO(b'body'),
{'etag': 'd41d8cd98f00b204e9800998ecf8427e'}) {'etag': '841a2d689ad86bd1611447453c22c6fc'})
_context(sr) _consume(sr)
# Check error is raised if SwiftReader doesnt read the same length # Check error is raised if SwiftReader doesnt read the same length
# as the content length it is created with # as the content length it is created with
sr = self.sr('path', 'body', {'content-length': 5}) sr = self.sr('path', BytesIO(b'body'), {'content-length': 5})
self.assertRaises(SwiftError, _context, sr) self.assertRaises(SwiftError, _consume, sr)
sr = self.sr('path', 'body', {'content-length': 5}) sr = self.sr('path', BytesIO(b'body'), {'content-length': 4})
sr._actual_read = 5 _consume(sr)
_context(sr)
def test_buffer(self):
# md5 = 97ac82a5b825239e782d0339e2d7b910
mock_buffer_content = ['abc'.encode()] * 3
sr = self.sr('path', mock_buffer_content, {})
for x in sr.buffer():
pass
# Check that the iterator generates expected length and etag values
sr = self.sr('path', ['abc'.encode()] * 3, {})
_consume(sr)
self.assertEqual(sr._actual_read, 9) self.assertEqual(sr._actual_read, 9)
self.assertEqual(sr._actual_md5.hexdigest(), self.assertEqual(sr._actual_md5.hexdigest(),
'97ac82a5b825239e782d0339e2d7b910') '97ac82a5b825239e782d0339e2d7b910')

@ -290,10 +290,15 @@ class TestShell(unittest.TestCase):
@mock.patch('swiftclient.service.makedirs') @mock.patch('swiftclient.service.makedirs')
@mock.patch('swiftclient.service.Connection') @mock.patch('swiftclient.service.Connection')
def test_download(self, connection, makedirs): def test_download(self, connection, makedirs):
connection.return_value.get_object.return_value = [ objcontent = six.BytesIO(b'objcontent')
{'content-type': 'text/plain', connection.return_value.get_object.side_effect = [
({'content-type': 'text/plain',
'etag': '2cbbfe139a744d6abbe695e17f3c1991'},
objcontent),
({'content-type': 'text/plain',
'etag': 'd41d8cd98f00b204e9800998ecf8427e'}, 'etag': 'd41d8cd98f00b204e9800998ecf8427e'},
''] '')
]
# Test downloading whole container # Test downloading whole container
connection.return_value.get_container.side_effect = [ connection.return_value.get_container.side_effect = [
@ -318,6 +323,12 @@ class TestShell(unittest.TestCase):
mock_open.assert_called_once_with('object', 'wb') mock_open.assert_called_once_with('object', 'wb')
# Test downloading single object # Test downloading single object
objcontent = six.BytesIO(b'objcontent')
connection.return_value.get_object.side_effect = [
({'content-type': 'text/plain',
'etag': '2cbbfe139a744d6abbe695e17f3c1991'},
objcontent)
]
with mock.patch(BUILTIN_OPEN) as mock_open: with mock.patch(BUILTIN_OPEN) as mock_open:
argv = ["", "download", "container", "object"] argv = ["", "download", "container", "object"]
swiftclient.shell.main(argv) swiftclient.shell.main(argv)
@ -326,6 +337,18 @@ class TestShell(unittest.TestCase):
response_dict={}) response_dict={})
mock_open.assert_called_with('object', 'wb') mock_open.assert_called_with('object', 'wb')
# Test downloading single object to stdout
objcontent = six.BytesIO(b'objcontent')
connection.return_value.get_object.side_effect = [
({'content-type': 'text/plain',
'etag': '2cbbfe139a744d6abbe695e17f3c1991'},
objcontent)
]
with CaptureOutput() as output:
argv = ["", "download", "--output", "-", "container", "object"]
swiftclient.shell.main(argv)
self.assertEqual('objcontent', output.out)
@mock.patch('swiftclient.service.Connection') @mock.patch('swiftclient.service.Connection')
def test_download_no_content_type(self, connection): def test_download_no_content_type(self, connection):
connection.return_value.get_object.return_value = [ connection.return_value.get_object.return_value = [

@ -340,13 +340,41 @@ class MockHttpTest(testtools.TestCase):
reload_module(c) reload_module(c)
class CaptureStreamBuffer(object):
"""
CaptureStreamBuffer is used for testing raw byte writing for PY3. Anything
written here is decoded as utf-8 and written to the parent CaptureStream
"""
def __init__(self, captured_stream):
self._captured_stream = captured_stream
def write(self, bytes_data):
# No encoding, just convert the raw bytes into a str for testing
# The below call also validates that we have a byte string.
self._captured_stream.write(
''.join(map(chr, bytes_data))
)
class CaptureStream(object): class CaptureStream(object):
def __init__(self, stream): def __init__(self, stream):
self.stream = stream self.stream = stream
self._capture = six.StringIO() self._capture = six.StringIO()
self._buffer = CaptureStreamBuffer(self)
self.streams = [self.stream, self._capture] self.streams = [self.stream, self._capture]
@property
def buffer(self):
if six.PY3:
return self._buffer
else:
raise AttributeError(
'Output stream has no attribute "buffer" in Python2')
def flush(self):
pass
def write(self, *args, **kwargs): def write(self, *args, **kwargs):
for stream in self.streams: for stream in self.streams:
stream.write(*args, **kwargs) stream.write(*args, **kwargs)