Refactor some file-like iters as utils.InputProxy subclasses

There's a few places where bespoke file-like wrapper classes have been
implemented. The common methods are now inherited from
utils.InputProxy.

Make utils.FileLikeIter tolerate size=None to mean the same as size=-1
so that it is consistent with the behavior of other input streams.

Fix docstrings in FileLikeIter.

Depends-On: https://review.opendev.org/c/openstack/requirements/+/942845
Change-Id: I20741ab58b0933390dc4679c3e6b2d888857d577
This commit is contained in:
Alistair Coles
2025-01-23 12:17:44 +00:00
committed by Tim Burke
parent 7140633925
commit e4cc228ed0
6 changed files with 304 additions and 90 deletions

View File

@@ -27,7 +27,7 @@ from swift.common.request_helpers import get_object_transient_sysmeta, \
from swift.common.swob import Request, Match, HTTPException, \ from swift.common.swob import Request, Match, HTTPException, \
HTTPUnprocessableEntity, wsgi_to_bytes, bytes_to_wsgi, normalize_etag HTTPUnprocessableEntity, wsgi_to_bytes, bytes_to_wsgi, normalize_etag
from swift.common.utils import get_logger, config_true_value, \ from swift.common.utils import get_logger, config_true_value, \
MD5_OF_EMPTY_STRING, md5 MD5_OF_EMPTY_STRING, md5, InputProxy
def encrypt_header_val(crypto, value, key): def encrypt_header_val(crypto, value, key):
@@ -66,11 +66,11 @@ def _hmac_etag(key, etag):
return base64.b64encode(result).decode() return base64.b64encode(result).decode()
class EncInputWrapper(object): class EncInputWrapper(InputProxy):
"""File-like object to be swapped in for wsgi.input.""" """File-like object to be swapped in for wsgi.input."""
def __init__(self, crypto, keys, req, logger): def __init__(self, crypto, keys, req, logger):
super().__init__(req.environ['wsgi.input'])
self.env = req.environ self.env = req.environ
self.wsgi_input = req.environ['wsgi.input']
self.path = req.path self.path = req.path
self.crypto = crypto self.crypto = crypto
self.body_crypto_ctxt = None self.body_crypto_ctxt = None
@@ -180,15 +180,7 @@ class EncInputWrapper(object):
req.environ['swift.callback.update_footers'] = footers_callback req.environ['swift.callback.update_footers'] = footers_callback
def read(self, *args, **kwargs): def chunk_update(self, chunk, eof, *args, **kwargs):
return self.readChunk(self.wsgi_input.read, *args, **kwargs)
def readline(self, *args, **kwargs):
return self.readChunk(self.wsgi_input.readline, *args, **kwargs)
def readChunk(self, read_method, *args, **kwargs):
chunk = read_method(*args, **kwargs)
if chunk: if chunk:
self._init_encryption_context() self._init_encryption_context()
self.plaintext_md5.update(chunk) self.plaintext_md5.update(chunk)

View File

@@ -134,7 +134,7 @@ from swift.common.digest import get_allowed_digests, \
extract_digest_and_algorithm, DEFAULT_ALLOWED_DIGESTS extract_digest_and_algorithm, DEFAULT_ALLOWED_DIGESTS
from swift.common.utils import streq_const_time, parse_content_disposition, \ from swift.common.utils import streq_const_time, parse_content_disposition, \
parse_mime_headers, iter_multipart_mime_documents, reiterate, \ parse_mime_headers, iter_multipart_mime_documents, reiterate, \
closing_if_possible, get_logger closing_if_possible, get_logger, InputProxy
from swift.common.registry import register_swift_info from swift.common.registry import register_swift_info
from swift.common.wsgi import WSGIContext, make_pre_authed_env from swift.common.wsgi import WSGIContext, make_pre_authed_env
from swift.common.swob import HTTPUnauthorized, wsgi_to_str, str_to_wsgi from swift.common.swob import HTTPUnauthorized, wsgi_to_str, str_to_wsgi
@@ -158,7 +158,7 @@ class FormUnauthorized(Exception):
pass pass
class _CappedFileLikeObject(object): class _CappedFileLikeObject(InputProxy):
""" """
A file-like object wrapping another file-like object that raises A file-like object wrapping another file-like object that raises
an EOFError if the amount of data read exceeds a given an EOFError if the amount of data read exceeds a given
@@ -170,26 +170,15 @@ class _CappedFileLikeObject(object):
""" """
def __init__(self, fp, max_file_size): def __init__(self, fp, max_file_size):
self.fp = fp super().__init__(fp)
self.max_file_size = max_file_size self.max_file_size = max_file_size
self.amount_read = 0
self.file_size_exceeded = False self.file_size_exceeded = False
def read(self, size=None): def chunk_update(self, chunk, eof, *args, **kwargs):
ret = self.fp.read(size) if self.bytes_received > self.max_file_size:
self.amount_read += len(ret)
if self.amount_read > self.max_file_size:
self.file_size_exceeded = True self.file_size_exceeded = True
raise EOFError('max_file_size exceeded') raise EOFError('max_file_size exceeded')
return ret return chunk
def readline(self):
ret = self.fp.readline()
self.amount_read += len(ret)
if self.amount_read > self.max_file_size:
self.file_size_exceeded = True
raise EOFError('max_file_size exceeded')
return ret
class FormPost(object): class FormPost(object):

View File

@@ -24,8 +24,8 @@ import re
from urllib.parse import quote, unquote, parse_qsl from urllib.parse import quote, unquote, parse_qsl
import string import string
from swift.common.utils import split_path, json, close_if_possible, md5, \ from swift.common.utils import split_path, json, md5, streq_const_time, \
streq_const_time, get_policy_index get_policy_index, InputProxy
from swift.common.registry import get_swift_info from swift.common.registry import get_swift_info
from swift.common import swob from swift.common import swob
from swift.common.http import HTTP_OK, HTTP_CREATED, HTTP_ACCEPTED, \ from swift.common.http import HTTP_OK, HTTP_CREATED, HTTP_ACCEPTED, \
@@ -133,41 +133,42 @@ class S3InputSHA256Mismatch(BaseException):
self.computed = computed self.computed = computed
class HashingInput(object): class HashingInput(InputProxy):
""" """
wsgi.input wrapper to verify the SHA256 of the input as it's read. wsgi.input wrapper to verify the SHA256 of the input as it's read.
""" """
def __init__(self, reader, content_length, expected_hex_hash): def __init__(self, wsgi_input, content_length, expected_hex_hash):
self._input = reader super().__init__(wsgi_input)
self._to_read = content_length self._expected_length = content_length
self._hasher = sha256() self._hasher = sha256()
self._expected = expected_hex_hash self._expected_hash = expected_hex_hash
if content_length == 0 and \ if content_length == 0 and \
self._hasher.hexdigest() != self._expected.lower(): self._hasher.hexdigest() != self._expected_hash.lower():
self.close() self.close()
raise XAmzContentSHA256Mismatch( raise XAmzContentSHA256Mismatch(
client_computed_content_s_h_a256=self._expected, client_computed_content_s_h_a256=self._expected_hash,
s3_computed_content_s_h_a256=self._hasher.hexdigest(), s3_computed_content_s_h_a256=self._hasher.hexdigest(),
) )
def read(self, size=None): def chunk_update(self, chunk, eof, *args, **kwargs):
chunk = self._input.read(size)
self._hasher.update(chunk) self._hasher.update(chunk)
self._to_read -= len(chunk)
short_read = bool(chunk) if size is None else (len(chunk) < size) if self.bytes_received < self._expected_length:
if self._to_read < 0 or (short_read and self._to_read) or ( error = eof
self._to_read == 0 and elif self.bytes_received == self._expected_length:
self._hasher.hexdigest() != self._expected.lower()): error = self._hasher.hexdigest() != self._expected_hash.lower()
else:
error = True
if error:
self.close() self.close()
# Since we don't return the last chunk, the PUT never completes # Since we don't return the last chunk, the PUT never completes
raise S3InputSHA256Mismatch( raise S3InputSHA256Mismatch(
self._expected, self._expected_hash,
self._hasher.hexdigest()) self._hasher.hexdigest())
return chunk
def close(self): return chunk
close_if_possible(self._input)
class SigV4Mixin(object): class SigV4Mixin(object):

View File

@@ -489,7 +489,9 @@ class FileLikeIter(object):
def __next__(self): def __next__(self):
""" """
next(x) -> the next value, or raise StopIteration :raise StopIteration: if there are no more values to iterate.
:raise ValueError: if the close() method has been called.
:return: the next value.
""" """
if self.closed: if self.closed:
raise ValueError('I/O operation on closed file') raise ValueError('I/O operation on closed file')
@@ -502,12 +504,14 @@ class FileLikeIter(object):
def read(self, size=-1): def read(self, size=-1):
""" """
read([size]) -> read at most size bytes, returned as a bytes string. :param size: (optional) the maximum number of bytes to read. The
default value of ``-1`` means 'unlimited' i.e. read until the wrapped
If the size argument is negative or omitted, read until EOF is reached. iterable is exhausted.
Notice that when in non-blocking mode, less data than what was :raise ValueError: if the close() method has been called.
requested may be returned, even if no size parameter was given. :return: a bytes literal; if the wrapped iterable has been exhausted
then a zero-length bytes literal is returned.
""" """
size = -1 if size is None else size
if self.closed: if self.closed:
raise ValueError('I/O operation on closed file') raise ValueError('I/O operation on closed file')
if size < 0: if size < 0:
@@ -529,12 +533,17 @@ class FileLikeIter(object):
def readline(self, size=-1): def readline(self, size=-1):
""" """
readline([size]) -> next line from the file, as a bytes string. Read the next line.
Retain newline. A non-negative size argument limits the maximum :param size: (optional) the maximum number of bytes of the next line to
number of bytes to return (an incomplete line may be returned then). read. The default value of ``-1`` means 'unlimited' i.e. read to
Return an empty string at EOF. the end of the line or until the wrapped iterable is exhausted,
whichever is first.
:raise ValueError: if the close() method has been called.
:return: a bytes literal; if the wrapped iterable has been exhausted
then a zero-length bytes literal is returned.
""" """
size = -1 if size is None else size
if self.closed: if self.closed:
raise ValueError('I/O operation on closed file') raise ValueError('I/O operation on closed file')
data = b'' data = b''
@@ -557,12 +566,16 @@ class FileLikeIter(object):
def readlines(self, sizehint=-1): def readlines(self, sizehint=-1):
""" """
readlines([size]) -> list of bytes strings, each a line from the file.
Call readline() repeatedly and return a list of the lines so read. Call readline() repeatedly and return a list of the lines so read.
The optional size argument, if given, is an approximate bound on the
total number of bytes in the lines returned. :param sizehint: (optional) an approximate bound on the total number of
bytes in the lines returned. Lines are read until ``sizehint`` has
been exceeded but complete lines are always returned, so the total
bytes read may exceed ``sizehint``.
:raise ValueError: if the close() method has been called.
:return: a list of bytes literals, each a line from the file.
""" """
sizehint = -1 if sizehint is None else sizehint
if self.closed: if self.closed:
raise ValueError('I/O operation on closed file') raise ValueError('I/O operation on closed file')
lines = [] lines = []
@@ -579,12 +592,10 @@ class FileLikeIter(object):
def close(self): def close(self):
""" """
close() -> None or (perhaps) an integer. Close the file. Close the iter.
Sets data attribute .closed to True. A closed file cannot be used for Once close() has been called the iter cannot be used for further I/O
further I/O operations. close() may be called more than once without operations. close() may be called more than once without error.
error. Some kinds of file objects (for example, opened by popen())
may return an exit status upon closing.
""" """
self.iterator = None self.iterator = None
self.closed = True self.closed = True
@@ -2515,41 +2526,79 @@ class InputProxy(object):
""" """
File-like object that counts bytes read. File-like object that counts bytes read.
To be swapped in for wsgi.input for accounting purposes. To be swapped in for wsgi.input for accounting purposes.
:param wsgi_input: file-like object to be wrapped
""" """
def __init__(self, wsgi_input): def __init__(self, wsgi_input):
"""
:param wsgi_input: file-like object to wrap the functionality of
"""
self.wsgi_input = wsgi_input self.wsgi_input = wsgi_input
#: total number of bytes read from the wrapped input
self.bytes_received = 0 self.bytes_received = 0
#: ``True`` if an exception is raised by ``read()`` or ``readline()``,
#: ``False`` otherwise
self.client_disconnect = False self.client_disconnect = False
def read(self, *args, **kwargs): def chunk_update(self, chunk, eof, *args, **kwargs):
"""
Called each time a chunk of bytes is read from the wrapped input.
:param chunk: the chunk of bytes that has been read.
:param eof: ``True`` if there are no more bytes to read from the
wrapped input, ``False`` otherwise. If ``read()`` has been called
this will be ``True`` when the size of ``chunk`` is less than the
requested size or the requested size is None. If ``readline`` has
been called this will be ``True`` when an incomplete line is read
(i.e. not ending with ``b'\\n'``) whose length is less than the
requested size or the requested size is None. If ``read()`` or
``readline()`` are called with a requested size that exactly
matches the number of bytes remaining in the wrapped input then
``eof`` will be ``False``. A subsequent call to ``read()`` or
``readline()`` with non-zero ``size`` would result in ``eof`` being
``True``. Alternatively, the end of the input could be inferred
by comparing ``bytes_received`` with the expected length of the
input.
"""
# subclasses may override this method; either the given chunk or an
# alternative chunk value should be returned
return chunk
def read(self, size=None, *args, **kwargs):
""" """
Pass read request to the underlying file-like object and Pass read request to the underlying file-like object and
add bytes read to total. add bytes read to total.
:param size: (optional) maximum number of bytes to read; the default
``None`` means unlimited.
""" """
try: try:
chunk = self.wsgi_input.read(*args, **kwargs) chunk = self.wsgi_input.read(size, *args, **kwargs)
except Exception: except Exception:
self.client_disconnect = True self.client_disconnect = True
raise raise
self.bytes_received += len(chunk) self.bytes_received += len(chunk)
return chunk eof = size is None or size < 0 or len(chunk) < size
return self.chunk_update(chunk, eof)
def readline(self, *args, **kwargs): def readline(self, size=None, *args, **kwargs):
""" """
Pass readline request to the underlying file-like object and Pass readline request to the underlying file-like object and
add bytes read to total. add bytes read to total.
:param size: (optional) maximum number of bytes to read from the
current line; the default ``None`` means unlimited.
""" """
try: try:
line = self.wsgi_input.readline(*args, **kwargs) line = self.wsgi_input.readline(size, *args, **kwargs)
except Exception: except Exception:
self.client_disconnect = True self.client_disconnect = True
raise raise
self.bytes_received += len(line) self.bytes_received += len(line)
return line eof = ((size is None or size < 0 or len(line) < size)
and (line[-1:] != b'\n'))
return self.chunk_update(line, eof)
def close(self):
close_if_possible(self.wsgi_input)
class LRUCache(object): class LRUCache(object):

View File

@@ -1468,9 +1468,17 @@ class TestHashingInput(S3ApiTestCase):
# can continue trying to read -- but it'll be empty # can continue trying to read -- but it'll be empty
self.assertEqual(b'', wrapped.read(2)) self.assertEqual(b'', wrapped.read(2))
self.assertFalse(wrapped._input.closed) self.assertFalse(wrapped.wsgi_input.closed)
wrapped.close() wrapped.close()
self.assertTrue(wrapped._input.closed) self.assertTrue(wrapped.wsgi_input.closed)
def test_good_readline(self):
raw = b'12345\n6789'
wrapped = HashingInput(
BytesIO(raw), 10, hashlib.sha256(raw).hexdigest())
self.assertEqual(b'12345\n', wrapped.readline())
self.assertEqual(b'6789', wrapped.readline())
self.assertEqual(b'', wrapped.readline())
def test_empty(self): def test_empty(self):
wrapped = HashingInput( wrapped = HashingInput(
@@ -1478,9 +1486,9 @@ class TestHashingInput(S3ApiTestCase):
self.assertEqual(b'', wrapped.read(4)) self.assertEqual(b'', wrapped.read(4))
self.assertEqual(b'', wrapped.read(2)) self.assertEqual(b'', wrapped.read(2))
self.assertFalse(wrapped._input.closed) self.assertFalse(wrapped.wsgi_input.closed)
wrapped.close() wrapped.close()
self.assertTrue(wrapped._input.closed) self.assertTrue(wrapped.wsgi_input.closed)
def test_too_long(self): def test_too_long(self):
raw = b'123456789' raw = b'123456789'
@@ -1495,18 +1503,26 @@ class TestHashingInput(S3ApiTestCase):
# won't get caught by most things in a pipeline # won't get caught by most things in a pipeline
self.assertNotIsInstance(raised.exception, Exception) self.assertNotIsInstance(raised.exception, Exception)
# the error causes us to close the input # the error causes us to close the input
self.assertTrue(wrapped._input.closed) self.assertTrue(wrapped.wsgi_input.closed)
def test_too_short(self): def test_too_short_read_piecemeal(self):
raw = b'123456789' raw = b'123456789'
wrapped = HashingInput( wrapped = HashingInput(
BytesIO(raw), 10, hashlib.sha256(raw).hexdigest()) BytesIO(raw), 10, hashlib.sha256(raw).hexdigest())
self.assertEqual(b'1234', wrapped.read(4)) self.assertEqual(b'1234', wrapped.read(4))
self.assertEqual(b'56', wrapped.read(2)) self.assertEqual(b'56789', wrapped.read(5))
# even though the hash matches, there was more data than we expected # even though the hash matches, there was less data than we expected
with self.assertRaises(S3InputSHA256Mismatch): with self.assertRaises(S3InputSHA256Mismatch):
wrapped.read(4) wrapped.read(1)
self.assertTrue(wrapped._input.closed) self.assertTrue(wrapped.wsgi_input.closed)
def test_too_short_read_all(self):
raw = b'123456789'
wrapped = HashingInput(
BytesIO(raw), 10, hashlib.sha256(raw).hexdigest())
with self.assertRaises(S3InputSHA256Mismatch):
wrapped.read()
self.assertTrue(wrapped.wsgi_input.closed)
def test_bad_hash(self): def test_bad_hash(self):
raw = b'123456789' raw = b'123456789'
@@ -1516,7 +1532,7 @@ class TestHashingInput(S3ApiTestCase):
self.assertEqual(b'5678', wrapped.read(4)) self.assertEqual(b'5678', wrapped.read(4))
with self.assertRaises(S3InputSHA256Mismatch): with self.assertRaises(S3InputSHA256Mismatch):
wrapped.read(4) wrapped.read(4)
self.assertTrue(wrapped._input.closed) self.assertTrue(wrapped.wsgi_input.closed)
def test_empty_bad_hash(self): def test_empty_bad_hash(self):
_input = BytesIO(b'') _input = BytesIO(b'')
@@ -1526,6 +1542,14 @@ class TestHashingInput(S3ApiTestCase):
HashingInput(_input, 0, 'nope') HashingInput(_input, 0, 'nope')
self.assertTrue(_input.closed) self.assertTrue(_input.closed)
def test_bad_hash_readline(self):
raw = b'12345\n6789'
wrapped = HashingInput(
BytesIO(raw), 10, hashlib.sha256(raw[:-3]).hexdigest())
self.assertEqual(b'12345\n', wrapped.readline())
with self.assertRaises(S3InputSHA256Mismatch):
self.assertEqual(b'6789', wrapped.readline())
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()

View File

@@ -18,6 +18,7 @@ from __future__ import print_function
import argparse import argparse
import hashlib import hashlib
import io
import itertools import itertools
from swift.common.statsd_client import StatsdClient from swift.common.statsd_client import StatsdClient
@@ -2961,8 +2962,16 @@ class TestFileLikeIter(unittest.TestCase):
def test_read(self): def test_read(self):
in_iter = [b'abc', b'de', b'fghijk', b'l'] in_iter = [b'abc', b'de', b'fghijk', b'l']
iter_file = utils.FileLikeIter(in_iter) expected = b''.join(in_iter)
self.assertEqual(iter_file.read(), b''.join(in_iter)) self.assertEqual(utils.FileLikeIter(in_iter).read(), expected)
self.assertEqual(utils.FileLikeIter(in_iter).read(-1), expected)
self.assertEqual(utils.FileLikeIter(in_iter).read(None), expected)
def test_read_empty(self):
in_iter = [b'abc']
ip = utils.FileLikeIter(in_iter)
self.assertEqual(b'abc', ip.read())
self.assertEqual(b'', ip.read())
def test_read_with_size(self): def test_read_with_size(self):
in_iter = [b'abc', b'de', b'fghijk', b'l'] in_iter = [b'abc', b'de', b'fghijk', b'l']
@@ -2995,6 +3004,15 @@ class TestFileLikeIter(unittest.TestCase):
[v if v == b'trailing.' else v + b'\n' [v if v == b'trailing.' else v + b'\n'
for v in b''.join(in_iter).split(b'\n')]) for v in b''.join(in_iter).split(b'\n')])
def test_readline_size_unlimited(self):
in_iter = [b'abc', b'd\nef']
self.assertEqual(
utils.FileLikeIter(in_iter).readline(-1),
b'abcd\n')
self.assertEqual(
utils.FileLikeIter(in_iter).readline(None),
b'abcd\n')
def test_readline2(self): def test_readline2(self):
self.assertEqual( self.assertEqual(
utils.FileLikeIter([b'abc', b'def\n']).readline(4), utils.FileLikeIter([b'abc', b'def\n']).readline(4),
@@ -3029,6 +3047,16 @@ class TestFileLikeIter(unittest.TestCase):
lines, lines,
[v if v == b'trailing.' else v + b'\n' [v if v == b'trailing.' else v + b'\n'
for v in b''.join(in_iter).split(b'\n')]) for v in b''.join(in_iter).split(b'\n')])
lines = utils.FileLikeIter(in_iter).readlines(sizehint=-1)
self.assertEqual(
lines,
[v if v == b'trailing.' else v + b'\n'
for v in b''.join(in_iter).split(b'\n')])
lines = utils.FileLikeIter(in_iter).readlines(sizehint=None)
self.assertEqual(
lines,
[v if v == b'trailing.' else v + b'\n'
for v in b''.join(in_iter).split(b'\n')])
def test_readlines_with_size(self): def test_readlines_with_size(self):
in_iter = [b'abc\n', b'd', b'\nef', b'g\nh', b'\nij\n\nk\n', in_iter = [b'abc\n', b'd', b'\nef', b'g\nh', b'\nij\n\nk\n',
@@ -3092,6 +3120,137 @@ class TestFileLikeIter(unittest.TestCase):
self.assertEqual(utils.get_hub(), 'selects') self.assertEqual(utils.get_hub(), 'selects')
class TestInputProxy(unittest.TestCase):
def test_read_all(self):
self.assertEqual(utils.InputProxy(io.BytesIO(b'abc')).read(), b'abc')
self.assertEqual(utils.InputProxy(io.BytesIO(b'abc')).read(-1), b'abc')
self.assertEqual(
utils.InputProxy(io.BytesIO(b'abc')).read(None), b'abc')
def test_read_size(self):
self.assertEqual(utils.InputProxy(io.BytesIO(b'abc')).read(0), b'')
self.assertEqual(utils.InputProxy(io.BytesIO(b'abc')).read(2), b'ab')
self.assertEqual(utils.InputProxy(io.BytesIO(b'abc')).read(4), b'abc')
def test_readline(self):
ip = utils.InputProxy(io.BytesIO(b'ab\nc'))
self.assertEqual(ip.readline(), b'ab\n')
self.assertFalse(ip.client_disconnect)
def test_bytes_received(self):
ip = utils.InputProxy(io.BytesIO(b'ab\ncdef'))
ip.readline()
self.assertEqual(3, ip.bytes_received)
ip.read(2)
self.assertEqual(5, ip.bytes_received)
ip.read(99)
self.assertEqual(7, ip.bytes_received)
def test_close(self):
utils.InputProxy(object()).close() # safe
fake = mock.MagicMock()
fake.close = mock.MagicMock()
ip = (utils.InputProxy(fake))
ip.close()
self.assertEqual([mock.call()], fake.close.call_args_list)
self.assertFalse(ip.client_disconnect)
def test_read_piecemeal_chunk_update(self):
ip = utils.InputProxy(io.BytesIO(b'abc'))
with mock.patch.object(ip, 'chunk_update') as mocked:
ip.read(1)
ip.read(2)
ip.read(1)
ip.read(1)
self.assertEqual([mock.call(b'a', False),
mock.call(b'bc', False),
mock.call(b'', True),
mock.call(b'', True)], mocked.call_args_list)
def test_read_unlimited_chunk_update(self):
ip = utils.InputProxy(io.BytesIO(b'abc'))
with mock.patch.object(ip, 'chunk_update') as mocked:
ip.read()
ip.read()
self.assertEqual([mock.call(b'abc', True),
mock.call(b'', True)], mocked.call_args_list)
ip = utils.InputProxy(io.BytesIO(b'abc'))
with mock.patch.object(ip, 'chunk_update') as mocked:
ip.read(None)
ip.read(None)
self.assertEqual([mock.call(b'abc', True),
mock.call(b'', True)], mocked.call_args_list)
ip = utils.InputProxy(io.BytesIO(b'abc'))
with mock.patch.object(ip, 'chunk_update') as mocked:
ip.read(-1)
ip.read(-1)
self.assertEqual([mock.call(b'abc', True),
mock.call(b'', True)], mocked.call_args_list)
def test_readline_piecemeal_chunk_update(self):
ip = utils.InputProxy(io.BytesIO(b'ab\nc'))
with mock.patch.object(ip, 'chunk_update') as mocked:
ip.readline(3)
ip.readline(1) # read to exact length
ip.readline(1)
self.assertEqual([mock.call(b'ab\n', False),
mock.call(b'c', False),
mock.call(b'', True)], mocked.call_args_list)
ip = utils.InputProxy(io.BytesIO(b'ab\nc'))
with mock.patch.object(ip, 'chunk_update') as mocked:
ip.readline(3)
ip.readline(2) # read beyond exact length
ip.readline(1)
self.assertEqual([mock.call(b'ab\n', False),
mock.call(b'c', True),
mock.call(b'', True)], mocked.call_args_list)
def test_readline_unlimited_chunk_update(self):
ip = utils.InputProxy(io.BytesIO(b'ab\nc'))
with mock.patch.object(ip, 'chunk_update') as mocked:
ip.readline()
ip.readline()
self.assertEqual([mock.call(b'ab\n', False),
mock.call(b'c', True)], mocked.call_args_list)
ip = utils.InputProxy(io.BytesIO(b'ab\nc'))
with mock.patch.object(ip, 'chunk_update') as mocked:
ip.readline(None)
ip.readline(None)
self.assertEqual([mock.call(b'ab\n', False),
mock.call(b'c', True)], mocked.call_args_list)
ip = utils.InputProxy(io.BytesIO(b'ab\nc'))
with mock.patch.object(ip, 'chunk_update') as mocked:
ip.readline(-1)
ip.readline(-1)
self.assertEqual([mock.call(b'ab\n', False),
mock.call(b'c', True)], mocked.call_args_list)
def test_chunk_update_modifies_chunk(self):
ip = utils.InputProxy(io.BytesIO(b'abc'))
with mock.patch.object(ip, 'chunk_update', return_value='modified'):
actual = ip.read()
self.assertEqual('modified', actual)
def test_read_client_disconnect(self):
fake = mock.MagicMock()
fake.read = mock.MagicMock(side_effect=ValueError('boom'))
ip = utils.InputProxy(fake)
with self.assertRaises(ValueError) as cm:
ip.read()
self.assertTrue(ip.client_disconnect)
self.assertEqual('boom', str(cm.exception))
def test_readline_client_disconnect(self):
fake = mock.MagicMock()
fake.readline = mock.MagicMock(side_effect=ValueError('boom'))
ip = utils.InputProxy(fake)
with self.assertRaises(ValueError) as cm:
ip.readline()
self.assertTrue(ip.client_disconnect)
self.assertEqual('boom', str(cm.exception))
class UnsafeXrange(object): class UnsafeXrange(object):
""" """
Like range(limit), but with extra context switching to screw things up. Like range(limit), but with extra context switching to screw things up.