Merge "Correct read function in _UWSGIChunkFile not to exceed length"
This commit is contained in:
@@ -38,6 +38,7 @@ from oslo_log import log as logging
|
||||
from oslo_serialization import jsonutils
|
||||
from oslo_utils import encodeutils
|
||||
from oslo_utils import strutils
|
||||
from oslo_utils import units
|
||||
from osprofiler import opts as profiler_opts
|
||||
import routes.middleware
|
||||
import webob.dec
|
||||
@@ -894,28 +895,40 @@ class Router(object):
|
||||
|
||||
|
||||
class _UWSGIChunkFile(object):
|
||||
"""
|
||||
A file-like object for reading uWSGI chunked requests, with internal
|
||||
buffering/caching of excess data for subsequent reads.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
# Buffer to cache data read in excess of the requested length
|
||||
self._buffer = b""
|
||||
|
||||
def read(self, length=None):
|
||||
position = 0
|
||||
"""
|
||||
Reads up to 'length' bytes from the chunked request stream.
|
||||
Caches any excess data internally.
|
||||
"""
|
||||
if length == 0:
|
||||
return b""
|
||||
|
||||
# If length is negative, treat it as reading until the end of the file.
|
||||
if length and length < 0:
|
||||
length = None
|
||||
|
||||
response = []
|
||||
while True:
|
||||
# If no length is provided, choose some sane minimum default
|
||||
length = length if length is not None else 1 * units.Mi
|
||||
|
||||
while len(self._buffer) < length:
|
||||
data = uwsgi.chunked_read()
|
||||
# Return everything if we reached the end of the file
|
||||
if not data:
|
||||
break
|
||||
response.append(data)
|
||||
# Return the data if we've reached the length
|
||||
if length is not None:
|
||||
position += len(data)
|
||||
if position >= length:
|
||||
break
|
||||
return b''.join(response)
|
||||
# append the buffer
|
||||
self._buffer += data
|
||||
|
||||
chunk = self._buffer[:length]
|
||||
self._buffer = self._buffer[length:]
|
||||
return chunk
|
||||
|
||||
|
||||
class Request(webob.Request):
|
||||
|
||||
@@ -833,3 +833,60 @@ class Test_UwsgiChunkedFile(test_utils.BaseTestCase):
|
||||
wsgi.uwsgi.chunked_read = fake_read
|
||||
out = reader.read(length=-2)
|
||||
self.assertEqual(out, b'abc')
|
||||
|
||||
def test_read_data_length_with_overshoot(self):
|
||||
reader = wsgi._UWSGIChunkFile()
|
||||
wsgi.uwsgi = mock.MagicMock()
|
||||
self.addCleanup(_cleanup_uwsgi)
|
||||
|
||||
values_read_count = 0
|
||||
values_read_count_prev = 0
|
||||
values = iter([b'a', b'bcd', b'e', b'fg', b'h', None, None])
|
||||
|
||||
def fake_read():
|
||||
nonlocal values_read_count
|
||||
values_read_count += 1
|
||||
return next(values)
|
||||
|
||||
def values_read_count_get():
|
||||
nonlocal values_read_count, values_read_count_prev
|
||||
res = values_read_count - values_read_count_prev
|
||||
values_read_count_prev = values_read_count
|
||||
return res
|
||||
|
||||
wsgi.uwsgi.chunked_read = fake_read
|
||||
out = reader.read(length=2)
|
||||
|
||||
# empty buffer case
|
||||
self.assertEqual(out, b'ab')
|
||||
self.assertEqual(values_read_count_get(), 2)
|
||||
|
||||
# buffer contains more - no extra read
|
||||
out = reader.read(length=1)
|
||||
self.assertEqual(out, b'c')
|
||||
self.assertEqual(values_read_count_get(), 0)
|
||||
|
||||
# buffer contains exactly what we need - no extra read
|
||||
out = reader.read(length=1)
|
||||
self.assertEqual(out, b'd')
|
||||
self.assertEqual(values_read_count_get(), 0)
|
||||
|
||||
# buffer is empty + 1st read returns less - must read twice
|
||||
out = reader.read(length=2)
|
||||
self.assertEqual(out, b'ef')
|
||||
self.assertEqual(values_read_count_get(), 2)
|
||||
|
||||
# buffer isn't empty, but not enough - must read
|
||||
out = reader.read(length=2)
|
||||
self.assertEqual(out, b'gh')
|
||||
self.assertEqual(values_read_count_get(), 1)
|
||||
|
||||
# eof case
|
||||
out = reader.read(length=2)
|
||||
self.assertEqual(out, b'')
|
||||
self.assertEqual(values_read_count_get(), 1)
|
||||
|
||||
# eof case when request till eof
|
||||
out = reader.read(length=-2)
|
||||
self.assertEqual(out, b'')
|
||||
self.assertEqual(values_read_count_get(), 1)
|
||||
|
||||
Reference in New Issue
Block a user