diff --git a/eventlet/wsgi.py b/eventlet/wsgi.py index b77797f..85fe4c9 100644 --- a/eventlet/wsgi.py +++ b/eventlet/wsgi.py @@ -52,7 +52,8 @@ def format_date_time(timestamp): class Input(object): - def __init__(self, rfile, content_length, wfile=None, wfile_line=None): + def __init__(self, rfile, content_length, wfile=None, wfile_line=None, + chunked_input=False): self.rfile = rfile if content_length is not None: content_length = int(content_length) @@ -62,6 +63,8 @@ class Input(object): self.wfile_line = wfile_line self.position = 0 + self.chunked_input = chunked_input + self.chunk_length = -1 def _do_read(self, reader, length=None): if self.wfile is not None: @@ -80,7 +83,38 @@ class Input(object): self.position += len(read) return read + def _chunked_read(self, rfile, length=None): + if self.wfile is not None: + ## 100 Continue + self.wfile.write(self.wfile_line) + self.wfile = None + self.wfile_line = None + + response = [] + if length is None: + if self.chunk_length > self.position: + response.append(rfile.read(self.chunk_length - self.position)) + while self.chunk_length != 0: + self.chunk_length = int(rfile.readline(), 16) + response.append(rfile.read(self.chunk_length)) + rfile.readline() + else: + while length > 0 and self.chunk_length != 0: + if self.chunk_length > self.position: + response.append(rfile.read( + min(self.chunk_length - self.position, length))) + length -= len(response[-1]) + self.position += len(response[-1]) + if self.chunk_length == self.position: + rfile.readline() + else: + self.chunk_length = int(rfile.readline(), 16) + self.position = 0 + return ''.join(response) + def read(self, length=None): + if self.chunked_input: + return self._chunked_read(self.rfile, length) return self._do_read(self.rfile.read, length) def readline(self): @@ -317,8 +351,10 @@ class HttpProtocol(BaseHTTPServer.BaseHTTPRequestHandler): else: wfile = None wfile_line = None + chunked = env.get('HTTP_TRANSFER_ENCODING', '').lower() == 'chunked' env['wsgi.input'] = env['eventlet.input'] = Input( - self.rfile, length, wfile=wfile, wfile_line=wfile_line) + self.rfile, length, wfile=wfile, wfile_line=wfile_line, + chunked_input=chunked) return env diff --git a/greentest/wsgi_test.py b/greentest/wsgi_test.py index 7845a3c..5a3ce78 100644 --- a/greentest/wsgi_test.py +++ b/greentest/wsgi_test.py @@ -57,6 +57,14 @@ def big_chunks(env, start_response): for x in range(10): yield line +def chunked_post(env, start_response): + start_response('200 OK', [('Content-type', 'text/plain')]) + if env['PATH_INFO'] == '/a': + return [env['wsgi.input'].read()] + elif env['PATH_INFO'] == '/b': + return [x for x in iter(lambda: env['wsgi.input'].read(4096), '')] + elif env['PATH_INFO'] == '/c': + return [x for x in iter(lambda: env['wsgi.input'].read(1), '')] class Site(object): def __init__(self): @@ -291,6 +299,34 @@ class TestHttpd(TestCase): res = httpc.get("https://localhost:4202/foo") self.assertEquals(res, '') + def test_014_chunked_post(self): + self.site.application = chunked_post + sock = api.connect_tcp(('127.0.0.1', 12346)) + fd = sock.makeGreenFile() + fd.write('PUT /a HTTP/1.1\r\nHost: localhost\r\nConnection: close\r\n' + 'Transfer-Encoding: chunked\r\n\r\n' + '2\r\noh\r\n4\r\n hai\r\n0\r\n\r\n') + fd.readuntil('\r\n\r\n') + response = fd.read() + self.assert_(response == 'oh hai', 'invalid response %s' % response) + + sock = api.connect_tcp(('127.0.0.1', 12346)) + fd = sock.makeGreenFile() + fd.write('PUT /b HTTP/1.1\r\nHost: localhost\r\nConnection: close\r\n' + 'Transfer-Encoding: chunked\r\n\r\n' + '2\r\noh\r\n4\r\n hai\r\n0\r\n\r\n') + fd.readuntil('\r\n\r\n') + response = fd.read() + self.assert_(response == 'oh hai', 'invalid response %s' % response) + + sock = api.connect_tcp(('127.0.0.1', 12346)) + fd = sock.makeGreenFile() + fd.write('PUT /c HTTP/1.1\r\nHost: localhost\r\nConnection: close\r\n' + 'Transfer-Encoding: chunked\r\n\r\n' + '2\r\noh\r\n4\r\n hai\r\n0\r\n\r\n') + fd.readuntil('\r\n\r\n') + response = fd.read(8192) + self.assert_(response == 'oh hai', 'invalid response %s' % response) if __name__ == '__main__': main()