Merge "Inline parse_request from cpython"
This commit is contained in:
commit
0c18b2d329
@ -16,8 +16,11 @@
|
||||
from eventlet import wsgi, websocket
|
||||
import six
|
||||
|
||||
from swift.common.swob import wsgi_quote, wsgi_unquote, \
|
||||
wsgi_quote_plus, wsgi_unquote_plus, wsgi_to_bytes, bytes_to_wsgi
|
||||
|
||||
if six.PY2:
|
||||
from eventlet.green import httplib as http_client
|
||||
else:
|
||||
from eventlet.green.http import client as http_client
|
||||
|
||||
|
||||
class SwiftHttpProtocol(wsgi.HttpProtocol):
|
||||
@ -62,44 +65,115 @@ class SwiftHttpProtocol(wsgi.HttpProtocol):
|
||||
return ''
|
||||
|
||||
def parse_request(self):
|
||||
# Need to track the bytes-on-the-wire for S3 signatures -- eventlet
|
||||
# would do it for us, but since we rewrite the path on py3, we need to
|
||||
# fix it ourselves later.
|
||||
self.__raw_path_info = None
|
||||
"""Parse a request (inlined from cpython@7e293984).
|
||||
|
||||
The request should be stored in self.raw_requestline; the results
|
||||
are in self.command, self.path, self.request_version and
|
||||
self.headers.
|
||||
|
||||
Return True for success, False for failure; on failure, any relevant
|
||||
error response has already been sent back.
|
||||
|
||||
"""
|
||||
self.command = None # set in case of error on the first line
|
||||
self.request_version = version = self.default_request_version
|
||||
self.close_connection = True
|
||||
requestline = self.raw_requestline
|
||||
if not six.PY2:
|
||||
# request lines *should* be ascii per the RFC, but historically
|
||||
# we've allowed (and even have func tests that use) arbitrary
|
||||
# bytes. This breaks on py3 (see https://bugs.python.org/issue33973
|
||||
# ) but the work-around is simple: munge the request line to be
|
||||
# properly quoted.
|
||||
if self.raw_requestline.count(b' ') >= 2:
|
||||
parts = self.raw_requestline.split(b' ', 2)
|
||||
path, q, query = parts[1].partition(b'?')
|
||||
self.__raw_path_info = path
|
||||
# unquote first, so we don't over-quote something
|
||||
# that was *correctly* quoted
|
||||
path = wsgi_to_bytes(wsgi_quote(wsgi_unquote(
|
||||
bytes_to_wsgi(path))))
|
||||
query = b'&'.join(
|
||||
sep.join([
|
||||
wsgi_to_bytes(wsgi_quote_plus(wsgi_unquote_plus(
|
||||
bytes_to_wsgi(key)))),
|
||||
wsgi_to_bytes(wsgi_quote_plus(wsgi_unquote_plus(
|
||||
bytes_to_wsgi(val))))
|
||||
])
|
||||
for part in query.split(b'&')
|
||||
for key, sep, val in (part.partition(b'='), ))
|
||||
parts[1] = path + q + query
|
||||
self.raw_requestline = b' '.join(parts)
|
||||
# else, mangled protocol, most likely; let base class deal with it
|
||||
return wsgi.HttpProtocol.parse_request(self)
|
||||
requestline = requestline.decode('iso-8859-1')
|
||||
requestline = requestline.rstrip('\r\n')
|
||||
self.requestline = requestline
|
||||
# Split off \x20 explicitly (see https://bugs.python.org/issue33973)
|
||||
words = requestline.split(' ')
|
||||
if len(words) == 0:
|
||||
return False
|
||||
|
||||
if len(words) >= 3: # Enough to determine protocol version
|
||||
version = words[-1]
|
||||
try:
|
||||
if not version.startswith('HTTP/'):
|
||||
raise ValueError
|
||||
base_version_number = version.split('/', 1)[1]
|
||||
version_number = base_version_number.split(".")
|
||||
# RFC 2145 section 3.1 says there can be only one "." and
|
||||
# - major and minor numbers MUST be treated as
|
||||
# separate integers;
|
||||
# - HTTP/2.4 is a lower version than HTTP/2.13, which in
|
||||
# turn is lower than HTTP/12.3;
|
||||
# - Leading zeros MUST be ignored by recipients.
|
||||
if len(version_number) != 2:
|
||||
raise ValueError
|
||||
version_number = int(version_number[0]), int(version_number[1])
|
||||
except (ValueError, IndexError):
|
||||
self.send_error(
|
||||
400,
|
||||
"Bad request version (%r)" % version)
|
||||
return False
|
||||
if version_number >= (1, 1) and \
|
||||
self.protocol_version >= "HTTP/1.1":
|
||||
self.close_connection = False
|
||||
if version_number >= (2, 0):
|
||||
self.send_error(
|
||||
505,
|
||||
"Invalid HTTP version (%s)" % base_version_number)
|
||||
return False
|
||||
self.request_version = version
|
||||
|
||||
if not 2 <= len(words) <= 3:
|
||||
self.send_error(
|
||||
400,
|
||||
"Bad request syntax (%r)" % requestline)
|
||||
return False
|
||||
command, path = words[:2]
|
||||
if len(words) == 2:
|
||||
self.close_connection = True
|
||||
if command != 'GET':
|
||||
self.send_error(
|
||||
400,
|
||||
"Bad HTTP/0.9 request type (%r)" % command)
|
||||
return False
|
||||
self.command, self.path = command, path
|
||||
|
||||
# Examine the headers and look for a Connection directive.
|
||||
if six.PY2:
|
||||
self.headers = self.MessageClass(self.rfile, 0)
|
||||
else:
|
||||
try:
|
||||
self.headers = http_client.parse_headers(
|
||||
self.rfile,
|
||||
_class=self.MessageClass)
|
||||
except http_client.LineTooLong as err:
|
||||
self.send_error(
|
||||
431,
|
||||
"Line too long",
|
||||
str(err))
|
||||
return False
|
||||
except http_client.HTTPException as err:
|
||||
self.send_error(
|
||||
431,
|
||||
"Too many headers",
|
||||
str(err)
|
||||
)
|
||||
return False
|
||||
|
||||
conntype = self.headers.get('Connection', "")
|
||||
if conntype.lower() == 'close':
|
||||
self.close_connection = True
|
||||
elif (conntype.lower() == 'keep-alive' and
|
||||
self.protocol_version >= "HTTP/1.1"):
|
||||
self.close_connection = False
|
||||
# Examine the headers and look for an Expect directive
|
||||
expect = self.headers.get('Expect', "")
|
||||
if (expect.lower() == "100-continue" and
|
||||
self.protocol_version >= "HTTP/1.1" and
|
||||
self.request_version >= "HTTP/1.1"):
|
||||
if not self.handle_expect_100():
|
||||
return False
|
||||
return True
|
||||
|
||||
if not six.PY2:
|
||||
def get_environ(self, *args, **kwargs):
|
||||
environ = wsgi.HttpProtocol.get_environ(self, *args, **kwargs)
|
||||
environ['RAW_PATH_INFO'] = bytes_to_wsgi(
|
||||
self.__raw_path_info)
|
||||
header_payload = self.headers.get_payload()
|
||||
if isinstance(header_payload, list) and len(header_payload) == 1:
|
||||
header_payload = header_payload[0].get_payload()
|
||||
|
@ -15,6 +15,7 @@
|
||||
|
||||
from argparse import Namespace
|
||||
from io import BytesIO
|
||||
import json
|
||||
import mock
|
||||
import types
|
||||
import unittest
|
||||
@ -81,36 +82,14 @@ class TestSwiftHttpProtocol(unittest.TestCase):
|
||||
], proto_obj.send_error.mock_calls)
|
||||
self.assertEqual(('a', '123'), proto_obj.client_address)
|
||||
|
||||
def test_request_line_cleanup(self):
|
||||
def do_test(line_from_socket, expected_line=None):
|
||||
if expected_line is None:
|
||||
expected_line = line_from_socket
|
||||
|
||||
proto_obj = self._proto_obj()
|
||||
proto_obj.raw_requestline = line_from_socket
|
||||
with mock.patch('swift.common.http_protocol.wsgi.HttpProtocol') \
|
||||
as mock_super:
|
||||
proto_obj.parse_request()
|
||||
|
||||
self.assertEqual([mock.call.parse_request(proto_obj)],
|
||||
mock_super.mock_calls)
|
||||
self.assertEqual(proto_obj.raw_requestline, expected_line)
|
||||
|
||||
do_test(b'GET / HTTP/1.1')
|
||||
do_test(b'GET /%FF HTTP/1.1')
|
||||
|
||||
if not six.PY2:
|
||||
do_test(b'GET /\xff HTTP/1.1', b'GET /%FF HTTP/1.1')
|
||||
do_test(b'PUT /Here%20Is%20A%20SnowMan:\xe2\x98\x83 HTTP/1.0',
|
||||
b'PUT /Here%20Is%20A%20SnowMan%3A%E2%98%83 HTTP/1.0')
|
||||
do_test(
|
||||
b'POST /?and%20it=fixes+params&'
|
||||
b'PALMTREE=\xf0%9f\x8c%b4 HTTP/1.1',
|
||||
b'POST /?and+it=fixes+params&PALMTREE=%F0%9F%8C%B4 HTTP/1.1')
|
||||
def test_bad_request_line(self):
|
||||
proto_obj = self._proto_obj()
|
||||
proto_obj.raw_requestline = b'None //'
|
||||
self.assertEqual(False, proto_obj.parse_request())
|
||||
|
||||
|
||||
class ProtocolTest(unittest.TestCase):
|
||||
def _run_bytes_through_protocol(self, bytes_from_client):
|
||||
def _run_bytes_through_protocol(self, bytes_from_client, app=None):
|
||||
rfile = BytesIO(bytes_from_client)
|
||||
wfile = BytesIO()
|
||||
|
||||
@ -153,7 +132,7 @@ class ProtocolTest(unittest.TestCase):
|
||||
with mock.patch.object(wfile, 'close', lambda: None), \
|
||||
mock.patch.object(rfile, 'close', lambda: None):
|
||||
eventlet.wsgi.server(
|
||||
fake_listen_socket, self.app,
|
||||
fake_listen_socket, app or self.app,
|
||||
protocol=self.protocol_class,
|
||||
custom_pool=FakePool(),
|
||||
log_output=False, # quiet the test run
|
||||
@ -170,37 +149,118 @@ class TestSwiftHttpProtocolSomeMore(ProtocolTest):
|
||||
return [swob.wsgi_to_bytes(env['RAW_PATH_INFO'])]
|
||||
|
||||
def test_simple(self):
|
||||
bytes_out = self._run_bytes_through_protocol((
|
||||
bytes_out = self._run_bytes_through_protocol(
|
||||
b"GET /someurl HTTP/1.0\r\n"
|
||||
b"User-Agent: something or other\r\n"
|
||||
b"\r\n"
|
||||
))
|
||||
)
|
||||
|
||||
lines = [l for l in bytes_out.split(b"\r\n") if l]
|
||||
self.assertEqual(lines[0], b"HTTP/1.1 200 OK") # sanity check
|
||||
self.assertEqual(lines[-1], b'/someurl')
|
||||
|
||||
def test_quoted(self):
|
||||
bytes_out = self._run_bytes_through_protocol((
|
||||
bytes_out = self._run_bytes_through_protocol(
|
||||
b"GET /some%fFpath%D8%AA HTTP/1.0\r\n"
|
||||
b"User-Agent: something or other\r\n"
|
||||
b"\r\n"
|
||||
))
|
||||
)
|
||||
|
||||
lines = [l for l in bytes_out.split(b"\r\n") if l]
|
||||
self.assertEqual(lines[0], b"HTTP/1.1 200 OK") # sanity check
|
||||
self.assertEqual(lines[-1], b'/some%fFpath%D8%AA')
|
||||
|
||||
def test_messy(self):
|
||||
bytes_out = self._run_bytes_through_protocol((
|
||||
bytes_out = self._run_bytes_through_protocol(
|
||||
b"GET /oh\xffboy%what$now%E2%80%bd HTTP/1.0\r\n"
|
||||
b"User-Agent: something or other\r\n"
|
||||
b"\r\n"
|
||||
))
|
||||
)
|
||||
|
||||
lines = [l for l in bytes_out.split(b"\r\n") if l]
|
||||
self.assertEqual(lines[-1], b'/oh\xffboy%what$now%E2%80%bd')
|
||||
|
||||
def test_bad_request(self):
|
||||
bytes_out = self._run_bytes_through_protocol((
|
||||
b"ONLY-METHOD\r\n"
|
||||
b"Server: example.com\r\n"
|
||||
b"\r\n"
|
||||
))
|
||||
lines = [l for l in bytes_out.split(b"\r\n") if l]
|
||||
self.assertEqual(
|
||||
lines[0], b"HTTP/1.1 400 Bad request syntax ('ONLY-METHOD')")
|
||||
self.assertIn(b"Bad request syntax or unsupported method.", lines[-1])
|
||||
|
||||
def test_leading_slashes(self):
|
||||
bytes_out = self._run_bytes_through_protocol((
|
||||
b"GET ///some-leading-slashes HTTP/1.0\r\n"
|
||||
b"User-Agent: blah blah blah\r\n"
|
||||
b"\r\n"
|
||||
))
|
||||
lines = [l for l in bytes_out.split(b"\r\n") if l]
|
||||
self.assertEqual(lines[-1], b'///some-leading-slashes')
|
||||
|
||||
def test_request_lines(self):
|
||||
def app(env, start_response):
|
||||
start_response("200 OK", [])
|
||||
if six.PY2:
|
||||
return [json.dumps({
|
||||
'RAW_PATH_INFO': env['RAW_PATH_INFO'].decode('latin1'),
|
||||
'QUERY_STRING': (None if 'QUERY_STRING' not in env else
|
||||
env['QUERY_STRING'].decode('latin1')),
|
||||
}).encode('ascii')]
|
||||
return [json.dumps({
|
||||
'RAW_PATH_INFO': env['RAW_PATH_INFO'],
|
||||
'QUERY_STRING': env.get('QUERY_STRING'),
|
||||
}).encode('ascii')]
|
||||
|
||||
def do_test(request_line, expected):
|
||||
bytes_out = self._run_bytes_through_protocol(
|
||||
request_line + b'\r\n\r\n',
|
||||
app,
|
||||
)
|
||||
print(bytes_out)
|
||||
resp_body = bytes_out.partition(b'\r\n\r\n')[2]
|
||||
self.assertEqual(json.loads(resp_body), expected)
|
||||
|
||||
do_test(b'GET / HTTP/1.1', {
|
||||
'RAW_PATH_INFO': u'/',
|
||||
'QUERY_STRING': None,
|
||||
})
|
||||
do_test(b'GET /%FF HTTP/1.1', {
|
||||
'RAW_PATH_INFO': u'/%FF',
|
||||
'QUERY_STRING': None,
|
||||
})
|
||||
|
||||
do_test(b'GET /\xff HTTP/1.1', {
|
||||
'RAW_PATH_INFO': u'/\xff',
|
||||
'QUERY_STRING': None,
|
||||
})
|
||||
do_test(b'PUT /Here%20Is%20A%20SnowMan:\xe2\x98\x83 HTTP/1.0', {
|
||||
'RAW_PATH_INFO': u'/Here%20Is%20A%20SnowMan:\xe2\x98\x83',
|
||||
'QUERY_STRING': None,
|
||||
})
|
||||
do_test(
|
||||
b'POST /?and%20it=does+nothing+to+params&'
|
||||
b'PALMTREE=\xf0%9f\x8c%b4 HTTP/1.1', {
|
||||
'RAW_PATH_INFO': u'/',
|
||||
'QUERY_STRING': (u'and%20it=does+nothing+to+params'
|
||||
u'&PALMTREE=\xf0%9f\x8c%b4'),
|
||||
}
|
||||
)
|
||||
do_test(b'GET // HTTP/1.1', {
|
||||
'RAW_PATH_INFO': u'//',
|
||||
'QUERY_STRING': None,
|
||||
})
|
||||
do_test(b'GET //bar HTTP/1.1', {
|
||||
'RAW_PATH_INFO': u'//bar',
|
||||
'QUERY_STRING': None,
|
||||
})
|
||||
do_test(b'GET //////baz HTTP/1.1', {
|
||||
'RAW_PATH_INFO': u'//////baz',
|
||||
'QUERY_STRING': None,
|
||||
})
|
||||
|
||||
|
||||
class TestProxyProtocol(ProtocolTest):
|
||||
protocol_class = http_protocol.SwiftHttpProxiedProtocol
|
||||
@ -222,12 +282,12 @@ class TestProxyProtocol(ProtocolTest):
|
||||
return [body.encode("utf-8")]
|
||||
|
||||
def test_request_with_proxy(self):
|
||||
bytes_out = self._run_bytes_through_protocol((
|
||||
bytes_out = self._run_bytes_through_protocol(
|
||||
b"PROXY TCP4 192.168.0.1 192.168.0.11 56423 4433\r\n"
|
||||
b"GET /someurl HTTP/1.0\r\n"
|
||||
b"User-Agent: something or other\r\n"
|
||||
b"\r\n"
|
||||
))
|
||||
)
|
||||
|
||||
lines = [l for l in bytes_out.split(b"\r\n") if l]
|
||||
self.assertEqual(lines[0], b"HTTP/1.1 200 OK") # sanity check
|
||||
@ -238,12 +298,12 @@ class TestProxyProtocol(ProtocolTest):
|
||||
])
|
||||
|
||||
def test_request_with_proxy_https(self):
|
||||
bytes_out = self._run_bytes_through_protocol((
|
||||
bytes_out = self._run_bytes_through_protocol(
|
||||
b"PROXY TCP4 192.168.0.1 192.168.0.11 56423 443\r\n"
|
||||
b"GET /someurl HTTP/1.0\r\n"
|
||||
b"User-Agent: something or other\r\n"
|
||||
b"\r\n"
|
||||
))
|
||||
)
|
||||
|
||||
lines = [l for l in bytes_out.split(b"\r\n") if l]
|
||||
self.assertEqual(lines[0], b"HTTP/1.1 200 OK") # sanity check
|
||||
@ -254,7 +314,7 @@ class TestProxyProtocol(ProtocolTest):
|
||||
])
|
||||
|
||||
def test_multiple_requests_with_proxy(self):
|
||||
bytes_out = self._run_bytes_through_protocol((
|
||||
bytes_out = self._run_bytes_through_protocol(
|
||||
b"PROXY TCP4 192.168.0.1 192.168.0.11 56423 443\r\n"
|
||||
b"GET /someurl HTTP/1.1\r\n"
|
||||
b"User-Agent: something or other\r\n"
|
||||
@ -263,7 +323,7 @@ class TestProxyProtocol(ProtocolTest):
|
||||
b"User-Agent: something or other\r\n"
|
||||
b"Connection: close\r\n"
|
||||
b"\r\n"
|
||||
))
|
||||
)
|
||||
|
||||
lines = bytes_out.split(b"\r\n")
|
||||
self.assertEqual(lines[0], b"HTTP/1.1 200 OK") # sanity check
|
||||
@ -277,12 +337,12 @@ class TestProxyProtocol(ProtocolTest):
|
||||
self.assertEqual(addr_lines, [b"https is on (scheme https)"] * 2)
|
||||
|
||||
def test_missing_proxy_line(self):
|
||||
bytes_out = self._run_bytes_through_protocol((
|
||||
bytes_out = self._run_bytes_through_protocol(
|
||||
# whoops, no PROXY line here
|
||||
b"GET /someurl HTTP/1.0\r\n"
|
||||
b"User-Agent: something or other\r\n"
|
||||
b"\r\n"
|
||||
))
|
||||
)
|
||||
|
||||
lines = [l for l in bytes_out.split(b"\r\n") if l]
|
||||
self.assertIn(b"400 Invalid PROXY line", lines[0])
|
||||
@ -303,12 +363,12 @@ class TestProxyProtocol(ProtocolTest):
|
||||
for unknown_line in [b'PROXY UNKNOWN', # mimimal valid unknown
|
||||
b'PROXY UNKNOWNblahblah', # also valid
|
||||
b'PROXY UNKNOWN a b c d']:
|
||||
bytes_out = self._run_bytes_through_protocol((
|
||||
bytes_out = self._run_bytes_through_protocol(
|
||||
unknown_line + (b"\r\n"
|
||||
b"GET /someurl HTTP/1.0\r\n"
|
||||
b"User-Agent: something or other\r\n"
|
||||
b"\r\n")
|
||||
))
|
||||
)
|
||||
lines = [l for l in bytes_out.split(b"\r\n") if l]
|
||||
self.assertIn(b"200 OK", lines[0])
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user