Merge "Update SERVER_ADDR/SERVER_PORT from PROXY protocol"
This commit is contained in:
@@ -447,6 +447,10 @@ class SwiftHttpProxiedProtocol(SwiftHttpProtocol):
|
|||||||
See http://www.haproxy.org/download/1.7/doc/proxy-protocol.txt for
|
See http://www.haproxy.org/download/1.7/doc/proxy-protocol.txt for
|
||||||
protocol details.
|
protocol details.
|
||||||
"""
|
"""
|
||||||
|
def __init__(self, *a, **kw):
|
||||||
|
self.proxy_address = None
|
||||||
|
SwiftHttpProtocol.__init__(self, *a, **kw)
|
||||||
|
|
||||||
def handle_error(self, connection_line):
|
def handle_error(self, connection_line):
|
||||||
if not six.PY2:
|
if not six.PY2:
|
||||||
connection_line = connection_line.decode('latin-1')
|
connection_line = connection_line.decode('latin-1')
|
||||||
@@ -477,16 +481,20 @@ class SwiftHttpProxiedProtocol(SwiftHttpProtocol):
|
|||||||
connection_line = self.rfile.readline(self.server.url_length_limit)
|
connection_line = self.rfile.readline(self.server.url_length_limit)
|
||||||
|
|
||||||
if connection_line.startswith(b'PROXY'):
|
if connection_line.startswith(b'PROXY'):
|
||||||
proxy_parts = connection_line.split(b' ')
|
proxy_parts = connection_line.strip(b'\r\n').split(b' ')
|
||||||
if len(proxy_parts) >= 2 and proxy_parts[0] == b'PROXY':
|
if len(proxy_parts) >= 2 and proxy_parts[0] == b'PROXY':
|
||||||
if proxy_parts[1] in (b'TCP4', b'TCP6') and \
|
if proxy_parts[1] in (b'TCP4', b'TCP6') and \
|
||||||
len(proxy_parts) == 6:
|
len(proxy_parts) == 6:
|
||||||
if six.PY2:
|
if six.PY2:
|
||||||
self.client_address = (proxy_parts[2], proxy_parts[4])
|
self.client_address = (proxy_parts[2], proxy_parts[4])
|
||||||
|
self.proxy_address = (proxy_parts[3], proxy_parts[5])
|
||||||
else:
|
else:
|
||||||
self.client_address = (
|
self.client_address = (
|
||||||
proxy_parts[2].decode('latin-1'),
|
proxy_parts[2].decode('latin-1'),
|
||||||
proxy_parts[4].decode('latin-1'))
|
proxy_parts[4].decode('latin-1'))
|
||||||
|
self.proxy_address = (
|
||||||
|
proxy_parts[3].decode('latin-1'),
|
||||||
|
proxy_parts[5].decode('latin-1'))
|
||||||
elif proxy_parts[1].startswith(b'UNKNOWN'):
|
elif proxy_parts[1].startswith(b'UNKNOWN'):
|
||||||
# "UNKNOWN", in PROXY protocol version 1, means "not
|
# "UNKNOWN", in PROXY protocol version 1, means "not
|
||||||
# TCP4 or TCP6". This includes completely legitimate
|
# TCP4 or TCP6". This includes completely legitimate
|
||||||
@@ -505,6 +513,16 @@ class SwiftHttpProxiedProtocol(SwiftHttpProtocol):
|
|||||||
|
|
||||||
return SwiftHttpProtocol.handle(self)
|
return SwiftHttpProtocol.handle(self)
|
||||||
|
|
||||||
|
def get_environ(self):
|
||||||
|
environ = SwiftHttpProtocol.get_environ(self)
|
||||||
|
if self.proxy_address:
|
||||||
|
environ['SERVER_ADDR'] = self.proxy_address[0]
|
||||||
|
environ['SERVER_PORT'] = self.proxy_address[1]
|
||||||
|
if self.proxy_address[1] == '443':
|
||||||
|
environ['wsgi.url_scheme'] = 'https'
|
||||||
|
environ['HTTPS'] = 'on'
|
||||||
|
return environ
|
||||||
|
|
||||||
|
|
||||||
def run_server(conf, logger, sock, global_conf=None):
|
def run_server(conf, logger, sock, global_conf=None):
|
||||||
# Ensure TZ environment variable exists to avoid stat('/etc/localtime') on
|
# Ensure TZ environment variable exists to avoid stat('/etc/localtime') on
|
||||||
|
|||||||
@@ -1093,9 +1093,17 @@ class TestProxyProtocol(unittest.TestCase):
|
|||||||
|
|
||||||
def dinky_app(env, start_response):
|
def dinky_app(env, start_response):
|
||||||
start_response("200 OK", [])
|
start_response("200 OK", [])
|
||||||
body = "got addr: %s %s\r\n" % (
|
body = '\r\n'.join([
|
||||||
env.get("REMOTE_ADDR", "<missing>"),
|
'got addr: %s %s' % (
|
||||||
env.get("REMOTE_PORT", "<missing>"))
|
env.get("REMOTE_ADDR", "<missing>"),
|
||||||
|
env.get("REMOTE_PORT", "<missing>")),
|
||||||
|
'on addr: %s %s' % (
|
||||||
|
env.get("SERVER_ADDR", "<missing>"),
|
||||||
|
env.get("SERVER_PORT", "<missing>")),
|
||||||
|
'https is %s (scheme %s)' % (
|
||||||
|
env.get("HTTPS", "<missing>"),
|
||||||
|
env.get("wsgi.url_scheme", "<missing>")),
|
||||||
|
]) + '\r\n'
|
||||||
return [body.encode("utf-8")]
|
return [body.encode("utf-8")]
|
||||||
|
|
||||||
fake_tcp_socket = mock.Mock(
|
fake_tcp_socket = mock.Mock(
|
||||||
@@ -1107,6 +1115,7 @@ class TestProxyProtocol(unittest.TestCase):
|
|||||||
# KeyboardInterrupt breaks the WSGI server out of
|
# KeyboardInterrupt breaks the WSGI server out of
|
||||||
# its infinite accept-process-close loop.
|
# its infinite accept-process-close loop.
|
||||||
KeyboardInterrupt]))
|
KeyboardInterrupt]))
|
||||||
|
del fake_listen_socket.do_handshake
|
||||||
|
|
||||||
# If we let the WSGI server close rfile/wfile then we can't access
|
# If we let the WSGI server close rfile/wfile then we can't access
|
||||||
# their contents any more.
|
# their contents any more.
|
||||||
@@ -1121,6 +1130,22 @@ class TestProxyProtocol(unittest.TestCase):
|
|||||||
return wfile.getvalue()
|
return wfile.getvalue()
|
||||||
|
|
||||||
def test_request_with_proxy(self):
|
def test_request_with_proxy(self):
|
||||||
|
bytes_out = self._run_bytes_through_protocol((
|
||||||
|
b"PROXY TCP4 192.168.0.1 192.168.0.11 56423 4433\r\n"
|
||||||
|
b"GET /someurl HTTP/1.0\r\n"
|
||||||
|
b"User-Agent: something or other\r\n"
|
||||||
|
b"\r\n"
|
||||||
|
), wsgi.SwiftHttpProxiedProtocol)
|
||||||
|
|
||||||
|
lines = [l for l in bytes_out.split(b"\r\n") if l]
|
||||||
|
self.assertEqual(lines[0], b"HTTP/1.1 200 OK") # sanity check
|
||||||
|
self.assertEqual(lines[-3:], [
|
||||||
|
b"got addr: 192.168.0.1 56423",
|
||||||
|
b"on addr: 192.168.0.11 4433",
|
||||||
|
b"https is <missing> (scheme http)",
|
||||||
|
])
|
||||||
|
|
||||||
|
def test_request_with_proxy_https(self):
|
||||||
bytes_out = self._run_bytes_through_protocol((
|
bytes_out = self._run_bytes_through_protocol((
|
||||||
b"PROXY TCP4 192.168.0.1 192.168.0.11 56423 443\r\n"
|
b"PROXY TCP4 192.168.0.1 192.168.0.11 56423 443\r\n"
|
||||||
b"GET /someurl HTTP/1.0\r\n"
|
b"GET /someurl HTTP/1.0\r\n"
|
||||||
@@ -1130,7 +1155,11 @@ class TestProxyProtocol(unittest.TestCase):
|
|||||||
|
|
||||||
lines = [l for l in bytes_out.split(b"\r\n") if l]
|
lines = [l for l in bytes_out.split(b"\r\n") if l]
|
||||||
self.assertEqual(lines[0], b"HTTP/1.1 200 OK") # sanity check
|
self.assertEqual(lines[0], b"HTTP/1.1 200 OK") # sanity check
|
||||||
self.assertEqual(lines[-1], b"got addr: 192.168.0.1 56423")
|
self.assertEqual(lines[-3:], [
|
||||||
|
b"got addr: 192.168.0.1 56423",
|
||||||
|
b"on addr: 192.168.0.11 443",
|
||||||
|
b"https is on (scheme https)",
|
||||||
|
])
|
||||||
|
|
||||||
def test_multiple_requests_with_proxy(self):
|
def test_multiple_requests_with_proxy(self):
|
||||||
bytes_out = self._run_bytes_through_protocol((
|
bytes_out = self._run_bytes_through_protocol((
|
||||||
@@ -1150,6 +1179,10 @@ class TestProxyProtocol(unittest.TestCase):
|
|||||||
# the address in the PROXY line is applied to every request
|
# the address in the PROXY line is applied to every request
|
||||||
addr_lines = [l for l in lines if l.startswith(b"got addr")]
|
addr_lines = [l for l in lines if l.startswith(b"got addr")]
|
||||||
self.assertEqual(addr_lines, [b"got addr: 192.168.0.1 56423"] * 2)
|
self.assertEqual(addr_lines, [b"got addr: 192.168.0.1 56423"] * 2)
|
||||||
|
addr_lines = [l for l in lines if l.startswith(b"on addr")]
|
||||||
|
self.assertEqual(addr_lines, [b"on addr: 192.168.0.11 443"] * 2)
|
||||||
|
addr_lines = [l for l in lines if l.startswith(b"https is")]
|
||||||
|
self.assertEqual(addr_lines, [b"https is on (scheme https)"] * 2)
|
||||||
|
|
||||||
def test_missing_proxy_line(self):
|
def test_missing_proxy_line(self):
|
||||||
bytes_out = self._run_bytes_through_protocol((
|
bytes_out = self._run_bytes_through_protocol((
|
||||||
|
|||||||
Reference in New Issue
Block a user