diff --git a/falcon/request.py b/falcon/request.py index 3312667..e809bcb 100644 --- a/falcon/request.py +++ b/falcon/request.py @@ -1055,15 +1055,14 @@ class Request(object): def _wrap_stream(self): # pragma nocover try: - # NOTE(kgriffs): We can only add the wrapper if the - # content-length header was provided. - if self.content_length is not None: - self.stream = helpers.Body(self.stream, self.content_length) + content_length = self.content_length or 0 except HTTPInvalidHeader: # NOTE(kgriffs): The content-length header was specified, - # but it had an invalid value. - pass + # but it had an invalid value. Assume no content. + content_length = 0 + + self.stream = helpers.Body(self.stream, content_length) def _parse_form_urlencoded(self): # NOTE(kgriffs): This assumes self.stream has been patched diff --git a/falcon/request_helpers.py b/falcon/request_helpers.py index 3a18e2d..3c42f70 100644 --- a/falcon/request_helpers.py +++ b/falcon/request_helpers.py @@ -58,6 +58,8 @@ class Body(object): self.stream = stream self.stream_len = stream_len + self._bytes_remaining = self.stream_len + def __iter__(self): return self @@ -83,9 +85,13 @@ class Body(object): """ - if size is None or size == -1 or size > self.stream_len: - size = self.stream_len + # NOTE(kgriffs): Default to reading all remaining bytes if the + # size is not specified or is out of bounds. This behaves + # similarly to the IO streams passed in by non-wsgiref servers. + if (size is None or size == -1 or size > self._bytes_remaining): + size = self._bytes_remaining + self._bytes_remaining -= size return target(size) def read(self, size=None): diff --git a/tests/test_request_body.py b/tests/test_request_body.py index 8b47039..fef730f 100644 --- a/tests/test_request_body.py +++ b/tests/test_request_body.py @@ -127,6 +127,14 @@ class TestRequestBody(testing.TestBase): body = request_helpers.Body(stream, expected_len) self.assertEqual(body.read(expected_len + 1), expected_body) + # NOTE(kgriffs): Test that reading past the end does not + # hang, but returns the empty string. + stream = io.BytesIO(expected_body) + body = request_helpers.Body(stream, expected_len) + for i in range(expected_len + 1): + expected_value = expected_body[i:i + 1] if i < expected_len else b'' + self.assertEqual(body.read(1), expected_value) + stream = io.BytesIO(expected_body) body = request_helpers.Body(stream, expected_len) self.assertEqual(body.readline(), expected_lines[0]) diff --git a/tests/test_wsgi.py b/tests/test_wsgi.py index 7cabbc8..7d6e75a 100644 --- a/tests/test_wsgi.py +++ b/tests/test_wsgi.py @@ -1,15 +1,20 @@ import sys -import threading import time +from wsgiref.simple_server import make_server + +if 'java' not in sys.platform: + import multiprocessing import requests -from six.moves import http_client -import testtools from testtools.matchers import Equals, MatchesRegex import falcon import falcon.testing as testing +_SERVER_HOST = 'localhost' +_SERVER_PORT = 9809 +_SERVER_BASE_URL = 'http://{0}:{1}/'.format(_SERVER_HOST, _SERVER_PORT) + def _is_iterable(thing): try: @@ -21,23 +26,32 @@ def _is_iterable(thing): return False -def _run_server(): +def _run_server(stop_event): class Things(object): def on_get(self, req, resp): pass + def on_post(self, req, resp): + resp.body = req.stream.read(1000) + def on_put(self, req, resp): - pass + # NOTE(kgriffs): Test that reading past the end does + # not hang. + req_body = (req.stream.read(1) + for i in range(req.content_length + 1)) + + resp.body = b''.join(req_body) api = application = falcon.API() api.add_route('/', Things()) - from wsgiref.simple_server import make_server - server = make_server('localhost', 9803, application) - server.serve_forever() + server = make_server(_SERVER_HOST, _SERVER_PORT, application) + + while not stop_event.is_set(): + server.handle_request() -class TestWsgi(testtools.TestCase): +class TestWSGIInterface(testing.TestBase): def test_srmock(self): mock = testing.StartResponseMock() @@ -76,30 +90,63 @@ class TestWsgi(testtools.TestCase): self.assertTrue(isinstance(header[0], str)) self.assertTrue(isinstance(header[1], str)) - def test_wsgiref(self): - thread = threading.Thread(target=_run_server) - thread.daemon = True - thread.start() - # Wait a moment for the thread to start up +class TestWSGIReference(testing.TestBase): + + def before(self): + if 'java' in sys.platform: + # NOTE(kgriffs): Jython does not support the multiprocessing + # module. We could alternatively implement these tests + # using threads, but then we have to force a garbage + # collection in between each test in order to make + # the server relinquish its socket, and the gc module + # doesn't appear to do anything under Jython. + self.skip('Incompatible with Jython') + + self._stop_event = multiprocessing.Event() + self._process = multiprocessing.Process(target=_run_server, + args=(self._stop_event,)) + self._process.start() + + # NOTE(kgriffs): Let the server start up time.sleep(0.2) - resp = requests.get('http://localhost:9803') + def after(self): + self._stop_event.set() + + # NOTE(kgriffs): Pump the request handler loop in case execution + # made it to the next server.handle_request() before we sent the + # event. + try: + requests.get(_SERVER_BASE_URL) + except Exception: + pass # Thread already exited + + self._process.join() + + def test_wsgiref_get(self): + resp = requests.get(_SERVER_BASE_URL) self.assertEqual(resp.status_code, 200) - # NOTE(kgriffs): This will cause a different path to - # be taken in req._wrap_stream. Have to use httplib - # to prevent the invalid header value from being - # forced to "0". - conn = http_client.HTTPConnection('localhost', 9803) - headers = {'Content-Length': 'invalid'} - conn.request('PUT', '/', headers=headers) - resp = conn.getresponse() - self.assertEqual(resp.status, 200) - - headers = {'Content-Length': '0'} - resp = requests.put('http://localhost:9803', headers=headers) + def test_wsgiref_put(self): + body = '{}' + resp = requests.put(_SERVER_BASE_URL, data=body) self.assertEqual(resp.status_code, 200) + self.assertEqual(resp.text, '{}') - resp = requests.post('http://localhost:9803') + def test_wsgiref_head_405(self): + body = '{}' + resp = requests.head(_SERVER_BASE_URL, data=body) self.assertEqual(resp.status_code, 405) + + def test_wsgiref_post(self): + body = '{}' + resp = requests.post(_SERVER_BASE_URL, data=body) + self.assertEqual(resp.status_code, 200) + self.assertEqual(resp.text, '{}') + + def test_wsgiref_post_invalid_content_length(self): + headers = {'Content-Length': 'invalid'} + resp = requests.post(_SERVER_BASE_URL, headers=headers) + self.assertEqual(resp.status_code, 200) + self.assertEqual(resp.text, '')