From 26b20ee7296442231e74c982891a9b85c217ff79 Mon Sep 17 00:00:00 2001 From: mmcardle Date: Fri, 18 May 2018 12:04:14 +0100 Subject: [PATCH] IP Range restrictions in temp urls This patch adds an additional optional parameter to tempurl which restricts the ip's from which a temp url can be used from. Change-Id: I23fe998a980960d4a32df042b3f6a21f096c36af --- requirements.txt | 1 + swift/common/middleware/tempurl.py | 101 +++++++-- swift/common/utils.py | 30 ++- test/unit/common/middleware/test_tempurl.py | 214 +++++++++++++++++++- test/unit/common/test_utils.py | 10 + test/unit/common/test_wsgi.py | 39 ++++ 6 files changed, 362 insertions(+), 33 deletions(-) diff --git a/requirements.txt b/requirements.txt index 4b46b0984c..3c27869c7e 100644 --- a/requirements.txt +++ b/requirements.txt @@ -13,3 +13,4 @@ six>=1.9.0 xattr>=0.4 PyECLib>=1.3.1 # BSD cryptography!=2.0,>=1.6 # BSD/Apache-2.0 +ipaddress>=1.0.17;python_version<'3.3' # PSF diff --git a/swift/common/middleware/tempurl.py b/swift/common/middleware/tempurl.py index 19605b8c87..b32ab54582 100644 --- a/swift/common/middleware/tempurl.py +++ b/swift/common/middleware/tempurl.py @@ -49,6 +49,10 @@ contain signatures which are valid for all objects which share a common prefix. These prefix-based URLs are useful for sharing a set of objects. +Restrictions can also be placed on the ip that the resource is allowed +to be accessed from. This can be useful for locking down where the urls +can be used from. + ------------ Client Usage ------------ @@ -148,6 +152,52 @@ Another valid URL:: temp_url_expires=1323479485& temp_url_prefix=pre +If you wish to lock down the ip ranges from where the resource can be accessed +to the ip 1.2.3.4:: + + import hmac + from hashlib import sha1 + from time import time + method = 'GET' + expires = int(time() + 60) + path = '/v1/AUTH_account/container/object' + ip_range = '1.2.3.4' + key = 'mykey' + hmac_body = 'ip=%s\n%s\n%s\n%s' % (ip_range, method, expires, path) + sig = hmac.new(key, hmac_body, sha1).hexdigest() + +The generated signature would only be valid from the ip ``1.2.3.4``. The +middleware detects an ip-based temporary URL by a query parameter called +``temp_url_ip_range``. So, if ``sig`` and ``expires`` would end up like +above, following URL would be valid:: + + https://swift-cluster.example.com/v1/AUTH_account/container/object? + temp_url_sig=da39a3ee5e6b4b0d3255bfef95601890afd80709& + temp_url_expires=1323479485& + temp_url_ip_range=1.2.3.4 + +Similarly to lock down the ip to a range of ``1.2.3.X`` so starting +from the ip ``1.2.3.0`` to ``1.2.3.255``. + + import hmac + from hashlib import sha1 + from time import time + method = 'GET' + expires = int(time() + 60) + path = '/v1/AUTH_account/container/object' + ip_range = '1.2.3.0/24' + key = 'mykey' + hmac_body = 'ip=%s\n%s\n%s\n%s' % (ip_range, method, expires, path) + sig = hmac.new(key, hmac_body, sha1).hexdigest() + +Then the following url would be valid + + https://swift-cluster.example.com/v1/AUTH_account/container/object? + temp_url_sig=da39a3ee5e6b4b0d3255bfef95601890afd80709& + temp_url_expires=1323479485& + temp_url_ip_range=1.2.3.0/24 + + Any alteration of the resource path or query arguments of a temporary URL would result in ``401 Unauthorized``. Similarly, a ``PUT`` where ``GET`` was the allowed method would be rejected with ``401 Unauthorized``. @@ -239,7 +289,6 @@ This middleware understands the following configuration settings: to be used when calculating the signature for a temporary URL. Default: ``sha1 sha256 sha512`` - """ __all__ = ['TempURL', 'filter_factory', @@ -252,8 +301,10 @@ import binascii from calendar import timegm import functools import hashlib +import six from os.path import basename from time import time, strftime, strptime, gmtime +from ipaddress import ip_address, ip_network from six.moves.urllib.parse import parse_qs from six.moves.urllib.parse import urlencode @@ -446,7 +497,7 @@ class TempURL(object): return self.app(env, start_response) info = self._get_temp_url_info(env) temp_url_sig, temp_url_expires, temp_url_prefix, filename,\ - inline_disposition = info + inline_disposition, temp_url_ip_range = info if temp_url_sig is None and temp_url_expires is None: return self.app(env, start_response) if not temp_url_sig or not temp_url_expires: @@ -474,6 +525,18 @@ class TempURL(object): account, container, obj = self._get_path_parts(env) if not account: return self._invalid(env, start_response) + + if temp_url_ip_range: + client_address = env.get('REMOTE_ADDR') + if client_address is None: + return self._invalid(env, start_response) + try: + allowed_ip_ranges = ip_network(six.u(temp_url_ip_range)) + if ip_address(six.u(client_address)) not in allowed_ip_ranges: + return self._invalid(env, start_response) + except ValueError: + return self._invalid(env, start_response) + keys = self._get_keys(env) if not keys: return self._invalid(env, start_response) @@ -489,10 +552,11 @@ class TempURL(object): hmac for method in ('HEAD', 'GET', 'POST', 'PUT') for hmac in self._get_hmacs( env, temp_url_expires, path, keys, hash_algorithm, - request_method=method)] + request_method=method, ip_range=temp_url_ip_range)] else: hmac_vals = self._get_hmacs( - env, temp_url_expires, path, keys, hash_algorithm) + env, temp_url_expires, path, keys, hash_algorithm, + ip_range=temp_url_ip_range) is_valid_hmac = False hmac_scope = None @@ -594,18 +658,22 @@ class TempURL(object): def _get_temp_url_info(self, env): """ - Returns the provided temporary URL parameters (sig, expires, prefix), - if given and syntactically valid. Either sig, expires or prefix could - be None if not provided. If provided, expires is also - converted to an int if possible or 0 if not, and checked for - expiration (returns 0 if expired). + Returns the provided temporary URL parameters (sig, expires, prefix, + temp_url_ip_range), if given and syntactically valid. + Either sig, expires or prefix could be None if not provided. + If provided, expires is also converted to an int if possible or 0 + if not, and checked for expiration (returns 0 if expired). :param env: The WSGI environment for the request. - :returns: (sig, expires, prefix, filename, inline) as described above. + :returns: (sig, expires, prefix, filename, inline, + temp_url_ip_range) as described above. """ temp_url_sig = temp_url_expires = temp_url_prefix = filename =\ inline = None + temp_url_ip_range = None qs = parse_qs(env.get('QUERY_STRING', ''), keep_blank_values=True) + if 'temp_url_ip_range' in qs: + temp_url_ip_range = qs['temp_url_ip_range'][0] if 'temp_url_sig' in qs: temp_url_sig = qs['temp_url_sig'][0] if 'temp_url_expires' in qs: @@ -627,7 +695,7 @@ class TempURL(object): if 'inline' in qs: inline = True return (temp_url_sig, temp_url_expires, temp_url_prefix, filename, - inline) + inline, temp_url_ip_range) def _get_keys(self, env): """ @@ -658,7 +726,7 @@ class TempURL(object): [(ck, CONTAINER_SCOPE) for ck in container_keys]) def _get_hmacs(self, env, expires, path, scoped_keys, hash_algorithm, - request_method=None): + request_method=None, ip_range=None): """ :param env: The WSGI environment for the request. :param expires: Unix timestamp as an int for when the URL @@ -671,15 +739,20 @@ class TempURL(object): does not match, you may wish to override with GET to still allow the HEAD. - + :param ip_range: The ip range from which the resource is allowed + to be accessed :returns: a list of (hmac, scope) 2-tuples """ if not request_method: request_method = env['REQUEST_METHOD'] digest = functools.partial(hashlib.new, hash_algorithm) + return [ - (get_hmac(request_method, path, expires, key, digest), scope) + (get_hmac( + request_method, path, expires, key, + digest=digest, ip_range=ip_range + ), scope) for (key, scope) in scoped_keys] def _invalid(self, env, start_response): diff --git a/swift/common/utils.py b/swift/common/utils.py index 64717f91d2..b765f3ab30 100644 --- a/swift/common/utils.py +++ b/swift/common/utils.py @@ -255,7 +255,8 @@ except InvalidHashPathConfigError: pass -def get_hmac(request_method, path, expires, key, digest=sha1): +def get_hmac(request_method, path, expires, key, digest=sha1, + ip_range=None): """ Returns the hexdigest string of the HMAC (see RFC 2104) for the request. @@ -267,18 +268,31 @@ def get_hmac(request_method, path, expires, key, digest=sha1): :param key: HMAC shared secret. :param digest: constructor for the digest to use in calculating the HMAC Defaults to SHA1 - + :param ip_range: The ip range from which the resource is allowed + to be accessed. We need to put the ip_range as the + first argument to hmac to avoid manipulation of the path + due to newlines being valid in paths + e.g. /v1/a/c/o\\n127.0.0.1 :returns: hexdigest str of the HMAC for the request using the specified digest algorithm. """ - parts = (request_method, str(expires), path) + # These are the three mandatory fields. + parts = [request_method, str(expires), path] + formats = [b"%s", b"%s", b"%s"] + + if ip_range: + parts.insert(0, ip_range) + formats.insert(0, b"ip=%s") + if not isinstance(key, six.binary_type): key = key.encode('utf8') - return hmac.new( - key, b'\n'.join( - x if isinstance(x, six.binary_type) else x.encode('utf8') - for x in parts), - digest).hexdigest() + + message = b'\n'.join( + fmt % (part if isinstance(part, six.binary_type) + else part.encode("utf-8")) + for fmt, part in zip(formats, parts)) + + return hmac.new(key, message, digest).hexdigest() # Used by get_swift_info and register_swift_info to store information about diff --git a/test/unit/common/middleware/test_tempurl.py b/test/unit/common/middleware/test_tempurl.py index 8456aab6b3..91114f90e2 100644 --- a/test/unit/common/middleware/test_tempurl.py +++ b/test/unit/common/middleware/test_tempurl.py @@ -812,6 +812,52 @@ class TestTempURL(unittest.TestCase): self.assertIn('Temp URL invalid', resp.body) self.assertIn('Www-Authenticate', resp.headers) + def test_ip_range_value_error(self): + method = 'GET' + expires = int(time() + 86400) + path = '/v1/a/c/o' + key = 'abc' + ip = '127.0.0.1' + not_an_ip = 'abcd' + hmac_body = 'ip=%s\n%s\n%s\n%s' % (ip, method, expires, path) + sig = hmac.new(key, hmac_body, hashlib.sha1).hexdigest() + req = self._make_request( + path, keys=[key], + environ={ + 'QUERY_STRING': + 'temp_url_sig=%s&temp_url_expires=%s&temp_url_ip_range=%s' + % (sig, expires, not_an_ip), + 'REMOTE_ADDR': '127.0.0.1' + }, + ) + resp = req.get_response(self.tempurl) + self.assertEqual(resp.status_int, 401) + self.assertIn('Temp URL invalid', resp.body) + self.assertIn('Www-Authenticate', resp.headers) + + def test_bad_ip_range_invalid(self): + method = 'GET' + expires = int(time() + 86400) + path = '/v1/a/c/o' + key = 'abc' + ip = '127.0.0.1' + bad_ip = '127.0.0.2' + hmac_body = 'ip=%s\n%s\n%s\n%s' % (ip, method, expires, path) + sig = hmac.new(key, hmac_body, hashlib.sha1).hexdigest() + req = self._make_request( + path, keys=[key], + environ={ + 'QUERY_STRING': + 'temp_url_sig=%s&temp_url_expires=%s&temp_url_ip_range=%s' + % (sig, expires, ip), + 'REMOTE_ADDR': bad_ip + }, + ) + resp = req.get_response(self.tempurl) + self.assertEqual(resp.status_int, 401) + self.assertIn('Temp URL invalid', resp.body) + self.assertIn('Www-Authenticate', resp.headers) + def test_different_key_invalid(self): method = 'GET' expires = int(time() + 86400) @@ -1098,44 +1144,51 @@ class TestTempURL(unittest.TestCase): self.tempurl._get_temp_url_info( {'QUERY_STRING': 'temp_url_sig=%s&temp_url_expires=%s' % ( s, e)}), - (s, e_ts, None, None, None)) + (s, e_ts, None, None, None, None)) self.assertEqual( self.tempurl._get_temp_url_info( {'QUERY_STRING': 'temp_url_sig=%s&temp_url_expires=%s&temp_url_prefix=%s' % (s, e, 'prefix')}), - (s, e_ts, 'prefix', None, None)) + (s, e_ts, 'prefix', None, None, None)) self.assertEqual( self.tempurl._get_temp_url_info( {'QUERY_STRING': 'temp_url_sig=%s&temp_url_expires=%s&' 'filename=bobisyouruncle' % (s, e)}), - (s, e_ts, None, 'bobisyouruncle', None)) + (s, e_ts, None, 'bobisyouruncle', None, None)) self.assertEqual( self.tempurl._get_temp_url_info({}), - (None, None, None, None, None)) + (None, None, None, None, None, None)) self.assertEqual( self.tempurl._get_temp_url_info( {'QUERY_STRING': 'temp_url_expires=%s' % e}), - (None, e_ts, None, None, None)) + (None, e_ts, None, None, None, None)) self.assertEqual( self.tempurl._get_temp_url_info( {'QUERY_STRING': 'temp_url_sig=%s' % s}), - (s, None, None, None, None)) + (s, None, None, None, None, None)) self.assertEqual( self.tempurl._get_temp_url_info( {'QUERY_STRING': 'temp_url_sig=%s&temp_url_expires=bad' % ( s)}), - (s, 0, None, None, None)) + (s, 0, None, None, None, None)) self.assertEqual( self.tempurl._get_temp_url_info( {'QUERY_STRING': 'temp_url_sig=%s&temp_url_expires=%s&' 'inline=' % (s, e)}), - (s, e_ts, None, None, True)) + (s, e_ts, None, None, True, None)) self.assertEqual( self.tempurl._get_temp_url_info( {'QUERY_STRING': 'temp_url_sig=%s&temp_url_expires=%s&' 'filename=bobisyouruncle&inline=' % (s, e)}), - (s, e_ts, None, 'bobisyouruncle', True)) + (s, e_ts, None, 'bobisyouruncle', True, None)) + self.assertEqual( + self.tempurl._get_temp_url_info( + {'QUERY_STRING': 'temp_url_sig=%s&temp_url_expires=%s&' + 'filename=bobisyouruncle&inline=' + '&temp_url_ip_range=127.0.0.1' % (s, e)}), + (s, e_ts, None, 'bobisyouruncle', True, '127.0.0.1')) + e_ts = int(time() - 1) e_8601 = strftime(tempurl.EXPIRES_ISO8601_FORMAT, gmtime(e_ts)) for e in (e_ts, e_8601): @@ -1143,14 +1196,14 @@ class TestTempURL(unittest.TestCase): self.tempurl._get_temp_url_info( {'QUERY_STRING': 'temp_url_sig=%s&temp_url_expires=%s' % ( s, e)}), - (s, 0, None, None, None)) + (s, 0, None, None, None, None)) # Offsets not supported (yet?). e_8601 = strftime('%Y-%m-%dT%H:%M:%S+0000', gmtime(e_ts)) self.assertEqual( self.tempurl._get_temp_url_info( {'QUERY_STRING': 'temp_url_sig=%s&temp_url_expires=%s' % ( s, e_8601)}), - (s, 0, None, None, None)) + (s, 0, None, None, None, None)) def test_get_hmacs(self): self.assertEqual( @@ -1165,6 +1218,15 @@ class TestTempURL(unittest.TestCase): [('240866478d94bbe683ab1d25fba52c7d0df21a60951' '4fe6a493dc30f951d2748abc51da0cbc633cd1e0acf' '6fadd3af3aedff00ee3d3434dc6a4c423e74adfc4a', 'account')]) + self.assertEqual( + self.tempurl._get_hmacs( + {'REQUEST_METHOD': 'HEAD'}, 1, '/v1/a/c/o', + [('abc', 'account')], 'sha512', request_method='GET', + ip_range='127.0.0.1' + ), + [('b713f99a66911cdf41dbcdff16db3efbd1ca89340a20' + '86cc2ed88f0d3a74c7159e7687a312b12345d3721b7b' + '94e36c2753d7cc01e9a91cc318c5081d788f2cfe', 'account')]) def test_invalid(self): @@ -1321,6 +1383,136 @@ class TestTempURL(unittest.TestCase): for str_value in results: self.assertIsInstance(str_value, str) + @mock.patch('swift.common.middleware.tempurl.time', return_value=0) + def test_get_valid_with_ip_range(self, mock_time): + method = 'GET' + expires = (((24 + 1) * 60 + 1) * 60) + 1 + path = '/v1/a/c/o' + key = 'abc' + ip_range = '127.0.0.0/29' + hmac_body = 'ip=%s\n%s\n%s\n%s' % (ip_range, method, expires, path) + sig = hmac.new(key, hmac_body, hashlib.sha1).hexdigest() + req = self._make_request(path, keys=[key], environ={ + 'QUERY_STRING': 'temp_url_sig=%s&temp_url_expires=%s&' + 'temp_url_ip_range=%s' % (sig, expires, ip_range), + 'REMOTE_ADDR': '127.0.0.1'}, + ) + self.tempurl.app = FakeApp(iter([('200 Ok', (), '123')])) + resp = req.get_response(self.tempurl) + self.assertEqual(resp.status_int, 200) + self.assertIn('expires', resp.headers) + self.assertEqual('Fri, 02 Jan 1970 01:01:01 GMT', + resp.headers['expires']) + self.assertEqual(req.environ['swift.authorize_override'], True) + self.assertEqual(req.environ['REMOTE_USER'], '.wsgi.tempurl') + + @mock.patch('swift.common.middleware.tempurl.time', return_value=0) + def test_get_valid_with_ip_from_remote_addr(self, mock_time): + method = 'GET' + expires = (((24 + 1) * 60 + 1) * 60) + 1 + path = '/v1/a/c/o' + key = 'abc' + ip = '127.0.0.1' + hmac_body = 'ip=%s\n%s\n%s\n%s' % (ip, method, expires, path) + sig = hmac.new(key, hmac_body, hashlib.sha1).hexdigest() + req = self._make_request(path, keys=[key], environ={ + 'QUERY_STRING': 'temp_url_sig=%s&temp_url_expires=%s&' + 'temp_url_ip_range=%s' % (sig, expires, ip), + 'REMOTE_ADDR': ip}, + ) + self.tempurl.app = FakeApp(iter([('200 Ok', (), '123')])) + resp = req.get_response(self.tempurl) + self.assertEqual(resp.status_int, 200) + self.assertIn('expires', resp.headers) + self.assertEqual('Fri, 02 Jan 1970 01:01:01 GMT', + resp.headers['expires']) + self.assertEqual(req.environ['swift.authorize_override'], True) + self.assertEqual(req.environ['REMOTE_USER'], '.wsgi.tempurl') + + def test_get_valid_with_fake_ip_from_x_forwarded_for(self): + method = 'GET' + expires = (((24 + 1) * 60 + 1) * 60) + 1 + path = '/v1/a/c/o' + key = 'abc' + ip = '127.0.0.1' + remote_addr = '127.0.0.2' + hmac_body = 'ip=%s\n%s\n%s\n%s' % (ip, method, expires, path) + sig = hmac.new(key, hmac_body, hashlib.sha1).hexdigest() + req = self._make_request(path, keys=[key], environ={ + 'QUERY_STRING': 'temp_url_sig=%s&temp_url_expires=%s&' + 'temp_url_ip_range=%s' % (sig, expires, ip), + 'REMOTE_ADDR': remote_addr}, + headers={'x-forwarded-for': ip}) + self.tempurl.app = FakeApp(iter([('200 Ok', (), '123')])) + resp = req.get_response(self.tempurl) + self.assertEqual(resp.status_int, 401) + self.assertIn('Temp URL invalid', resp.body) + self.assertIn('Www-Authenticate', resp.headers) + + @mock.patch('swift.common.middleware.tempurl.time', return_value=0) + def test_get_valid_with_single_ipv6(self, mock_time): + method = 'GET' + expires = (((24 + 1) * 60 + 1) * 60) + 1 + path = '/v1/a/c/o' + key = 'abc' + ip = '2001:db8::' + hmac_body = 'ip=%s\n%s\n%s\n%s' % (ip, method, expires, path) + sig = hmac.new(key, hmac_body, hashlib.sha1).hexdigest() + req = self._make_request(path, keys=[key], environ={ + 'QUERY_STRING': 'temp_url_sig=%s&temp_url_expires=%s&' + 'temp_url_ip_range=%s' % (sig, expires, ip), + 'REMOTE_ADDR': '2001:db8::'}, + ) + self.tempurl.app = FakeApp(iter([('200 Ok', (), '123')])) + resp = req.get_response(self.tempurl) + self.assertEqual(resp.status_int, 200) + self.assertIn('expires', resp.headers) + self.assertEqual('Fri, 02 Jan 1970 01:01:01 GMT', + resp.headers['expires']) + self.assertEqual(req.environ['swift.authorize_override'], True) + self.assertEqual(req.environ['REMOTE_USER'], '.wsgi.tempurl') + + @mock.patch('swift.common.middleware.tempurl.time', return_value=0) + def test_get_valid_with_ipv6_range(self, mock_time): + method = 'GET' + expires = (((24 + 1) * 60 + 1) * 60) + 1 + path = '/v1/a/c/o' + key = 'abc' + ip_range = '2001:db8::/127' + hmac_body = 'ip=%s\n%s\n%s\n%s' % (ip_range, method, expires, path) + sig = hmac.new(key, hmac_body, hashlib.sha1).hexdigest() + req = self._make_request(path, keys=[key], environ={ + 'QUERY_STRING': 'temp_url_sig=%s&temp_url_expires=%s&' + 'temp_url_ip_range=%s' % (sig, expires, ip_range), + 'REMOTE_ADDR': '2001:db8::'}, + ) + self.tempurl.app = FakeApp(iter([('200 Ok', (), '123')])) + resp = req.get_response(self.tempurl) + self.assertEqual(resp.status_int, 200) + self.assertIn('expires', resp.headers) + self.assertEqual('Fri, 02 Jan 1970 01:01:01 GMT', + resp.headers['expires']) + self.assertEqual(req.environ['swift.authorize_override'], True) + self.assertEqual(req.environ['REMOTE_USER'], '.wsgi.tempurl') + + def test_get_valid_with_no_client_address(self): + method = 'GET' + expires = (((24 + 1) * 60 + 1) * 60) + 1 + path = '/v1/a/c/o' + key = 'abc' + ip = '127.0.0.1' + hmac_body = '%s\n%s\n%s\n%s' % (ip, method, expires, path) + sig = hmac.new(key, hmac_body, hashlib.sha1).hexdigest() + req = self._make_request(path, keys=[key], environ={ + 'QUERY_STRING': 'temp_url_sig=%s&temp_url_expires=%s&' + 'temp_url_ip_range=%s' % (sig, expires, ip)}, + ) + self.tempurl.app = FakeApp(iter([('200 Ok', (), '123')])) + resp = req.get_response(self.tempurl) + self.assertEqual(resp.status_int, 401) + self.assertIn('Temp URL invalid', resp.body) + self.assertIn('Www-Authenticate', resp.headers) + class TestSwiftInfo(unittest.TestCase): def setUp(self): diff --git a/test/unit/common/test_utils.py b/test/unit/common/test_utils.py index f399e7b98f..5c020e7b7a 100644 --- a/test/unit/common/test_utils.py +++ b/test/unit/common/test_utils.py @@ -3629,6 +3629,16 @@ cluster_dfw1 = http://dfw1.host/v1/ utils.get_hmac('GET', '/path', 1, 'abc'), 'b17f6ff8da0e251737aa9e3ee69a881e3e092e2f') + def test_get_hmac_ip_range(self): + self.assertEqual( + utils.get_hmac('GET', '/path', 1, 'abc', ip_range='127.0.0.1'), + 'b30dde4d2b8562b8496466c3b46b2b9ac5054461') + + def test_get_hmac_ip_range_non_binary_type(self): + self.assertEqual( + utils.get_hmac(u'GET', u'/path', 1, u'abc', ip_range=u'127.0.0.1'), + 'b30dde4d2b8562b8496466c3b46b2b9ac5054461') + def test_parse_override_options(self): # When override_ is passed in, it takes precedence. opts = utils.parse_override_options( diff --git a/test/unit/common/test_wsgi.py b/test/unit/common/test_wsgi.py index 5c6b91c0cb..c64f01d549 100644 --- a/test/unit/common/test_wsgi.py +++ b/test/unit/common/test_wsgi.py @@ -1225,6 +1225,45 @@ class TestProxyProtocol(unittest.TestCase): lines = [l for l in bytes_out.split(b"\r\n") if l] self.assertIn(b"200 OK", lines[0]) + def test_address_and_environ(self): + # Make an object we can exercise... note the base class's __init__() + # does a bunch of work, so we just new up an object like eventlet.wsgi + # does. + dummy_env = {'OTHER_ENV_KEY': 'OTHER_ENV_VALUE'} + mock_protocol = mock.Mock(get_environ=lambda s: dummy_env) + patcher = mock.patch( + 'swift.common.wsgi.SwiftHttpProtocol', mock_protocol + ) + self.mock_super = patcher.start() + self.addCleanup(patcher.stop) + + proto_class = wsgi.SwiftHttpProxiedProtocol + try: + proxy_obj = types.InstanceType(proto_class) + except AttributeError: + proxy_obj = proto_class.__new__(proto_class) + + # Install some convenience mocks + proxy_obj.server = Namespace(app=Namespace(logger=mock.Mock()), + url_length_limit=777, + log=mock.Mock()) + proxy_obj.send_error = mock.Mock() + + proxy_obj.rfile = BytesIO( + b'PROXY TCP4 111.111.111.111 222.222.222.222 111 222' + ) + + assert proxy_obj.handle() + + self.assertEqual(proxy_obj.client_address, ('111.111.111.111', '111')) + self.assertEqual(proxy_obj.proxy_address, ('222.222.222.222', '222')) + expected_env = { + 'SERVER_PORT': '222', + 'SERVER_ADDR': '222.222.222.222', + 'OTHER_ENV_KEY': 'OTHER_ENV_VALUE' + } + self.assertEqual(proxy_obj.get_environ(), expected_env) + class TestServersPerPortStrategy(unittest.TestCase): def setUp(self):