Merge "Correct read function in _UWSGIChunkFile not to exceed length"

This commit is contained in:
Zuul
2025-12-03 08:54:30 +00:00
committed by Gerrit Code Review
2 changed files with 81 additions and 11 deletions

View File

@@ -38,6 +38,7 @@ from oslo_log import log as logging
from oslo_serialization import jsonutils
from oslo_utils import encodeutils
from oslo_utils import strutils
from oslo_utils import units
from osprofiler import opts as profiler_opts
import routes.middleware
import webob.dec
@@ -894,28 +895,40 @@ class Router(object):
class _UWSGIChunkFile(object):
"""
A file-like object for reading uWSGI chunked requests, with internal
buffering/caching of excess data for subsequent reads.
"""
def __init__(self):
# Buffer to cache data read in excess of the requested length
self._buffer = b""
def read(self, length=None):
position = 0
"""
Reads up to 'length' bytes from the chunked request stream.
Caches any excess data internally.
"""
if length == 0:
return b""
# If length is negative, treat it as reading until the end of the file.
if length and length < 0:
length = None
response = []
while True:
# If no length is provided, choose some sane minimum default
length = length if length is not None else 1 * units.Mi
while len(self._buffer) < length:
data = uwsgi.chunked_read()
# Return everything if we reached the end of the file
if not data:
break
response.append(data)
# Return the data if we've reached the length
if length is not None:
position += len(data)
if position >= length:
break
return b''.join(response)
# append the buffer
self._buffer += data
chunk = self._buffer[:length]
self._buffer = self._buffer[length:]
return chunk
class Request(webob.Request):

View File

@@ -833,3 +833,60 @@ class Test_UwsgiChunkedFile(test_utils.BaseTestCase):
wsgi.uwsgi.chunked_read = fake_read
out = reader.read(length=-2)
self.assertEqual(out, b'abc')
def test_read_data_length_with_overshoot(self):
reader = wsgi._UWSGIChunkFile()
wsgi.uwsgi = mock.MagicMock()
self.addCleanup(_cleanup_uwsgi)
values_read_count = 0
values_read_count_prev = 0
values = iter([b'a', b'bcd', b'e', b'fg', b'h', None, None])
def fake_read():
nonlocal values_read_count
values_read_count += 1
return next(values)
def values_read_count_get():
nonlocal values_read_count, values_read_count_prev
res = values_read_count - values_read_count_prev
values_read_count_prev = values_read_count
return res
wsgi.uwsgi.chunked_read = fake_read
out = reader.read(length=2)
# empty buffer case
self.assertEqual(out, b'ab')
self.assertEqual(values_read_count_get(), 2)
# buffer contains more - no extra read
out = reader.read(length=1)
self.assertEqual(out, b'c')
self.assertEqual(values_read_count_get(), 0)
# buffer contains exactly what we need - no extra read
out = reader.read(length=1)
self.assertEqual(out, b'd')
self.assertEqual(values_read_count_get(), 0)
# buffer is empty + 1st read returns less - must read twice
out = reader.read(length=2)
self.assertEqual(out, b'ef')
self.assertEqual(values_read_count_get(), 2)
# buffer isn't empty, but not enough - must read
out = reader.read(length=2)
self.assertEqual(out, b'gh')
self.assertEqual(values_read_count_get(), 1)
# eof case
out = reader.read(length=2)
self.assertEqual(out, b'')
self.assertEqual(values_read_count_get(), 1)
# eof case when request till eof
out = reader.read(length=-2)
self.assertEqual(out, b'')
self.assertEqual(values_read_count_get(), 1)