diff --git a/test/unit/proxy/controllers/test_base.py b/test/unit/proxy/controllers/test_base.py index a2bd055427..eeecdcfdbb 100644 --- a/test/unit/proxy/controllers/test_base.py +++ b/test/unit/proxy/controllers/test_base.py @@ -183,6 +183,39 @@ class FakeCache(FakeMemcache): return self.stub or self.store.get(key) +class TestSource(object): + def __init__(self, chunks, headers=None, body=b''): + self.chunks = list(chunks) + self.headers = headers or {} + self.status = 200 + self.swift_conn = None + self.body = body + + def read(self, _read_size): + if self.chunks: + chunk = self.chunks.pop(0) + if chunk is None: + raise exceptions.ChunkReadTimeout() + else: + return chunk + else: + return self.body + + def getheader(self, header): + # content-length for the whole object is generated dynamically + # by summing non-None chunks + if header.lower() == "content-length": + if self.chunks: + return str(sum(len(c) for c in self.chunks + if c is not None)) + return len(self.read(-1)) + return self.headers.get(header.lower()) + + def getheaders(self): + return [('content-length', self.getheader('content-length'))] + \ + [(k, v) for k, v in self.headers.items()] + + class BaseTest(unittest.TestCase): def setUp(self): @@ -1272,67 +1305,21 @@ class TestFuncs(BaseTest): self.assertEqual('', dst_headers['Referer']) def test_client_chunk_size(self): - - class TestSource(object): - def __init__(self, chunks): - self.chunks = list(chunks) - self.status = 200 - - def read(self, _read_size): - if self.chunks: - return self.chunks.pop(0) - else: - return b'' - - def getheader(self, header): - if header.lower() == "content-length": - return str(sum(len(c) for c in self.chunks)) - - def getheaders(self): - return [('content-length', self.getheader('content-length'))] - source = TestSource(( b'abcd', b'1234', b'abc', b'd1', b'234abcd1234abcd1', b'2')) req = Request.blank('/v1/a/c/o') - node = {} handler = GetOrHeadHandler( self.app, req, None, Namespace(num_primary_nodes=3), None, None, {}, client_chunk_size=8) - handler.source = source - handler.node = node - app_iter = handler._make_app_iter(req) - client_chunks = list(app_iter) + with mock.patch.object(handler, '_get_source_and_node', + return_value=(source, {})): + resp = handler.get_working_response(req) + client_chunks = list(resp.app_iter) self.assertEqual(client_chunks, [ b'abcd1234', b'abcd1234', b'abcd1234', b'abcd12']) def test_client_chunk_size_resuming(self): - - class TestSource(object): - def __init__(self, chunks): - self.chunks = list(chunks) - self.status = 200 - - def read(self, _read_size): - if self.chunks: - chunk = self.chunks.pop(0) - if chunk is None: - raise exceptions.ChunkReadTimeout() - else: - return chunk - else: - return b'' - - def getheader(self, header): - # content-length for the whole object is generated dynamically - # by summing non-None chunks initialized as source1 - if header.lower() == "content-length": - return str(sum(len(c) for c in self.chunks - if c is not None)) - - def getheaders(self): - return [('content-length', self.getheader('content-length'))] - node = {'ip': '1.2.3.4', 'port': 6200, 'device': 'sda'} source1 = TestSource([b'abcd', b'1234', None, @@ -1346,93 +1333,55 @@ class TestFuncs(BaseTest): None, {}, client_chunk_size=8) range_headers = [] - sources = [(source2, node), (source3, node)] + sources = [(source1, node), (source2, node), (source3, node)] def mock_get_source_and_node(): - range_headers.append(handler.backend_headers['Range']) + range_headers.append(handler.backend_headers.get('Range')) return sources.pop(0) - handler.source = source1 - handler.node = node - app_iter = handler._make_app_iter(req) with mock.patch.object(handler, '_get_source_and_node', - side_effect=mock_get_source_and_node): - client_chunks = list(app_iter) - self.assertEqual(range_headers, ['bytes=8-27', 'bytes=16-27']) + mock_get_source_and_node): + resp = handler.get_working_response(req) + client_chunks = list(resp.app_iter) + self.assertEqual(range_headers, [None, 'bytes=8-27', 'bytes=16-27']) self.assertEqual(client_chunks, [ b'abcd1234', b'efgh5678', b'lotsmore', b'data']) def test_client_chunk_size_resuming_chunked(self): - - class TestChunkedSource(object): - def __init__(self, chunks): - self.chunks = list(chunks) - self.status = 200 - self.headers = {'transfer-encoding': 'chunked', - 'content-type': 'text/plain'} - - def read(self, _read_size): - if self.chunks: - chunk = self.chunks.pop(0) - if chunk is None: - raise exceptions.ChunkReadTimeout() - else: - return chunk - else: - return b'' - - def getheader(self, header): - return self.headers.get(header.lower()) - - def getheaders(self): - return self.headers - node = {'ip': '1.2.3.4', 'port': 6200, 'device': 'sda'} - - source1 = TestChunkedSource([b'abcd', b'1234', b'abc', None]) - source2 = TestChunkedSource([b'efgh5678']) + headers = {'transfer-encoding': 'chunked', + 'content-type': 'text/plain'} + source1 = TestSource([b'abcd', b'1234', b'abc', None], headers=headers) + source2 = TestSource([b'efgh5678'], headers=headers) + sources = [(source1, node), (source2, node)] req = Request.blank('/v1/a/c/o') handler = GetOrHeadHandler( self.app, req, 'Object', Namespace(num_primary_nodes=1), None, None, {}, client_chunk_size=8) - handler.source = source1 - handler.node = node - app_iter = handler._make_app_iter(req) + def mock_get_source_and_node(): + return sources.pop(0) + with mock.patch.object(handler, '_get_source_and_node', - lambda: (source2, node)): - client_chunks = list(app_iter) + mock_get_source_and_node): + resp = handler.get_working_response(req) + client_chunks = list(resp.app_iter) self.assertEqual(client_chunks, [b'abcd1234', b'efgh5678']) def test_disconnected_logging(self): self.app.logger = mock.Mock() req = Request.blank('/v1/a/c/o') - - class TestSource(object): - def __init__(self): - self.headers = {'content-type': 'text/plain', - 'content-length': len(self.read(-1))} - self.status = 200 - - def read(self, _read_size): - return b'the cake is a lie' - - def getheader(self, header): - return self.headers.get(header.lower()) - - def getheaders(self): - return self.headers - - source = TestSource() + headers = {'content-type': 'text/plain'} + source = TestSource([], headers=headers, body=b'the cake is a lie') node = {'ip': '1.2.3.4', 'port': 6200, 'device': 'sda'} handler = GetOrHeadHandler( self.app, req, 'Object', Namespace(num_primary_nodes=1), None, 'some-path', {}) - handler.source = source - handler.node = node - app_iter = handler._make_app_iter(req) - app_iter.close() + with mock.patch.object(handler, '_get_source_and_node', + return_value=(source, node)): + resp = handler.get_working_response(req) + resp.app_iter.close() self.app.logger.info.assert_called_once_with( 'Client disconnected on read of %r', 'some-path') @@ -1441,11 +1390,12 @@ class TestFuncs(BaseTest): handler = GetOrHeadHandler( self.app, req, 'Object', Namespace(num_primary_nodes=1), None, None, {}) - handler.source = source - handler.node = node - app_iter = handler._make_app_iter(req) - next(app_iter) - app_iter.close() + + with mock.patch.object(handler, '_get_source_and_node', + return_value=(source, node)): + resp = handler.get_working_response(req) + next(resp.app_iter) + resp.app_iter.close() self.app.logger.warning.assert_not_called() def test_bytes_to_skip(self):