Merge pull request #679 from kgriffs/polish_access_route

fix(Request): Improve access_route perf and error handling
This commit is contained in:
John Vrbanac
2016-01-04 18:52:35 -06:00
4 changed files with 50 additions and 22 deletions

View File

@@ -580,17 +580,28 @@ class Request(object):
@property
def access_route(self):
if self._cached_access_route is None:
access_route = []
# NOTE(kgriffs): Try different headers in order of
# preference; if none are found, fall back to REMOTE_ADDR.
#
# If one of these headers is present, but its value is
# malformed such that we end up with an empty list, or
# a non-empty list containing malformed values, go ahead
# and return the results as-is. The alternative would be
# to fall back to another header or to REMOTE_ADDR, but
# that only masks the problem; the operator needs to be
# aware that an upstream proxy is malfunctioning.
if 'HTTP_FORWARDED' in self.env:
access_route = self._parse_rfc_forwarded()
if not access_route and 'HTTP_X_FORWARDED_FOR' in self.env:
access_route = [ip.strip() for ip in
self.env['HTTP_X_FORWARDED_FOR'].split(',')]
if not access_route and 'HTTP_X_REAL_IP' in self.env:
access_route = [self.env['HTTP_X_REAL_IP']]
if not access_route and 'REMOTE_ADDR' in self.env:
access_route = [self.env['REMOTE_ADDR']]
self._cached_access_route = access_route
self._cached_access_route = self._parse_rfc_forwarded()
elif 'HTTP_X_FORWARDED_FOR' in self.env:
addresses = self.env['HTTP_X_FORWARDED_FOR'].split(',')
self._cached_access_route = [ip.strip() for ip in addresses]
elif 'HTTP_X_REAL_IP' in self.env:
self._cached_access_route = [self.env['HTTP_X_REAL_IP']]
elif 'REMOTE_ADDR' in self.env:
self._cached_access_route = [self.env['REMOTE_ADDR']]
else:
self._cached_access_route = []
return self._cached_access_route
@@ -1119,18 +1130,25 @@ class Request(object):
Returns:
list: addresses derived from "for" parameters.
"""
addr = []
for forwarded in self.env['HTTP_FORWARDED'].split(','):
for param in forwarded.split(';'):
param = param.strip().split('=', 1)
if len(param) == 1:
# PERF(kgriffs): Partition() is faster than split().
key, _, val = param.strip().partition('=')
if not val:
# NOTE(kgriffs): The '=' separator was not found or
# it was, but the value was missing.
continue
key, val = param
if key.lower() != 'for':
# we only want for params
# We only want "for" params
continue
host, _ = parse_host(unquote_string(val))
addr.append(host)
return addr

View File

@@ -386,14 +386,15 @@ def unquote_string(quoted):
Raises:
TypeError: `quoted` was not a ``str``.
"""
tmp_quoted = quoted.strip()
if len(tmp_quoted) < 2:
if len(quoted) < 2:
return quoted
elif tmp_quoted[0] != '"' or tmp_quoted[-1] != '"':
elif quoted[0] != '"' or quoted[-1] != '"':
# return original one, prevent side-effect
return quoted
tmp_quoted = tmp_quoted[1:-1]
tmp_quoted = quoted[1:-1]
# PERF(philiptzou): Most header strings don't contain "quoted-pair" which
# defined by RFC 7320. We use this little trick (quick string search) to
# speed up string parsing by preventing unnecessary processes if possible.

View File

@@ -23,12 +23,13 @@ class TestAccessRoute(testing.TestBase):
headers={
'Forwarded': ('for=192.0.2.43,for=,'
'for="[2001:db8:cafe::17]:555",'
'for=x,'
'for="unknown", by=_hidden,for="\\"\\\\",'
'for="_don\\\"t_\\try_this\\\\at_home_\\42",'
'for="198\\.51\\.100\\.17\\:1236";'
'proto=https;host=example.com')
}))
compares = ['192.0.2.43', '', '2001:db8:cafe::17',
compares = ['192.0.2.43', '2001:db8:cafe::17', 'x',
'unknown', '"\\', '_don"t_try_this\\at_home_42',
'198.51.100.17']
self.assertEqual(req.access_route, compares)
@@ -42,9 +43,9 @@ class TestAccessRoute(testing.TestBase):
headers={
'Forwarded': 'for'
}))
self.assertEqual(req.access_route, ['127.0.0.1'])
self.assertEqual(req.access_route, [])
# test cached
self.assertEqual(req.access_route, ['127.0.0.1'])
self.assertEqual(req.access_route, [])
def test_x_forwarded_for(self):
req = Request(testing.create_environ(
@@ -72,3 +73,10 @@ class TestAccessRoute(testing.TestBase):
host='example.com',
path='/access_route'))
self.assertEqual(req.access_route, ['127.0.0.1'])
def test_remote_addr_missing(self):
env = testing.create_environ(host='example.com', path='/access_route')
del env['REMOTE_ADDR']
req = Request(env)
self.assertEqual(req.access_route, [])

View File

@@ -29,7 +29,7 @@ def _is_iterable(thing):
def _run_server(stop_event):
class Things(object):
def on_get(self, req, resp):
pass
resp.body = req.remote_addr
def on_post(self, req, resp):
resp.body = req.stream.read(1000)
@@ -127,6 +127,7 @@ class TestWSGIReference(testing.TestBase):
def test_wsgiref_get(self):
resp = requests.get(_SERVER_BASE_URL)
self.assertEqual(resp.status_code, 200)
self.assertEqual(resp.text, '127.0.0.1')
def test_wsgiref_put(self):
body = '{}'