From 0666ec8b8fcf4cdbee840a49d0d9d1d12db287b6 Mon Sep 17 00:00:00 2001 From: Kurt Griffiths Date: Fri, 2 Oct 2015 16:23:57 -0500 Subject: [PATCH] fix(Request): Make wsgi.input wrapper more robust Fix two bugs with the wsgi.input wrapper that is used to normalize behavior between wsgiref and production-class web servers: 1. Do not hang when reading with a size < Content-Length but the stream has already been consumed. 2. If Content-Length is invalid or missing, assume no content and still wrap the stream, rather than sometimes wrapping and sometimes not. Also, improve the tests to more cleanly shut down the wsgiref server between tests so that the above can be tested in a more modular fashion. --- falcon/request.py | 11 ++-- falcon/request_helpers.py | 10 +++- tests/test_request_body.py | 8 +++ tests/test_wsgi.py | 103 +++++++++++++++++++++++++++---------- 4 files changed, 96 insertions(+), 36 deletions(-) diff --git a/falcon/request.py b/falcon/request.py index 3312667..e809bcb 100644 --- a/falcon/request.py +++ b/falcon/request.py @@ -1055,15 +1055,14 @@ class Request(object): def _wrap_stream(self): # pragma nocover try: - # NOTE(kgriffs): We can only add the wrapper if the - # content-length header was provided. - if self.content_length is not None: - self.stream = helpers.Body(self.stream, self.content_length) + content_length = self.content_length or 0 except HTTPInvalidHeader: # NOTE(kgriffs): The content-length header was specified, - # but it had an invalid value. - pass + # but it had an invalid value. Assume no content. + content_length = 0 + + self.stream = helpers.Body(self.stream, content_length) def _parse_form_urlencoded(self): # NOTE(kgriffs): This assumes self.stream has been patched diff --git a/falcon/request_helpers.py b/falcon/request_helpers.py index 3a18e2d..3c42f70 100644 --- a/falcon/request_helpers.py +++ b/falcon/request_helpers.py @@ -58,6 +58,8 @@ class Body(object): self.stream = stream self.stream_len = stream_len + self._bytes_remaining = self.stream_len + def __iter__(self): return self @@ -83,9 +85,13 @@ class Body(object): """ - if size is None or size == -1 or size > self.stream_len: - size = self.stream_len + # NOTE(kgriffs): Default to reading all remaining bytes if the + # size is not specified or is out of bounds. This behaves + # similarly to the IO streams passed in by non-wsgiref servers. + if (size is None or size == -1 or size > self._bytes_remaining): + size = self._bytes_remaining + self._bytes_remaining -= size return target(size) def read(self, size=None): diff --git a/tests/test_request_body.py b/tests/test_request_body.py index 8b47039..fef730f 100644 --- a/tests/test_request_body.py +++ b/tests/test_request_body.py @@ -127,6 +127,14 @@ class TestRequestBody(testing.TestBase): body = request_helpers.Body(stream, expected_len) self.assertEqual(body.read(expected_len + 1), expected_body) + # NOTE(kgriffs): Test that reading past the end does not + # hang, but returns the empty string. + stream = io.BytesIO(expected_body) + body = request_helpers.Body(stream, expected_len) + for i in range(expected_len + 1): + expected_value = expected_body[i:i + 1] if i < expected_len else b'' + self.assertEqual(body.read(1), expected_value) + stream = io.BytesIO(expected_body) body = request_helpers.Body(stream, expected_len) self.assertEqual(body.readline(), expected_lines[0]) diff --git a/tests/test_wsgi.py b/tests/test_wsgi.py index 7cabbc8..7d6e75a 100644 --- a/tests/test_wsgi.py +++ b/tests/test_wsgi.py @@ -1,15 +1,20 @@ import sys -import threading import time +from wsgiref.simple_server import make_server + +if 'java' not in sys.platform: + import multiprocessing import requests -from six.moves import http_client -import testtools from testtools.matchers import Equals, MatchesRegex import falcon import falcon.testing as testing +_SERVER_HOST = 'localhost' +_SERVER_PORT = 9809 +_SERVER_BASE_URL = 'http://{0}:{1}/'.format(_SERVER_HOST, _SERVER_PORT) + def _is_iterable(thing): try: @@ -21,23 +26,32 @@ def _is_iterable(thing): return False -def _run_server(): +def _run_server(stop_event): class Things(object): def on_get(self, req, resp): pass + def on_post(self, req, resp): + resp.body = req.stream.read(1000) + def on_put(self, req, resp): - pass + # NOTE(kgriffs): Test that reading past the end does + # not hang. + req_body = (req.stream.read(1) + for i in range(req.content_length + 1)) + + resp.body = b''.join(req_body) api = application = falcon.API() api.add_route('/', Things()) - from wsgiref.simple_server import make_server - server = make_server('localhost', 9803, application) - server.serve_forever() + server = make_server(_SERVER_HOST, _SERVER_PORT, application) + + while not stop_event.is_set(): + server.handle_request() -class TestWsgi(testtools.TestCase): +class TestWSGIInterface(testing.TestBase): def test_srmock(self): mock = testing.StartResponseMock() @@ -76,30 +90,63 @@ class TestWsgi(testtools.TestCase): self.assertTrue(isinstance(header[0], str)) self.assertTrue(isinstance(header[1], str)) - def test_wsgiref(self): - thread = threading.Thread(target=_run_server) - thread.daemon = True - thread.start() - # Wait a moment for the thread to start up +class TestWSGIReference(testing.TestBase): + + def before(self): + if 'java' in sys.platform: + # NOTE(kgriffs): Jython does not support the multiprocessing + # module. We could alternatively implement these tests + # using threads, but then we have to force a garbage + # collection in between each test in order to make + # the server relinquish its socket, and the gc module + # doesn't appear to do anything under Jython. + self.skip('Incompatible with Jython') + + self._stop_event = multiprocessing.Event() + self._process = multiprocessing.Process(target=_run_server, + args=(self._stop_event,)) + self._process.start() + + # NOTE(kgriffs): Let the server start up time.sleep(0.2) - resp = requests.get('http://localhost:9803') + def after(self): + self._stop_event.set() + + # NOTE(kgriffs): Pump the request handler loop in case execution + # made it to the next server.handle_request() before we sent the + # event. + try: + requests.get(_SERVER_BASE_URL) + except Exception: + pass # Thread already exited + + self._process.join() + + def test_wsgiref_get(self): + resp = requests.get(_SERVER_BASE_URL) self.assertEqual(resp.status_code, 200) - # NOTE(kgriffs): This will cause a different path to - # be taken in req._wrap_stream. Have to use httplib - # to prevent the invalid header value from being - # forced to "0". - conn = http_client.HTTPConnection('localhost', 9803) - headers = {'Content-Length': 'invalid'} - conn.request('PUT', '/', headers=headers) - resp = conn.getresponse() - self.assertEqual(resp.status, 200) - - headers = {'Content-Length': '0'} - resp = requests.put('http://localhost:9803', headers=headers) + def test_wsgiref_put(self): + body = '{}' + resp = requests.put(_SERVER_BASE_URL, data=body) self.assertEqual(resp.status_code, 200) + self.assertEqual(resp.text, '{}') - resp = requests.post('http://localhost:9803') + def test_wsgiref_head_405(self): + body = '{}' + resp = requests.head(_SERVER_BASE_URL, data=body) self.assertEqual(resp.status_code, 405) + + def test_wsgiref_post(self): + body = '{}' + resp = requests.post(_SERVER_BASE_URL, data=body) + self.assertEqual(resp.status_code, 200) + self.assertEqual(resp.text, '{}') + + def test_wsgiref_post_invalid_content_length(self): + headers = {'Content-Length': 'invalid'} + resp = requests.post(_SERVER_BASE_URL, headers=headers) + self.assertEqual(resp.status_code, 200) + self.assertEqual(resp.text, '')