wsgi: Unix socket address representation; Thanks to Samuel Merritt

https://github.com/eventlet/eventlet/pull/320
This commit is contained in:
Sergey Shepelev
2016-08-09 23:46:37 +05:00
parent 15c5f02112
commit 4f0913d084
2 changed files with 46 additions and 17 deletions

View File

@@ -44,6 +44,15 @@ def format_date_time(timestamp):
) )
def addr_to_host_port(addr):
host = 'unix'
port = ''
if isinstance(addr, tuple):
host = addr[0]
port = addr[1]
return (host, port)
# Collections of error codes to compare against. Not all attributes are set # Collections of error codes to compare against. Not all attributes are set
# on errno module on all platforms, so some are literals :( # on errno module on all platforms, so some are literals :(
BAD_SOCK = set((errno.EBADF, 10053)) BAD_SOCK = set((errno.EBADF, 10053))
@@ -536,8 +545,8 @@ class HttpProtocol(BaseHTTPServer.BaseHTTPRequestHandler):
self.close_connection = 1 self.close_connection = 1
self.server.log.error(( self.server.log.error((
'chunked encoding error while discarding request body.' 'chunked encoding error while discarding request body.'
+ ' ip={0} request="{1}" error="{2}"').format( + ' client={0} request="{1}" error="{2}"').format(
self.get_client_ip(), self.requestline, e, self.get_client_address()[0], self.requestline, e,
)) ))
finish = time.time() finish = time.time()
@@ -545,9 +554,11 @@ class HttpProtocol(BaseHTTPServer.BaseHTTPRequestHandler):
hook(self.environ, *args, **kwargs) hook(self.environ, *args, **kwargs)
if self.server.log_output: if self.server.log_output:
client_host, client_port = self.get_client_address()
self.server.log.info(self.server.log_format % { self.server.log.info(self.server.log_format % {
'client_ip': self.get_client_ip(), 'client_ip': client_host,
'client_port': self.client_address[1], 'client_port': client_port,
'date_time': self.log_date_time_string(), 'date_time': self.log_date_time_string(),
'request_line': self.requestline, 'request_line': self.requestline,
'status_code': status_code[0], 'status_code': status_code[0],
@@ -555,13 +566,14 @@ class HttpProtocol(BaseHTTPServer.BaseHTTPRequestHandler):
'wall_seconds': finish - start, 'wall_seconds': finish - start,
}) })
def get_client_ip(self): def get_client_address(self):
client_ip = self.client_address[0] host, port = addr_to_host_port(self.client_address)
if self.server.log_x_forwarded_for: if self.server.log_x_forwarded_for:
forward = self.headers.get('X-Forwarded-For', '').replace(' ', '') forward = self.headers.get('X-Forwarded-For', '').replace(' ', '')
if forward: if forward:
client_ip = "%s,%s" % (forward, client_ip) host = forward + ',' + host
return client_ip return (host, port)
def get_environ(self): def get_environ(self):
env = self.server.get_environ() env = self.server.get_environ()
@@ -587,11 +599,13 @@ class HttpProtocol(BaseHTTPServer.BaseHTTPRequestHandler):
env['CONTENT_LENGTH'] = length env['CONTENT_LENGTH'] = length
env['SERVER_PROTOCOL'] = 'HTTP/1.0' env['SERVER_PROTOCOL'] = 'HTTP/1.0'
host, port = self.request.getsockname()[:2] sockname = self.request.getsockname()
env['SERVER_NAME'] = host server_addr = addr_to_host_port(sockname)
env['SERVER_PORT'] = str(port) env['SERVER_NAME'] = server_addr[0]
env['REMOTE_ADDR'] = self.client_address[0] env['SERVER_PORT'] = str(server_addr[1])
env['REMOTE_PORT'] = str(self.client_address[1]) client_addr = addr_to_host_port(self.client_address)
env['REMOTE_ADDR'] = client_addr[0]
env['REMOTE_PORT'] = str(client_addr[1])
env['GATEWAY_INTERFACE'] = 'CGI/1.1' env['GATEWAY_INTERFACE'] = 'CGI/1.1'
try: try:

View File

@@ -1515,19 +1515,34 @@ class TestHttpd(_TestBase):
self.assertEqual(result.headers_original[random_case_header[0]], random_case_header[1]) self.assertEqual(result.headers_original[random_case_header[0]], random_case_header[1])
def test_log_unix_address(self): def test_log_unix_address(self):
def app(environ, start_response):
start_response('200 OK', [])
return ['\n{0}={1}\n'.format(k, v).encode() for k, v in environ.items()]
tempdir = tempfile.mkdtemp('eventlet_test_log_unix_address') tempdir = tempfile.mkdtemp('eventlet_test_log_unix_address')
path = ''
try: try:
sock = eventlet.listen(tempdir + '/socket', socket.AF_UNIX) server_sock = eventlet.listen(tempdir + '/socket', socket.AF_UNIX)
path = sock.getsockname() path = server_sock.getsockname()
log = six.StringIO() log = six.StringIO()
self.spawn_server(sock=sock, log=log) self.spawn_server(site=app, sock=server_sock, log=log)
eventlet.sleep(0) # need to enter server loop eventlet.sleep(0) # need to enter server loop
assert 'http:' + path in log.getvalue() assert 'http:' + path in log.getvalue()
client_sock = eventlet.connect(path, family=socket.AF_UNIX)
client_sock.sendall(b'GET / HTTP/1.0\r\nHost: localhost\r\n\r\n')
result = read_http(client_sock)
client_sock.close()
assert '\nunix -' in log.getvalue()
finally: finally:
shutil.rmtree(tempdir) shutil.rmtree(tempdir)
assert result.status == 'HTTP/1.1 200 OK', repr(result) + log.getvalue()
assert b'\nSERVER_NAME=unix\n' in result.body
assert b'\nSERVER_PORT=\n' in result.body
assert b'\nREMOTE_ADDR=unix\n' in result.body
assert b'\nREMOTE_PORT=\n' in result.body
def test_headers_raw(self): def test_headers_raw(self):
def app(environ, start_response): def app(environ, start_response):
start_response('200 OK', []) start_response('200 OK', [])