feat(Request): Normalize wsgi.input semantics
The socket._fileobject and io.BufferedReader are sometimes used to implement wsgi.input. However, app developers are often burned by the fact that the read() method for these objects block indefinitely if either no size is passed, or a size greater than the request's content length is passed to the method. This patch makes Falcon detect when the above native stream types are used by a WSGI server, and wraps them with a simple Body object that provides more forgiving read, readline, and readlines methods than what is otherwise provided. The end result is that app developers are shielded from this silly inconsistency between WSGI servers. Fixes issue #147
This commit is contained in:
@@ -18,6 +18,18 @@ limitations under the License.
|
||||
|
||||
from datetime import datetime
|
||||
|
||||
try:
|
||||
# NOTE(kgrifs): In Python 2.6 and 2.7, socket._fileobject is a
|
||||
# standard way of exposing a socket as a file-like object, and
|
||||
# is used by wsgiref for wsgi.input.
|
||||
import socket
|
||||
NativeStream = socket._fileobject
|
||||
except AttributeError: # pragma nocover
|
||||
# NOTE(kgriffs): In Python 3.3, wsgiref implements wsgi.input
|
||||
# using _io.BufferedReader which is an alias of io.BufferedReader
|
||||
import io
|
||||
NativeStream = io.BufferedReader
|
||||
|
||||
import mimeparse
|
||||
import six
|
||||
|
||||
@@ -81,7 +93,6 @@ class Request(object):
|
||||
|
||||
self._wsgierrors = env['wsgi.errors']
|
||||
self.stream = env['wsgi.input']
|
||||
|
||||
self.method = env['REQUEST_METHOD']
|
||||
|
||||
# Normalize path
|
||||
@@ -109,6 +120,14 @@ class Request(object):
|
||||
|
||||
self._headers = helpers.parse_headers(env)
|
||||
|
||||
# NOTE(kgriffs): Wrap wsgi.input if needed to make read() more robust,
|
||||
# normalizing semantics between, e.g., gunicorn and wsgiref.
|
||||
if isinstance(self.stream, NativeStream): # pragma: nocover
|
||||
# NOTE(kgriffs): coverage can't detect that this *is* actually
|
||||
# covered since the test that does so uses multiprocessing.
|
||||
self.stream = helpers.Body(self.stream, self.content_length)
|
||||
|
||||
# TODO(kgriffs): Use the nocover pragma only for the six.PY3 if..else
|
||||
def log_error(self, message): # pragma: no cover
|
||||
"""Log an error to wsgi.error
|
||||
|
||||
|
@@ -94,3 +94,53 @@ def parse_headers(env):
|
||||
headers['HOST'] = host
|
||||
|
||||
return headers
|
||||
|
||||
|
||||
class Body(object):
|
||||
"""Wrap wsgi.input streams to make them more robust.
|
||||
|
||||
The socket._fileobject and io.BufferedReader are sometimes used
|
||||
to implement wsgi.input. However, app developers are often burned
|
||||
by the fact that the read() method for these objects block
|
||||
indefinitely if either no size is passed, or a size greater than
|
||||
the request's content length is passed to the method.
|
||||
|
||||
This class normalizes wsgi.input behavior between WSGI servers
|
||||
by implementing non-blocking behavior for the cases mentioned
|
||||
above.
|
||||
"""
|
||||
|
||||
def __init__(self, stream, stream_len):
|
||||
"""Initialize the request body instance.
|
||||
|
||||
Args:
|
||||
stream: Instance of socket._fileobject from environ['wsgi.input']
|
||||
stream_len: Expected content length of the stream.
|
||||
"""
|
||||
|
||||
self.stream = stream
|
||||
self.stream_len = stream_len
|
||||
|
||||
def _make_stream_reader(func):
|
||||
def read(size=None):
|
||||
if size is None or size > self.stream_len:
|
||||
size = self.stream_len
|
||||
|
||||
return func(size)
|
||||
|
||||
return read
|
||||
|
||||
# NOTE(kgriffs): All of the wrapped methods take a single argument,
|
||||
# which is a size AKA length AKA limit, always in bytes/characters.
|
||||
# This is consistent with Gunicorn's "Body" class.
|
||||
for attr in ('read', 'readline', 'readlines'):
|
||||
target = getattr(self.stream, attr)
|
||||
setattr(self, attr, _make_stream_reader(target))
|
||||
|
||||
def __iter__(self):
|
||||
return self
|
||||
|
||||
def __next__(self):
|
||||
return next(self.stream)
|
||||
|
||||
next = __next__
|
||||
|
@@ -11,7 +11,7 @@ def application(environ, start_response):
|
||||
|
||||
body += '}\n\n'
|
||||
|
||||
return [body]
|
||||
return [body.encode('utf-8')]
|
||||
|
||||
app = application
|
||||
|
||||
|
@@ -1,5 +1,15 @@
|
||||
import io
|
||||
import multiprocessing
|
||||
from wsgiref import simple_server
|
||||
|
||||
import requests
|
||||
|
||||
import falcon
|
||||
from falcon import request_helpers
|
||||
import falcon.testing as testing
|
||||
|
||||
SIZE_1_KB = 1024
|
||||
|
||||
|
||||
class TestRequestBody(testing.TestBase):
|
||||
|
||||
@@ -25,8 +35,17 @@ class TestRequestBody(testing.TestBase):
|
||||
stream.seek(0, 2)
|
||||
self.assertEquals(stream.tell(), 1)
|
||||
|
||||
def test_tiny_body_overflow(self):
|
||||
expected_body = '.'
|
||||
self.simulate_request('', body=expected_body)
|
||||
stream = self.resource.req.stream
|
||||
|
||||
# Read too many bytes; shouldn't block
|
||||
actual_body = stream.read(len(expected_body) + 1)
|
||||
self.assertEquals(actual_body, expected_body.encode('utf-8'))
|
||||
|
||||
def test_read_body(self):
|
||||
expected_body = testing.rand_string(2, 1 * 1024 * 1024)
|
||||
expected_body = testing.rand_string(SIZE_1_KB / 2, SIZE_1_KB)
|
||||
expected_len = len(expected_body)
|
||||
headers = {'Content-Length': str(expected_len)}
|
||||
|
||||
@@ -44,3 +63,85 @@ class TestRequestBody(testing.TestBase):
|
||||
self.assertEquals(stream.tell(), expected_len)
|
||||
|
||||
self.assertEquals(stream.tell(), expected_len)
|
||||
|
||||
def test_read_socket_body(self):
|
||||
expected_body = testing.rand_string(SIZE_1_KB / 2, SIZE_1_KB)
|
||||
|
||||
def server():
|
||||
class Echo(object):
|
||||
def on_post(self, req, resp):
|
||||
# wsgiref socket._fileobject blocks when len not given,
|
||||
# but Falcon is smarter than that. :D
|
||||
body = req.stream.read()
|
||||
resp.body = body
|
||||
|
||||
def on_put(self, req, resp):
|
||||
# wsgiref socket._fileobject blocks when len too long,
|
||||
# but Falcon should work around that for me.
|
||||
body = req.stream.read(req.content_length + 1)
|
||||
resp.body = body
|
||||
|
||||
api = falcon.API()
|
||||
api.add_route('/echo', Echo())
|
||||
|
||||
httpd = simple_server.make_server('127.0.0.1', 8989, api)
|
||||
httpd.serve_forever()
|
||||
|
||||
process = multiprocessing.Process(target=server)
|
||||
process.daemon = True
|
||||
process.start()
|
||||
|
||||
# Let it boot
|
||||
process.join(1)
|
||||
|
||||
url = 'http://127.0.0.1:8989/echo'
|
||||
resp = requests.post(url, data=expected_body)
|
||||
self.assertEquals(resp.text, expected_body)
|
||||
|
||||
resp = requests.put(url, data=expected_body)
|
||||
self.assertEquals(resp.text, expected_body)
|
||||
|
||||
process.terminate()
|
||||
|
||||
def test_body_stream_wrapper(self):
|
||||
data = testing.rand_string(SIZE_1_KB / 2, SIZE_1_KB)
|
||||
expected_body = data.encode('utf-8')
|
||||
expected_len = len(expected_body)
|
||||
|
||||
# NOTE(kgriffs): Append newline char to each line
|
||||
# to match readlines behavior
|
||||
expected_lines = [(line + '\n').encode('utf-8')
|
||||
for line in data.split('\n')]
|
||||
|
||||
# NOTE(kgriffs): Remove trailing newline to simulate
|
||||
# what readlines does
|
||||
expected_lines[-1] = expected_lines[-1][:-1]
|
||||
|
||||
stream = io.BytesIO(expected_body)
|
||||
body = request_helpers.Body(stream, expected_len)
|
||||
self.assertEquals(body.read(), expected_body)
|
||||
|
||||
stream = io.BytesIO(expected_body)
|
||||
body = request_helpers.Body(stream, expected_len)
|
||||
self.assertEquals(body.read(2), expected_body[0:2])
|
||||
|
||||
stream = io.BytesIO(expected_body)
|
||||
body = request_helpers.Body(stream, expected_len)
|
||||
self.assertEquals(body.read(expected_len + 1), expected_body)
|
||||
|
||||
stream = io.BytesIO(expected_body)
|
||||
body = request_helpers.Body(stream, expected_len)
|
||||
self.assertEquals(body.readline(), expected_lines[0])
|
||||
|
||||
stream = io.BytesIO(expected_body)
|
||||
body = request_helpers.Body(stream, expected_len)
|
||||
self.assertEquals(body.readlines(), expected_lines)
|
||||
|
||||
stream = io.BytesIO(expected_body)
|
||||
body = request_helpers.Body(stream, expected_len)
|
||||
self.assertEquals(next(body), expected_lines[0])
|
||||
|
||||
stream = io.BytesIO(expected_body)
|
||||
body = request_helpers.Body(stream, expected_len)
|
||||
for i, line in enumerate(body):
|
||||
self.assertEquals(line, expected_lines[i])
|
||||
|
@@ -1,5 +1,6 @@
|
||||
coverage
|
||||
nose
|
||||
ordereddict
|
||||
requests
|
||||
six
|
||||
testtools
|
||||
|
Reference in New Issue
Block a user