Inline parse_request from cpython

Applied deltas:

- Fix http.client references
- Inline HTTPStatus codes
- Address request line splitting (https://bugs.python.org/issue33973)
- Special-case py2 header-parsing
- Address multiple leading slashes in request path
  (https://github.com/python/cpython/issues/99220)

Closes-Bug: #1999278
Change-Id: Iae28097668213aa0734837ff21aef83251167d19
(cherry picked from commit 884f5538f8)
This commit is contained in:
Tim Burke 2022-12-06 11:21:37 -08:00
parent 76cdb4d3e9
commit 28d6cf1ba8
2 changed files with 212 additions and 78 deletions

View File

@ -16,8 +16,11 @@
from eventlet import wsgi, websocket
import six
from swift.common.swob import wsgi_quote, wsgi_unquote, \
wsgi_quote_plus, wsgi_unquote_plus, wsgi_to_bytes, bytes_to_wsgi
if six.PY2:
from eventlet.green import httplib as http_client
else:
from eventlet.green.http import client as http_client
class SwiftHttpProtocol(wsgi.HttpProtocol):
@ -62,44 +65,115 @@ class SwiftHttpProtocol(wsgi.HttpProtocol):
return ''
def parse_request(self):
# Need to track the bytes-on-the-wire for S3 signatures -- eventlet
# would do it for us, but since we rewrite the path on py3, we need to
# fix it ourselves later.
self.__raw_path_info = None
"""Parse a request (inlined from cpython@7e293984).
The request should be stored in self.raw_requestline; the results
are in self.command, self.path, self.request_version and
self.headers.
Return True for success, False for failure; on failure, any relevant
error response has already been sent back.
"""
self.command = None # set in case of error on the first line
self.request_version = version = self.default_request_version
self.close_connection = True
requestline = self.raw_requestline
if not six.PY2:
# request lines *should* be ascii per the RFC, but historically
# we've allowed (and even have func tests that use) arbitrary
# bytes. This breaks on py3 (see https://bugs.python.org/issue33973
# ) but the work-around is simple: munge the request line to be
# properly quoted.
if self.raw_requestline.count(b' ') >= 2:
parts = self.raw_requestline.split(b' ', 2)
path, q, query = parts[1].partition(b'?')
self.__raw_path_info = path
# unquote first, so we don't over-quote something
# that was *correctly* quoted
path = wsgi_to_bytes(wsgi_quote(wsgi_unquote(
bytes_to_wsgi(path))))
query = b'&'.join(
sep.join([
wsgi_to_bytes(wsgi_quote_plus(wsgi_unquote_plus(
bytes_to_wsgi(key)))),
wsgi_to_bytes(wsgi_quote_plus(wsgi_unquote_plus(
bytes_to_wsgi(val))))
])
for part in query.split(b'&')
for key, sep, val in (part.partition(b'='), ))
parts[1] = path + q + query
self.raw_requestline = b' '.join(parts)
# else, mangled protocol, most likely; let base class deal with it
return wsgi.HttpProtocol.parse_request(self)
requestline = requestline.decode('iso-8859-1')
requestline = requestline.rstrip('\r\n')
self.requestline = requestline
# Split off \x20 explicitly (see https://bugs.python.org/issue33973)
words = requestline.split(' ')
if len(words) == 0:
return False
if len(words) >= 3: # Enough to determine protocol version
version = words[-1]
try:
if not version.startswith('HTTP/'):
raise ValueError
base_version_number = version.split('/', 1)[1]
version_number = base_version_number.split(".")
# RFC 2145 section 3.1 says there can be only one "." and
# - major and minor numbers MUST be treated as
# separate integers;
# - HTTP/2.4 is a lower version than HTTP/2.13, which in
# turn is lower than HTTP/12.3;
# - Leading zeros MUST be ignored by recipients.
if len(version_number) != 2:
raise ValueError
version_number = int(version_number[0]), int(version_number[1])
except (ValueError, IndexError):
self.send_error(
400,
"Bad request version (%r)" % version)
return False
if version_number >= (1, 1) and \
self.protocol_version >= "HTTP/1.1":
self.close_connection = False
if version_number >= (2, 0):
self.send_error(
505,
"Invalid HTTP version (%s)" % base_version_number)
return False
self.request_version = version
if not 2 <= len(words) <= 3:
self.send_error(
400,
"Bad request syntax (%r)" % requestline)
return False
command, path = words[:2]
if len(words) == 2:
self.close_connection = True
if command != 'GET':
self.send_error(
400,
"Bad HTTP/0.9 request type (%r)" % command)
return False
self.command, self.path = command, path
# Examine the headers and look for a Connection directive.
if six.PY2:
self.headers = self.MessageClass(self.rfile, 0)
else:
try:
self.headers = http_client.parse_headers(
self.rfile,
_class=self.MessageClass)
except http_client.LineTooLong as err:
self.send_error(
431,
"Line too long",
str(err))
return False
except http_client.HTTPException as err:
self.send_error(
431,
"Too many headers",
str(err)
)
return False
conntype = self.headers.get('Connection', "")
if conntype.lower() == 'close':
self.close_connection = True
elif (conntype.lower() == 'keep-alive' and
self.protocol_version >= "HTTP/1.1"):
self.close_connection = False
# Examine the headers and look for an Expect directive
expect = self.headers.get('Expect', "")
if (expect.lower() == "100-continue" and
self.protocol_version >= "HTTP/1.1" and
self.request_version >= "HTTP/1.1"):
if not self.handle_expect_100():
return False
return True
if not six.PY2:
def get_environ(self, *args, **kwargs):
environ = wsgi.HttpProtocol.get_environ(self, *args, **kwargs)
environ['RAW_PATH_INFO'] = bytes_to_wsgi(
self.__raw_path_info)
header_payload = self.headers.get_payload()
if isinstance(header_payload, list) and len(header_payload) == 1:
header_payload = header_payload[0].get_payload()

View File

@ -15,6 +15,7 @@
from argparse import Namespace
from io import BytesIO
import json
import mock
import types
import unittest
@ -81,36 +82,14 @@ class TestSwiftHttpProtocol(unittest.TestCase):
], proto_obj.send_error.mock_calls)
self.assertEqual(('a', '123'), proto_obj.client_address)
def test_request_line_cleanup(self):
def do_test(line_from_socket, expected_line=None):
if expected_line is None:
expected_line = line_from_socket
proto_obj = self._proto_obj()
proto_obj.raw_requestline = line_from_socket
with mock.patch('swift.common.http_protocol.wsgi.HttpProtocol') \
as mock_super:
proto_obj.parse_request()
self.assertEqual([mock.call.parse_request(proto_obj)],
mock_super.mock_calls)
self.assertEqual(proto_obj.raw_requestline, expected_line)
do_test(b'GET / HTTP/1.1')
do_test(b'GET /%FF HTTP/1.1')
if not six.PY2:
do_test(b'GET /\xff HTTP/1.1', b'GET /%FF HTTP/1.1')
do_test(b'PUT /Here%20Is%20A%20SnowMan:\xe2\x98\x83 HTTP/1.0',
b'PUT /Here%20Is%20A%20SnowMan%3A%E2%98%83 HTTP/1.0')
do_test(
b'POST /?and%20it=fixes+params&'
b'PALMTREE=\xf0%9f\x8c%b4 HTTP/1.1',
b'POST /?and+it=fixes+params&PALMTREE=%F0%9F%8C%B4 HTTP/1.1')
def test_bad_request_line(self):
proto_obj = self._proto_obj()
proto_obj.raw_requestline = b'None //'
self.assertEqual(False, proto_obj.parse_request())
class ProtocolTest(unittest.TestCase):
def _run_bytes_through_protocol(self, bytes_from_client):
def _run_bytes_through_protocol(self, bytes_from_client, app=None):
rfile = BytesIO(bytes_from_client)
wfile = BytesIO()
@ -153,7 +132,7 @@ class ProtocolTest(unittest.TestCase):
with mock.patch.object(wfile, 'close', lambda: None), \
mock.patch.object(rfile, 'close', lambda: None):
eventlet.wsgi.server(
fake_listen_socket, self.app,
fake_listen_socket, app or self.app,
protocol=self.protocol_class,
custom_pool=FakePool(),
log_output=False, # quiet the test run
@ -170,37 +149,118 @@ class TestSwiftHttpProtocolSomeMore(ProtocolTest):
return [swob.wsgi_to_bytes(env['RAW_PATH_INFO'])]
def test_simple(self):
bytes_out = self._run_bytes_through_protocol((
bytes_out = self._run_bytes_through_protocol(
b"GET /someurl HTTP/1.0\r\n"
b"User-Agent: something or other\r\n"
b"\r\n"
))
)
lines = [l for l in bytes_out.split(b"\r\n") if l]
self.assertEqual(lines[0], b"HTTP/1.1 200 OK") # sanity check
self.assertEqual(lines[-1], b'/someurl')
def test_quoted(self):
bytes_out = self._run_bytes_through_protocol((
bytes_out = self._run_bytes_through_protocol(
b"GET /some%fFpath%D8%AA HTTP/1.0\r\n"
b"User-Agent: something or other\r\n"
b"\r\n"
))
)
lines = [l for l in bytes_out.split(b"\r\n") if l]
self.assertEqual(lines[0], b"HTTP/1.1 200 OK") # sanity check
self.assertEqual(lines[-1], b'/some%fFpath%D8%AA')
def test_messy(self):
bytes_out = self._run_bytes_through_protocol((
bytes_out = self._run_bytes_through_protocol(
b"GET /oh\xffboy%what$now%E2%80%bd HTTP/1.0\r\n"
b"User-Agent: something or other\r\n"
b"\r\n"
))
)
lines = [l for l in bytes_out.split(b"\r\n") if l]
self.assertEqual(lines[-1], b'/oh\xffboy%what$now%E2%80%bd')
def test_bad_request(self):
bytes_out = self._run_bytes_through_protocol((
b"ONLY-METHOD\r\n"
b"Server: example.com\r\n"
b"\r\n"
))
lines = [l for l in bytes_out.split(b"\r\n") if l]
self.assertEqual(
lines[0], b"HTTP/1.1 400 Bad request syntax ('ONLY-METHOD')")
self.assertIn(b"Bad request syntax or unsupported method.", lines[-1])
def test_leading_slashes(self):
bytes_out = self._run_bytes_through_protocol((
b"GET ///some-leading-slashes HTTP/1.0\r\n"
b"User-Agent: blah blah blah\r\n"
b"\r\n"
))
lines = [l for l in bytes_out.split(b"\r\n") if l]
self.assertEqual(lines[-1], b'///some-leading-slashes')
def test_request_lines(self):
def app(env, start_response):
start_response("200 OK", [])
if six.PY2:
return [json.dumps({
'RAW_PATH_INFO': env['RAW_PATH_INFO'].decode('latin1'),
'QUERY_STRING': (None if 'QUERY_STRING' not in env else
env['QUERY_STRING'].decode('latin1')),
}).encode('ascii')]
return [json.dumps({
'RAW_PATH_INFO': env['RAW_PATH_INFO'],
'QUERY_STRING': env.get('QUERY_STRING'),
}).encode('ascii')]
def do_test(request_line, expected):
bytes_out = self._run_bytes_through_protocol(
request_line + b'\r\n\r\n',
app,
)
print(bytes_out)
resp_body = bytes_out.partition(b'\r\n\r\n')[2]
self.assertEqual(json.loads(resp_body), expected)
do_test(b'GET / HTTP/1.1', {
'RAW_PATH_INFO': u'/',
'QUERY_STRING': None,
})
do_test(b'GET /%FF HTTP/1.1', {
'RAW_PATH_INFO': u'/%FF',
'QUERY_STRING': None,
})
do_test(b'GET /\xff HTTP/1.1', {
'RAW_PATH_INFO': u'/\xff',
'QUERY_STRING': None,
})
do_test(b'PUT /Here%20Is%20A%20SnowMan:\xe2\x98\x83 HTTP/1.0', {
'RAW_PATH_INFO': u'/Here%20Is%20A%20SnowMan:\xe2\x98\x83',
'QUERY_STRING': None,
})
do_test(
b'POST /?and%20it=does+nothing+to+params&'
b'PALMTREE=\xf0%9f\x8c%b4 HTTP/1.1', {
'RAW_PATH_INFO': u'/',
'QUERY_STRING': (u'and%20it=does+nothing+to+params'
u'&PALMTREE=\xf0%9f\x8c%b4'),
}
)
do_test(b'GET // HTTP/1.1', {
'RAW_PATH_INFO': u'//',
'QUERY_STRING': None,
})
do_test(b'GET //bar HTTP/1.1', {
'RAW_PATH_INFO': u'//bar',
'QUERY_STRING': None,
})
do_test(b'GET //////baz HTTP/1.1', {
'RAW_PATH_INFO': u'//////baz',
'QUERY_STRING': None,
})
class TestProxyProtocol(ProtocolTest):
protocol_class = http_protocol.SwiftHttpProxiedProtocol
@ -222,12 +282,12 @@ class TestProxyProtocol(ProtocolTest):
return [body.encode("utf-8")]
def test_request_with_proxy(self):
bytes_out = self._run_bytes_through_protocol((
bytes_out = self._run_bytes_through_protocol(
b"PROXY TCP4 192.168.0.1 192.168.0.11 56423 4433\r\n"
b"GET /someurl HTTP/1.0\r\n"
b"User-Agent: something or other\r\n"
b"\r\n"
))
)
lines = [l for l in bytes_out.split(b"\r\n") if l]
self.assertEqual(lines[0], b"HTTP/1.1 200 OK") # sanity check
@ -238,12 +298,12 @@ class TestProxyProtocol(ProtocolTest):
])
def test_request_with_proxy_https(self):
bytes_out = self._run_bytes_through_protocol((
bytes_out = self._run_bytes_through_protocol(
b"PROXY TCP4 192.168.0.1 192.168.0.11 56423 443\r\n"
b"GET /someurl HTTP/1.0\r\n"
b"User-Agent: something or other\r\n"
b"\r\n"
))
)
lines = [l for l in bytes_out.split(b"\r\n") if l]
self.assertEqual(lines[0], b"HTTP/1.1 200 OK") # sanity check
@ -254,7 +314,7 @@ class TestProxyProtocol(ProtocolTest):
])
def test_multiple_requests_with_proxy(self):
bytes_out = self._run_bytes_through_protocol((
bytes_out = self._run_bytes_through_protocol(
b"PROXY TCP4 192.168.0.1 192.168.0.11 56423 443\r\n"
b"GET /someurl HTTP/1.1\r\n"
b"User-Agent: something or other\r\n"
@ -263,7 +323,7 @@ class TestProxyProtocol(ProtocolTest):
b"User-Agent: something or other\r\n"
b"Connection: close\r\n"
b"\r\n"
))
)
lines = bytes_out.split(b"\r\n")
self.assertEqual(lines[0], b"HTTP/1.1 200 OK") # sanity check
@ -277,12 +337,12 @@ class TestProxyProtocol(ProtocolTest):
self.assertEqual(addr_lines, [b"https is on (scheme https)"] * 2)
def test_missing_proxy_line(self):
bytes_out = self._run_bytes_through_protocol((
bytes_out = self._run_bytes_through_protocol(
# whoops, no PROXY line here
b"GET /someurl HTTP/1.0\r\n"
b"User-Agent: something or other\r\n"
b"\r\n"
))
)
lines = [l for l in bytes_out.split(b"\r\n") if l]
self.assertIn(b"400 Invalid PROXY line", lines[0])
@ -303,12 +363,12 @@ class TestProxyProtocol(ProtocolTest):
for unknown_line in [b'PROXY UNKNOWN', # mimimal valid unknown
b'PROXY UNKNOWNblahblah', # also valid
b'PROXY UNKNOWN a b c d']:
bytes_out = self._run_bytes_through_protocol((
bytes_out = self._run_bytes_through_protocol(
unknown_line + (b"\r\n"
b"GET /someurl HTTP/1.0\r\n"
b"User-Agent: something or other\r\n"
b"\r\n")
))
)
lines = [l for l in bytes_out.split(b"\r\n") if l]
self.assertIn(b"200 OK", lines[0])