Add ClosingIterator class; be more explicit about closes

... in document_iters_to_http_response_body.

We seemed to be relying a little too heavily upon prompt garbage
collection to log client disconnects, leading to failures in
test_base.py::TestGetOrHeadHandler::test_disconnected_logging
under python 3.12.

Closes-Bug: #2046352
Co-Authored-By: Alistair Coles <alistairncoles@gmail.com>
Change-Id: I4479d2690f708312270eb92759789ddce7f7f930
This commit is contained in:
Tim Burke 2024-01-04 05:06:32 +00:00 committed by Alistair Coles
parent afe31b4c01
commit c522f5676e
4 changed files with 370 additions and 57 deletions

View File

@ -3687,27 +3687,66 @@ def csv_append(csv_string, item):
return item
class CloseableChain(object):
class ClosingIterator(object):
"""
Wrap another iterator and close it, if possible, on completion/exception.
If other closeable objects are given then they will also be closed when
this iterator is closed.
This is particularly useful for ensuring a generator properly closes its
resources, even if the generator was never started.
This class may be subclassed to override the behavior of
``_get_next_item``.
:param iterable: iterator to wrap.
:param other_closeables: other resources to attempt to close.
"""
__slots__ = ('closeables', 'wrapped_iter', 'closed')
def __init__(self, iterable, other_closeables=None):
self.closeables = [iterable]
if other_closeables:
self.closeables.extend(other_closeables)
# this is usually, but not necessarily, the same object
self.wrapped_iter = iter(iterable)
self.closed = False
def __iter__(self):
return self
def _get_next_item(self):
return next(self.wrapped_iter)
def __next__(self):
try:
return self._get_next_item()
except Exception:
# note: if wrapped_iter is a generator then the exception
# already caused it to exit (without raising a GeneratorExit)
# but we still need to close any other closeables.
self.close()
raise
next = __next__ # py2
def close(self):
if not self.closed:
for wrapped in self.closeables:
close_if_possible(wrapped)
self.closed = True
class CloseableChain(ClosingIterator):
"""
Like itertools.chain, but with a close method that will attempt to invoke
its sub-iterators' close methods, if any.
"""
def __init__(self, *iterables):
self.iterables = iterables
self.chained_iter = itertools.chain(*self.iterables)
def __iter__(self):
return self
def __next__(self):
return next(self.chained_iter)
next = __next__ # py2
def close(self):
for it in self.iterables:
close_if_possible(it)
chained_iter = itertools.chain(*iterables)
super(CloseableChain, self).__init__(chained_iter, iterables)
def reiterate(iterable):
@ -4396,6 +4435,47 @@ def document_iters_to_multipart_byteranges(ranges_iter, boundary):
yield terminator
class StringAlong(ClosingIterator):
"""
This iterator wraps and iterates over a first iterator until it stops, and
then iterates a second iterator, expecting it to stop immediately. This
"stringing along" of the second iterator is useful when the exit of the
second iterator must be delayed until the first iterator has stopped. For
example, when the second iterator has already yielded its item(s) but
has resources that mustn't be garbage collected until the first iterator
has stopped.
The second iterator is expected to have no more items and raise
StopIteration when called. If this is not the case then
``unexpected_items_func`` is called.
:param iterable: a first iterator that is wrapped and iterated.
:param other_iter: a second iterator that is stopped once the first
iterator has stopped.
:param unexpected_items_func: a no-arg function that will be called if the
second iterator is found to have remaining items.
"""
__slots__ = ('other_iter', 'unexpected_items_func')
def __init__(self, iterable, other_iter, unexpected_items_func):
super(StringAlong, self).__init__(iterable, [other_iter])
self.other_iter = other_iter
self.unexpected_items_func = unexpected_items_func
def _get_next_item(self):
try:
return super(StringAlong, self)._get_next_item()
except StopIteration:
try:
next(self.other_iter)
except StopIteration:
pass
else:
self.unexpected_items_func()
finally:
raise
def document_iters_to_http_response_body(ranges_iter, boundary, multipart,
logger):
"""
@ -4445,20 +4525,11 @@ def document_iters_to_http_response_body(ranges_iter, boundary, multipart,
# ranges_iter has a finally block that calls close_swift_conn, and
# so if that finally block fires before we read response_body_iter,
# there's nothing there.
def string_along(useful_iter, useless_iter_iter, logger):
with closing_if_possible(useful_iter):
for x in useful_iter:
yield x
try:
next(useless_iter_iter)
except StopIteration:
pass
else:
logger.warning(
"More than one part in a single-part response?")
return string_along(response_body_iter, ranges_iter, logger)
result = StringAlong(
response_body_iter, ranges_iter,
lambda: logger.warning(
"More than one part in a single-part response?"))
return result
def multipart_byteranges_to_document_iters(input_file, boundary,
@ -6430,7 +6501,7 @@ class WatchdogTimeout(object):
self.watchdog.stop(self.key)
class CooperativeIterator(object):
class CooperativeIterator(ClosingIterator):
"""
Wrapper to make a deliberate periodic call to ``sleep()`` while iterating
over wrapped iterator, providing an opportunity to switch greenthreads.
@ -6452,24 +6523,16 @@ class CooperativeIterator(object):
:param period: number of items yielded from this iterator between calls to
``sleep()``.
"""
__slots__ = ('period', 'count', 'wrapped_iter')
__slots__ = ('period', 'count')
def __init__(self, iterable, period=5):
self.wrapped_iter = iterable
super(CooperativeIterator, self).__init__(iterable)
self.count = 0
self.period = period
def __iter__(self):
return self
def next(self):
def _get_next_item(self):
if self.count >= self.period:
self.count = 0
sleep()
self.count += 1
return next(self.wrapped_iter)
__next__ = next
def close(self):
close_if_possible(self.wrapped_iter)
return super(CooperativeIterator, self)._get_next_item()

View File

@ -53,7 +53,7 @@ from swift.common import storage_policy, swob, utils, exceptions
from swift.common.memcached import MemcacheConnectionError
from swift.common.storage_policy import (StoragePolicy, ECStoragePolicy,
VALID_EC_TYPES)
from swift.common.utils import Timestamp, md5
from swift.common.utils import Timestamp, md5, close_if_possible
from test import get_config
from test.debug_logger import FakeLogger
from swift.common.header_key_dict import HeaderKeyDict
@ -1499,6 +1499,70 @@ class FakeSource(object):
[(k, v) for k, v in self.headers.items()]
class CaptureIterator(object):
"""
Wraps an iterable, forwarding all calls to the wrapped iterable but
capturing the calls via a callback.
This class may be used to observe garbage collection, so tests should not
have to hold a reference to instances of this class because that would
prevent them being garbage collected. Calls are therefore captured via a
callback rather than being stashed locally.
:param wrapped: an iterable to wrap.
:param call_capture_callback: a function that will be called to capture
calls to this iterator.
"""
def __init__(self, wrapped, call_capture_callback):
self.call_capture_callback = call_capture_callback
self.wrapped_iter = wrapped
def _capture_call(self):
# call home to capture the call
self.call_capture_callback(inspect.stack()[1][3])
def __iter__(self):
return self
def next(self):
self._capture_call()
return next(self.wrapped_iter)
__next__ = next
def __del__(self):
self._capture_call()
def close(self):
self._capture_call()
close_if_possible(self.wrapped_iter)
class CaptureIteratorFactory(object):
"""
Create instances of ``CaptureIterator`` to wrap a given iterable, and
provides a callback function for the ``CaptureIterator`` to capture its
calls.
:param wrapped: an iterable to wrap.
"""
def __init__(self, wrapped):
self.wrapped = wrapped
self.instance_count = 0
self.captured_calls = defaultdict(list)
def log_call(self, instance_number, call):
self.captured_calls[instance_number].append(call)
def __call__(self, *args, **kwargs):
# note: do not keep a reference to the CaptureIterator because that
# would prevent it being garbage collected
self.instance_count += 1
return CaptureIterator(
self.wrapped(*args, **kwargs),
functools.partial(self.log_call, self.instance_count))
def get_node_error_stats(proxy_app, ring_node):
node_key = proxy_app.error_limiter.node_key(ring_node)
return proxy_app.error_limiter.stats.get(node_key) or {}

View File

@ -6861,20 +6861,34 @@ class FakeResponse(object):
class TestDocumentItersToHTTPResponseBody(unittest.TestCase):
def test_no_parts(self):
logger = debug_logger()
body = utils.document_iters_to_http_response_body(
iter([]), 'dontcare',
multipart=False, logger=debug_logger())
iter([]), 'dontcare', multipart=False, logger=logger)
self.assertEqual(body, '')
self.assertFalse(logger.all_log_lines())
def test_single_part(self):
body = b"time flies like an arrow; fruit flies like a banana"
doc_iters = [{'part_iter': iter(BytesIO(body).read, b'')}]
logger = debug_logger()
resp_body = b''.join(
utils.document_iters_to_http_response_body(
iter(doc_iters), b'dontcare',
multipart=False, logger=debug_logger()))
iter(doc_iters), b'dontcare', multipart=False, logger=logger))
self.assertEqual(resp_body, body)
self.assertFalse(logger.all_log_lines())
def test_single_part_unexpected_ranges(self):
body = b"time flies like an arrow; fruit flies like a banana"
doc_iters = [{'part_iter': iter(BytesIO(body).read, b'')}, 'junk']
logger = debug_logger()
resp_body = b''.join(
utils.document_iters_to_http_response_body(
iter(doc_iters), b'dontcare', multipart=False, logger=logger))
self.assertEqual(resp_body, body)
self.assertEqual(['More than one part in a single-part response?'],
logger.get_lines_for_level('warning'))
def test_multiple_parts(self):
part1 = b"two peanuts were walking down a railroad track"
@ -6915,7 +6929,6 @@ class TestDocumentItersToHTTPResponseBody(unittest.TestCase):
b"--boundaryboundary--"))
def test_closed_part_iterator(self):
print('test')
useful_iter_mock = mock.MagicMock()
useful_iter_mock.__iter__.return_value = ['']
body_iter = utils.document_iters_to_http_response_body(
@ -9563,6 +9576,138 @@ class TestReiterate(unittest.TestCase):
self.assertIs(test_tuple, reiterated)
class TestClosingIterator(unittest.TestCase):
def _make_gen(self, items, captured_exit):
def gen():
try:
for it in items:
if isinstance(it, Exception):
raise it
yield it
except GeneratorExit as e:
captured_exit.append(e)
raise
return gen()
def test_close(self):
wrapped = FakeIterable([1, 2, 3])
# note: iter(FakeIterable) is the same object
self.assertIs(wrapped, iter(wrapped))
it = utils.ClosingIterator(wrapped)
actual = [x for x in it]
self.assertEqual([1, 2, 3], actual)
self.assertEqual(1, wrapped.close_call_count)
it.close()
self.assertEqual(1, wrapped.close_call_count)
def test_close_others(self):
wrapped = FakeIterable([1, 2, 3])
others = [FakeIterable([4, 5, 6]), FakeIterable([])]
self.assertIs(wrapped, iter(wrapped))
it = utils.ClosingIterator(wrapped, others)
actual = [x for x in it]
self.assertEqual([1, 2, 3], actual)
self.assertEqual([1, 1, 1],
[i.close_call_count for i in others + [wrapped]])
it.close()
self.assertEqual([1, 1, 1],
[i.close_call_count for i in others + [wrapped]])
def test_close_gen(self):
# explicitly check generator closing
captured_exit = []
gen = self._make_gen([1, 2], captured_exit)
it = utils.ClosingIterator(gen)
self.assertFalse(captured_exit)
it.close()
self.assertFalse(captured_exit) # the generator didn't start
captured_exit = []
gen = self._make_gen([1, 2], captured_exit)
it = utils.ClosingIterator(gen)
self.assertFalse(captured_exit)
self.assertEqual(1, next(it)) # start the generator
it.close()
self.assertEqual(1, len(captured_exit))
def test_close_wrapped_is_not_same_as_iter(self):
class AltFakeIterable(FakeIterable):
def __iter__(self):
return (x for x in self.values)
wrapped = AltFakeIterable([1, 2, 3])
# note: iter(AltFakeIterable) is a generator, not the same object
self.assertIsNot(wrapped, iter(wrapped))
it = utils.ClosingIterator(wrapped)
actual = [x for x in it]
self.assertEqual([1, 2, 3], actual)
self.assertEqual(1, wrapped.close_call_count)
it.close()
self.assertEqual(1, wrapped.close_call_count)
def test_init_with_iterable(self):
wrapped = [1, 2, 3] # list is iterable but not an iterator
it = utils.ClosingIterator(wrapped)
actual = [x for x in it]
self.assertEqual([1, 2, 3], actual)
it.close() # safe to call even though list has no close
def test_nested_iters(self):
wrapped = FakeIterable([1, 2, 3])
it = utils.ClosingIterator(utils.ClosingIterator(wrapped))
actual = [x for x in it]
self.assertEqual([1, 2, 3], actual)
self.assertEqual(1, wrapped.close_call_count)
it.close()
self.assertEqual(1, wrapped.close_call_count)
def test_close_on_stop_iteration(self):
wrapped = FakeIterable([1, 2, 3])
others = [FakeIterable([4, 5, 6]), FakeIterable([])]
self.assertIs(wrapped, iter(wrapped))
it = utils.ClosingIterator(wrapped, others)
actual = [x for x in it]
self.assertEqual([1, 2, 3], actual)
self.assertEqual([1, 1, 1],
[i.close_call_count for i in others + [wrapped]])
it.close()
self.assertEqual([1, 1, 1],
[i.close_call_count for i in others + [wrapped]])
def test_close_on_exception(self):
# sanity check: generator exits on raising exception without executing
# GeneratorExit
captured_exit = []
gen = self._make_gen([1, ValueError(), 2], captured_exit)
self.assertEqual(1, next(gen))
with self.assertRaises(ValueError):
next(gen)
self.assertFalse(captured_exit)
gen.close()
self.assertFalse(captured_exit) # gen already exited
captured_exit = []
gen = self._make_gen([1, ValueError(), 2], captured_exit)
self.assertEqual(1, next(gen))
with self.assertRaises(ValueError):
next(gen)
self.assertFalse(captured_exit)
with self.assertRaises(StopIteration):
next(gen) # gen already exited
# wrapped gen does the same...
captured_exit = []
gen = self._make_gen([1, ValueError(), 2], captured_exit)
others = [FakeIterable([4, 5, 6]), FakeIterable([])]
it = utils.ClosingIterator(gen, others)
self.assertEqual(1, next(it))
with self.assertRaises(ValueError):
next(it)
self.assertFalse(captured_exit)
# but other iters are closed :)
self.assertEqual([1, 1], [i.close_call_count for i in others])
class TestCloseableChain(unittest.TestCase):
def test_closeable_chain_iterates(self):
test_iter1 = FakeIterable([1])
@ -9619,6 +9764,39 @@ class TestCloseableChain(unittest.TestCase):
self.assertTrue(generator_closed[0])
class TestStringAlong(unittest.TestCase):
def test_happy(self):
logger = debug_logger()
it = FakeIterable([1, 2, 3])
other_it = FakeIterable([])
string_along = utils.StringAlong(
it, other_it, lambda: logger.warning('boom'))
for i, x in enumerate(string_along):
self.assertEqual(i + 1, x)
self.assertEqual(0, other_it.next_call_count, x)
self.assertEqual(0, other_it.close_call_count, x)
self.assertEqual(1, other_it.next_call_count, x)
self.assertEqual(1, other_it.close_call_count, x)
lines = logger.get_lines_for_level('warning')
self.assertFalse(lines)
def test_unhappy(self):
logger = debug_logger()
it = FakeIterable([1, 2, 3])
other_it = FakeIterable([1])
string_along = utils.StringAlong(
it, other_it, lambda: logger.warning('boom'))
for i, x in enumerate(string_along):
self.assertEqual(i + 1, x)
self.assertEqual(0, other_it.next_call_count, x)
self.assertEqual(0, other_it.close_call_count, x)
self.assertEqual(1, other_it.next_call_count, x)
self.assertEqual(1, other_it.close_call_count, x)
lines = logger.get_lines_for_level('warning')
self.assertEqual(1, len(lines))
self.assertIn('boom', lines[0])
class TestCooperativeIterator(unittest.TestCase):
def test_init(self):
wrapped = itertools.count()

View File

@ -42,7 +42,7 @@ from swift.common.storage_policy import StoragePolicy, StoragePolicyCollection
from test.debug_logger import debug_logger
from test.unit import (
fake_http_connect, FakeRing, FakeMemcache, PatchPolicies, patch_policies,
FakeSource, StubResponse)
FakeSource, StubResponse, CaptureIteratorFactory)
from swift.common.request_helpers import (
get_sys_meta_prefix, get_object_transient_sysmeta
)
@ -1706,10 +1706,14 @@ class TestGetOrHeadHandler(BaseTest):
handler.source = GetterSource(self.app, source, node)
return True
with mock.patch.object(handler, '_find_source',
mock_find_source):
resp = handler.get_working_response(req)
resp.app_iter.close()
factory = CaptureIteratorFactory(handler._iter_parts_from_response)
with mock.patch.object(handler, '_find_source', mock_find_source):
with mock.patch.object(
handler, '_iter_parts_from_response', factory):
resp = handler.get_working_response(req)
resp.app_iter.close()
# verify that iter exited
self.assertEqual({1: ['next', '__del__']}, factory.captured_calls)
self.assertEqual(["Client disconnected on read of 'some-path'"],
self.logger.get_lines_for_level('info'))
@ -1719,12 +1723,16 @@ class TestGetOrHeadHandler(BaseTest):
self.app, req, 'Object', Namespace(num_primary_nodes=1), None,
None, {})
with mock.patch.object(handler, '_find_source',
mock_find_source):
resp = handler.get_working_response(req)
next(resp.app_iter)
factory = CaptureIteratorFactory(handler._iter_parts_from_response)
with mock.patch.object(handler, '_find_source', mock_find_source):
with mock.patch.object(
handler, '_iter_parts_from_response', factory):
resp = handler.get_working_response(req)
next(resp.app_iter)
resp.app_iter.close()
self.assertEqual({1: ['next', '__del__']}, factory.captured_calls)
self.assertEqual([], self.logger.get_lines_for_level('warning'))
self.assertEqual([], self.logger.get_lines_for_level('info'))
def test_range_fast_forward(self):
req = Request.blank('/')