fix(Request): Make wsgi.input wrapper more robust
Fix two bugs with the wsgi.input wrapper that is used to normalize behavior between wsgiref and production-class web servers: 1. Do not hang when reading with a size < Content-Length but the stream has already been consumed. 2. If Content-Length is invalid or missing, assume no content and still wrap the stream, rather than sometimes wrapping and sometimes not. Also, improve the tests to more cleanly shut down the wsgiref server between tests so that the above can be tested in a more modular fashion.
This commit is contained in:
@@ -1055,15 +1055,14 @@ class Request(object):
|
||||
|
||||
def _wrap_stream(self): # pragma nocover
|
||||
try:
|
||||
# NOTE(kgriffs): We can only add the wrapper if the
|
||||
# content-length header was provided.
|
||||
if self.content_length is not None:
|
||||
self.stream = helpers.Body(self.stream, self.content_length)
|
||||
content_length = self.content_length or 0
|
||||
|
||||
except HTTPInvalidHeader:
|
||||
# NOTE(kgriffs): The content-length header was specified,
|
||||
# but it had an invalid value.
|
||||
pass
|
||||
# but it had an invalid value. Assume no content.
|
||||
content_length = 0
|
||||
|
||||
self.stream = helpers.Body(self.stream, content_length)
|
||||
|
||||
def _parse_form_urlencoded(self):
|
||||
# NOTE(kgriffs): This assumes self.stream has been patched
|
||||
|
||||
@@ -58,6 +58,8 @@ class Body(object):
|
||||
self.stream = stream
|
||||
self.stream_len = stream_len
|
||||
|
||||
self._bytes_remaining = self.stream_len
|
||||
|
||||
def __iter__(self):
|
||||
return self
|
||||
|
||||
@@ -83,9 +85,13 @@ class Body(object):
|
||||
|
||||
"""
|
||||
|
||||
if size is None or size == -1 or size > self.stream_len:
|
||||
size = self.stream_len
|
||||
# NOTE(kgriffs): Default to reading all remaining bytes if the
|
||||
# size is not specified or is out of bounds. This behaves
|
||||
# similarly to the IO streams passed in by non-wsgiref servers.
|
||||
if (size is None or size == -1 or size > self._bytes_remaining):
|
||||
size = self._bytes_remaining
|
||||
|
||||
self._bytes_remaining -= size
|
||||
return target(size)
|
||||
|
||||
def read(self, size=None):
|
||||
|
||||
@@ -127,6 +127,14 @@ class TestRequestBody(testing.TestBase):
|
||||
body = request_helpers.Body(stream, expected_len)
|
||||
self.assertEqual(body.read(expected_len + 1), expected_body)
|
||||
|
||||
# NOTE(kgriffs): Test that reading past the end does not
|
||||
# hang, but returns the empty string.
|
||||
stream = io.BytesIO(expected_body)
|
||||
body = request_helpers.Body(stream, expected_len)
|
||||
for i in range(expected_len + 1):
|
||||
expected_value = expected_body[i:i + 1] if i < expected_len else b''
|
||||
self.assertEqual(body.read(1), expected_value)
|
||||
|
||||
stream = io.BytesIO(expected_body)
|
||||
body = request_helpers.Body(stream, expected_len)
|
||||
self.assertEqual(body.readline(), expected_lines[0])
|
||||
|
||||
@@ -1,15 +1,20 @@
|
||||
import sys
|
||||
import threading
|
||||
import time
|
||||
from wsgiref.simple_server import make_server
|
||||
|
||||
if 'java' not in sys.platform:
|
||||
import multiprocessing
|
||||
|
||||
import requests
|
||||
from six.moves import http_client
|
||||
import testtools
|
||||
from testtools.matchers import Equals, MatchesRegex
|
||||
|
||||
import falcon
|
||||
import falcon.testing as testing
|
||||
|
||||
_SERVER_HOST = 'localhost'
|
||||
_SERVER_PORT = 9809
|
||||
_SERVER_BASE_URL = 'http://{0}:{1}/'.format(_SERVER_HOST, _SERVER_PORT)
|
||||
|
||||
|
||||
def _is_iterable(thing):
|
||||
try:
|
||||
@@ -21,23 +26,32 @@ def _is_iterable(thing):
|
||||
return False
|
||||
|
||||
|
||||
def _run_server():
|
||||
def _run_server(stop_event):
|
||||
class Things(object):
|
||||
def on_get(self, req, resp):
|
||||
pass
|
||||
|
||||
def on_post(self, req, resp):
|
||||
resp.body = req.stream.read(1000)
|
||||
|
||||
def on_put(self, req, resp):
|
||||
pass
|
||||
# NOTE(kgriffs): Test that reading past the end does
|
||||
# not hang.
|
||||
req_body = (req.stream.read(1)
|
||||
for i in range(req.content_length + 1))
|
||||
|
||||
resp.body = b''.join(req_body)
|
||||
|
||||
api = application = falcon.API()
|
||||
api.add_route('/', Things())
|
||||
|
||||
from wsgiref.simple_server import make_server
|
||||
server = make_server('localhost', 9803, application)
|
||||
server.serve_forever()
|
||||
server = make_server(_SERVER_HOST, _SERVER_PORT, application)
|
||||
|
||||
while not stop_event.is_set():
|
||||
server.handle_request()
|
||||
|
||||
|
||||
class TestWsgi(testtools.TestCase):
|
||||
class TestWSGIInterface(testing.TestBase):
|
||||
|
||||
def test_srmock(self):
|
||||
mock = testing.StartResponseMock()
|
||||
@@ -76,30 +90,63 @@ class TestWsgi(testtools.TestCase):
|
||||
self.assertTrue(isinstance(header[0], str))
|
||||
self.assertTrue(isinstance(header[1], str))
|
||||
|
||||
def test_wsgiref(self):
|
||||
thread = threading.Thread(target=_run_server)
|
||||
thread.daemon = True
|
||||
thread.start()
|
||||
|
||||
# Wait a moment for the thread to start up
|
||||
class TestWSGIReference(testing.TestBase):
|
||||
|
||||
def before(self):
|
||||
if 'java' in sys.platform:
|
||||
# NOTE(kgriffs): Jython does not support the multiprocessing
|
||||
# module. We could alternatively implement these tests
|
||||
# using threads, but then we have to force a garbage
|
||||
# collection in between each test in order to make
|
||||
# the server relinquish its socket, and the gc module
|
||||
# doesn't appear to do anything under Jython.
|
||||
self.skip('Incompatible with Jython')
|
||||
|
||||
self._stop_event = multiprocessing.Event()
|
||||
self._process = multiprocessing.Process(target=_run_server,
|
||||
args=(self._stop_event,))
|
||||
self._process.start()
|
||||
|
||||
# NOTE(kgriffs): Let the server start up
|
||||
time.sleep(0.2)
|
||||
|
||||
resp = requests.get('http://localhost:9803')
|
||||
def after(self):
|
||||
self._stop_event.set()
|
||||
|
||||
# NOTE(kgriffs): Pump the request handler loop in case execution
|
||||
# made it to the next server.handle_request() before we sent the
|
||||
# event.
|
||||
try:
|
||||
requests.get(_SERVER_BASE_URL)
|
||||
except Exception:
|
||||
pass # Thread already exited
|
||||
|
||||
self._process.join()
|
||||
|
||||
def test_wsgiref_get(self):
|
||||
resp = requests.get(_SERVER_BASE_URL)
|
||||
self.assertEqual(resp.status_code, 200)
|
||||
|
||||
# NOTE(kgriffs): This will cause a different path to
|
||||
# be taken in req._wrap_stream. Have to use httplib
|
||||
# to prevent the invalid header value from being
|
||||
# forced to "0".
|
||||
conn = http_client.HTTPConnection('localhost', 9803)
|
||||
headers = {'Content-Length': 'invalid'}
|
||||
conn.request('PUT', '/', headers=headers)
|
||||
resp = conn.getresponse()
|
||||
self.assertEqual(resp.status, 200)
|
||||
|
||||
headers = {'Content-Length': '0'}
|
||||
resp = requests.put('http://localhost:9803', headers=headers)
|
||||
def test_wsgiref_put(self):
|
||||
body = '{}'
|
||||
resp = requests.put(_SERVER_BASE_URL, data=body)
|
||||
self.assertEqual(resp.status_code, 200)
|
||||
self.assertEqual(resp.text, '{}')
|
||||
|
||||
resp = requests.post('http://localhost:9803')
|
||||
def test_wsgiref_head_405(self):
|
||||
body = '{}'
|
||||
resp = requests.head(_SERVER_BASE_URL, data=body)
|
||||
self.assertEqual(resp.status_code, 405)
|
||||
|
||||
def test_wsgiref_post(self):
|
||||
body = '{}'
|
||||
resp = requests.post(_SERVER_BASE_URL, data=body)
|
||||
self.assertEqual(resp.status_code, 200)
|
||||
self.assertEqual(resp.text, '{}')
|
||||
|
||||
def test_wsgiref_post_invalid_content_length(self):
|
||||
headers = {'Content-Length': 'invalid'}
|
||||
resp = requests.post(_SERVER_BASE_URL, headers=headers)
|
||||
self.assertEqual(resp.status_code, 200)
|
||||
self.assertEqual(resp.text, '')
|
||||
|
||||
Reference in New Issue
Block a user