Merge "Update SERVER_ADDR/SERVER_PORT from PROXY protocol"
This commit is contained in:
@@ -447,6 +447,10 @@ class SwiftHttpProxiedProtocol(SwiftHttpProtocol):
|
||||
See http://www.haproxy.org/download/1.7/doc/proxy-protocol.txt for
|
||||
protocol details.
|
||||
"""
|
||||
def __init__(self, *a, **kw):
|
||||
self.proxy_address = None
|
||||
SwiftHttpProtocol.__init__(self, *a, **kw)
|
||||
|
||||
def handle_error(self, connection_line):
|
||||
if not six.PY2:
|
||||
connection_line = connection_line.decode('latin-1')
|
||||
@@ -477,16 +481,20 @@ class SwiftHttpProxiedProtocol(SwiftHttpProtocol):
|
||||
connection_line = self.rfile.readline(self.server.url_length_limit)
|
||||
|
||||
if connection_line.startswith(b'PROXY'):
|
||||
proxy_parts = connection_line.split(b' ')
|
||||
proxy_parts = connection_line.strip(b'\r\n').split(b' ')
|
||||
if len(proxy_parts) >= 2 and proxy_parts[0] == b'PROXY':
|
||||
if proxy_parts[1] in (b'TCP4', b'TCP6') and \
|
||||
len(proxy_parts) == 6:
|
||||
if six.PY2:
|
||||
self.client_address = (proxy_parts[2], proxy_parts[4])
|
||||
self.proxy_address = (proxy_parts[3], proxy_parts[5])
|
||||
else:
|
||||
self.client_address = (
|
||||
proxy_parts[2].decode('latin-1'),
|
||||
proxy_parts[4].decode('latin-1'))
|
||||
self.proxy_address = (
|
||||
proxy_parts[3].decode('latin-1'),
|
||||
proxy_parts[5].decode('latin-1'))
|
||||
elif proxy_parts[1].startswith(b'UNKNOWN'):
|
||||
# "UNKNOWN", in PROXY protocol version 1, means "not
|
||||
# TCP4 or TCP6". This includes completely legitimate
|
||||
@@ -505,6 +513,16 @@ class SwiftHttpProxiedProtocol(SwiftHttpProtocol):
|
||||
|
||||
return SwiftHttpProtocol.handle(self)
|
||||
|
||||
def get_environ(self):
|
||||
environ = SwiftHttpProtocol.get_environ(self)
|
||||
if self.proxy_address:
|
||||
environ['SERVER_ADDR'] = self.proxy_address[0]
|
||||
environ['SERVER_PORT'] = self.proxy_address[1]
|
||||
if self.proxy_address[1] == '443':
|
||||
environ['wsgi.url_scheme'] = 'https'
|
||||
environ['HTTPS'] = 'on'
|
||||
return environ
|
||||
|
||||
|
||||
def run_server(conf, logger, sock, global_conf=None):
|
||||
# Ensure TZ environment variable exists to avoid stat('/etc/localtime') on
|
||||
|
||||
@@ -1093,9 +1093,17 @@ class TestProxyProtocol(unittest.TestCase):
|
||||
|
||||
def dinky_app(env, start_response):
|
||||
start_response("200 OK", [])
|
||||
body = "got addr: %s %s\r\n" % (
|
||||
env.get("REMOTE_ADDR", "<missing>"),
|
||||
env.get("REMOTE_PORT", "<missing>"))
|
||||
body = '\r\n'.join([
|
||||
'got addr: %s %s' % (
|
||||
env.get("REMOTE_ADDR", "<missing>"),
|
||||
env.get("REMOTE_PORT", "<missing>")),
|
||||
'on addr: %s %s' % (
|
||||
env.get("SERVER_ADDR", "<missing>"),
|
||||
env.get("SERVER_PORT", "<missing>")),
|
||||
'https is %s (scheme %s)' % (
|
||||
env.get("HTTPS", "<missing>"),
|
||||
env.get("wsgi.url_scheme", "<missing>")),
|
||||
]) + '\r\n'
|
||||
return [body.encode("utf-8")]
|
||||
|
||||
fake_tcp_socket = mock.Mock(
|
||||
@@ -1107,6 +1115,7 @@ class TestProxyProtocol(unittest.TestCase):
|
||||
# KeyboardInterrupt breaks the WSGI server out of
|
||||
# its infinite accept-process-close loop.
|
||||
KeyboardInterrupt]))
|
||||
del fake_listen_socket.do_handshake
|
||||
|
||||
# If we let the WSGI server close rfile/wfile then we can't access
|
||||
# their contents any more.
|
||||
@@ -1121,6 +1130,22 @@ class TestProxyProtocol(unittest.TestCase):
|
||||
return wfile.getvalue()
|
||||
|
||||
def test_request_with_proxy(self):
|
||||
bytes_out = self._run_bytes_through_protocol((
|
||||
b"PROXY TCP4 192.168.0.1 192.168.0.11 56423 4433\r\n"
|
||||
b"GET /someurl HTTP/1.0\r\n"
|
||||
b"User-Agent: something or other\r\n"
|
||||
b"\r\n"
|
||||
), wsgi.SwiftHttpProxiedProtocol)
|
||||
|
||||
lines = [l for l in bytes_out.split(b"\r\n") if l]
|
||||
self.assertEqual(lines[0], b"HTTP/1.1 200 OK") # sanity check
|
||||
self.assertEqual(lines[-3:], [
|
||||
b"got addr: 192.168.0.1 56423",
|
||||
b"on addr: 192.168.0.11 4433",
|
||||
b"https is <missing> (scheme http)",
|
||||
])
|
||||
|
||||
def test_request_with_proxy_https(self):
|
||||
bytes_out = self._run_bytes_through_protocol((
|
||||
b"PROXY TCP4 192.168.0.1 192.168.0.11 56423 443\r\n"
|
||||
b"GET /someurl HTTP/1.0\r\n"
|
||||
@@ -1130,7 +1155,11 @@ class TestProxyProtocol(unittest.TestCase):
|
||||
|
||||
lines = [l for l in bytes_out.split(b"\r\n") if l]
|
||||
self.assertEqual(lines[0], b"HTTP/1.1 200 OK") # sanity check
|
||||
self.assertEqual(lines[-1], b"got addr: 192.168.0.1 56423")
|
||||
self.assertEqual(lines[-3:], [
|
||||
b"got addr: 192.168.0.1 56423",
|
||||
b"on addr: 192.168.0.11 443",
|
||||
b"https is on (scheme https)",
|
||||
])
|
||||
|
||||
def test_multiple_requests_with_proxy(self):
|
||||
bytes_out = self._run_bytes_through_protocol((
|
||||
@@ -1150,6 +1179,10 @@ class TestProxyProtocol(unittest.TestCase):
|
||||
# the address in the PROXY line is applied to every request
|
||||
addr_lines = [l for l in lines if l.startswith(b"got addr")]
|
||||
self.assertEqual(addr_lines, [b"got addr: 192.168.0.1 56423"] * 2)
|
||||
addr_lines = [l for l in lines if l.startswith(b"on addr")]
|
||||
self.assertEqual(addr_lines, [b"on addr: 192.168.0.11 443"] * 2)
|
||||
addr_lines = [l for l in lines if l.startswith(b"https is")]
|
||||
self.assertEqual(addr_lines, [b"https is on (scheme https)"] * 2)
|
||||
|
||||
def test_missing_proxy_line(self):
|
||||
bytes_out = self._run_bytes_through_protocol((
|
||||
|
||||
Reference in New Issue
Block a user