fix(Request): Make wsgi.input wrapper more robust

Fix two bugs with the wsgi.input wrapper that is used to normalize
behavior between wsgiref and production-class web servers:

1. Do not hang when reading with a size < Content-Length but the
   stream has already been consumed.
2. If Content-Length is invalid or missing, assume no content and
   still wrap the stream, rather than sometimes wrapping and sometimes
   not.

Also, improve the tests to more cleanly shut down the wsgiref server
between tests so that the above can be tested in a more modular
fashion.
This commit is contained in:
Kurt Griffiths
2015-10-02 16:23:57 -05:00
parent 0fa3c5771a
commit 0666ec8b8f
4 changed files with 96 additions and 36 deletions

View File

@@ -1055,15 +1055,14 @@ class Request(object):
def _wrap_stream(self): # pragma nocover
try:
# NOTE(kgriffs): We can only add the wrapper if the
# content-length header was provided.
if self.content_length is not None:
self.stream = helpers.Body(self.stream, self.content_length)
content_length = self.content_length or 0
except HTTPInvalidHeader:
# NOTE(kgriffs): The content-length header was specified,
# but it had an invalid value.
pass
# but it had an invalid value. Assume no content.
content_length = 0
self.stream = helpers.Body(self.stream, content_length)
def _parse_form_urlencoded(self):
# NOTE(kgriffs): This assumes self.stream has been patched

View File

@@ -58,6 +58,8 @@ class Body(object):
self.stream = stream
self.stream_len = stream_len
self._bytes_remaining = self.stream_len
def __iter__(self):
return self
@@ -83,9 +85,13 @@ class Body(object):
"""
if size is None or size == -1 or size > self.stream_len:
size = self.stream_len
# NOTE(kgriffs): Default to reading all remaining bytes if the
# size is not specified or is out of bounds. This behaves
# similarly to the IO streams passed in by non-wsgiref servers.
if (size is None or size == -1 or size > self._bytes_remaining):
size = self._bytes_remaining
self._bytes_remaining -= size
return target(size)
def read(self, size=None):

View File

@@ -127,6 +127,14 @@ class TestRequestBody(testing.TestBase):
body = request_helpers.Body(stream, expected_len)
self.assertEqual(body.read(expected_len + 1), expected_body)
# NOTE(kgriffs): Test that reading past the end does not
# hang, but returns the empty string.
stream = io.BytesIO(expected_body)
body = request_helpers.Body(stream, expected_len)
for i in range(expected_len + 1):
expected_value = expected_body[i:i + 1] if i < expected_len else b''
self.assertEqual(body.read(1), expected_value)
stream = io.BytesIO(expected_body)
body = request_helpers.Body(stream, expected_len)
self.assertEqual(body.readline(), expected_lines[0])

View File

@@ -1,15 +1,20 @@
import sys
import threading
import time
from wsgiref.simple_server import make_server
if 'java' not in sys.platform:
import multiprocessing
import requests
from six.moves import http_client
import testtools
from testtools.matchers import Equals, MatchesRegex
import falcon
import falcon.testing as testing
_SERVER_HOST = 'localhost'
_SERVER_PORT = 9809
_SERVER_BASE_URL = 'http://{0}:{1}/'.format(_SERVER_HOST, _SERVER_PORT)
def _is_iterable(thing):
try:
@@ -21,23 +26,32 @@ def _is_iterable(thing):
return False
def _run_server():
def _run_server(stop_event):
class Things(object):
def on_get(self, req, resp):
pass
def on_post(self, req, resp):
resp.body = req.stream.read(1000)
def on_put(self, req, resp):
pass
# NOTE(kgriffs): Test that reading past the end does
# not hang.
req_body = (req.stream.read(1)
for i in range(req.content_length + 1))
resp.body = b''.join(req_body)
api = application = falcon.API()
api.add_route('/', Things())
from wsgiref.simple_server import make_server
server = make_server('localhost', 9803, application)
server.serve_forever()
server = make_server(_SERVER_HOST, _SERVER_PORT, application)
while not stop_event.is_set():
server.handle_request()
class TestWsgi(testtools.TestCase):
class TestWSGIInterface(testing.TestBase):
def test_srmock(self):
mock = testing.StartResponseMock()
@@ -76,30 +90,63 @@ class TestWsgi(testtools.TestCase):
self.assertTrue(isinstance(header[0], str))
self.assertTrue(isinstance(header[1], str))
def test_wsgiref(self):
thread = threading.Thread(target=_run_server)
thread.daemon = True
thread.start()
# Wait a moment for the thread to start up
class TestWSGIReference(testing.TestBase):
def before(self):
if 'java' in sys.platform:
# NOTE(kgriffs): Jython does not support the multiprocessing
# module. We could alternatively implement these tests
# using threads, but then we have to force a garbage
# collection in between each test in order to make
# the server relinquish its socket, and the gc module
# doesn't appear to do anything under Jython.
self.skip('Incompatible with Jython')
self._stop_event = multiprocessing.Event()
self._process = multiprocessing.Process(target=_run_server,
args=(self._stop_event,))
self._process.start()
# NOTE(kgriffs): Let the server start up
time.sleep(0.2)
resp = requests.get('http://localhost:9803')
def after(self):
self._stop_event.set()
# NOTE(kgriffs): Pump the request handler loop in case execution
# made it to the next server.handle_request() before we sent the
# event.
try:
requests.get(_SERVER_BASE_URL)
except Exception:
pass # Thread already exited
self._process.join()
def test_wsgiref_get(self):
resp = requests.get(_SERVER_BASE_URL)
self.assertEqual(resp.status_code, 200)
# NOTE(kgriffs): This will cause a different path to
# be taken in req._wrap_stream. Have to use httplib
# to prevent the invalid header value from being
# forced to "0".
conn = http_client.HTTPConnection('localhost', 9803)
headers = {'Content-Length': 'invalid'}
conn.request('PUT', '/', headers=headers)
resp = conn.getresponse()
self.assertEqual(resp.status, 200)
headers = {'Content-Length': '0'}
resp = requests.put('http://localhost:9803', headers=headers)
def test_wsgiref_put(self):
body = '{}'
resp = requests.put(_SERVER_BASE_URL, data=body)
self.assertEqual(resp.status_code, 200)
self.assertEqual(resp.text, '{}')
resp = requests.post('http://localhost:9803')
def test_wsgiref_head_405(self):
body = '{}'
resp = requests.head(_SERVER_BASE_URL, data=body)
self.assertEqual(resp.status_code, 405)
def test_wsgiref_post(self):
body = '{}'
resp = requests.post(_SERVER_BASE_URL, data=body)
self.assertEqual(resp.status_code, 200)
self.assertEqual(resp.text, '{}')
def test_wsgiref_post_invalid_content_length(self):
headers = {'Content-Length': 'invalid'}
resp = requests.post(_SERVER_BASE_URL, headers=headers)
self.assertEqual(resp.status_code, 200)
self.assertEqual(resp.text, '')