Added support for logging x-forwarded-for header in wsgi. Unit test to ensure it works as designed.

This commit is contained in:
Ryan Williams
2009-11-25 01:29:21 -05:00
parent 0b9db4a713
commit 118490147a
2 changed files with 51 additions and 10 deletions

View File

@@ -286,12 +286,20 @@ class HttpProtocol(BaseHTTPServer.BaseHTTPRequestHandler):
finish = time.time()
self.server.log_message('%s - - [%s] "%s" %s %s %.6f' % (
self.client_address[0],
self.get_client_ip(),
self.log_date_time_string(),
self.requestline,
status_code[0],
length[0],
finish - start))
def get_client_ip(self):
client_ip = self.client_address[0]
if self.server.log_x_forwarded_for:
forward = self.headers.get('X-Forwarded-For', '').replace(' ', '')
if forward:
client_ip = "%s,%s" % (forward, client_ip)
return client_ip
def get_environ(self):
env = self.server.get_environ()
@@ -361,7 +369,8 @@ class Server(BaseHTTPServer.HTTPServer):
environ=None,
max_http_version=None,
protocol=HttpProtocol,
minimum_chunk_size=None):
minimum_chunk_size=None,
log_x_forwarded_for=True):
self.outstanding_requests = 0
self.socket = socket
@@ -377,6 +386,7 @@ class Server(BaseHTTPServer.HTTPServer):
self.pid = os.getpid()
if minimum_chunk_size is not None:
protocol.minimum_chunk_size = minimum_chunk_size
self.log_x_forwarded_for = log_x_forwarded_for
def get_environ(self):
socket = self.socket
@@ -407,7 +417,8 @@ def server(sock, site,
max_http_version=DEFAULT_MAX_HTTP_VERSION,
protocol=HttpProtocol,
server_event=None,
minimum_chunk_size=None):
minimum_chunk_size=None,
log_x_forwarded_for=True):
""" Start up a wsgi server handling requests from the supplied server socket.
This function loops forever.
@@ -418,7 +429,8 @@ def server(sock, site,
environ=None,
max_http_version=max_http_version,
protocol=protocol,
minimum_chunk_size=minimum_chunk_size)
minimum_chunk_size=minimum_chunk_size,
log_x_forwarded_for=log_x_forwarded_for)
if server_event is not None:
server_event.send(serv)
if max_size is None:

View File

@@ -278,7 +278,7 @@ class TestHttpd(LimitedTestCase):
server_sock = api.ssl_listener(('localhost', 0), certificate_file, private_key_file)
api.spawn(wsgi.server, server_sock, wsgi_app)
api.spawn(wsgi.server, server_sock, wsgi_app, log=StringIO())
sock = api.connect_tcp(('localhost', server_sock.getsockname()[1]))
sock = util.wrap_ssl(sock)
@@ -294,7 +294,7 @@ class TestHttpd(LimitedTestCase):
certificate_file = os.path.join(os.path.dirname(__file__), 'test_server.crt')
private_key_file = os.path.join(os.path.dirname(__file__), 'test_server.key')
server_sock = api.ssl_listener(('localhost', 0), certificate_file, private_key_file)
api.spawn(wsgi.server, server_sock, wsgi_app)
api.spawn(wsgi.server, server_sock, wsgi_app, log=StringIO())
sock = api.connect_tcp(('localhost', server_sock.getsockname()[1]))
sock = util.wrap_ssl(sock)
@@ -354,6 +354,7 @@ class TestHttpd(LimitedTestCase):
def wsgi_app(environ, start_response):
start_response('200 OK', [('Content-Length', '7')])
return ['testing']
self.site.application = wsgi_app
sock = api.connect_tcp(('localhost', self.port))
fd = sock.makeGreenFile()
fd.write('GET /a HTTP/1.1\r\nHost: localhost\r\nConnection: close\r\n\r\n')
@@ -363,7 +364,7 @@ class TestHttpd(LimitedTestCase):
def test_017_ssl_zeroreturnerror(self):
def server(sock, site, log=None):
def server(sock, site, log):
try:
serv = wsgi.Server(sock, sock.getsockname(), site, log)
client_socket = sock.accept()
@@ -375,7 +376,7 @@ class TestHttpd(LimitedTestCase):
return False
def wsgi_app(environ, start_response):
start_response('200 OK', {})
start_response('200 OK', [])
return [environ['wsgi.input'].read()]
certificate_file = os.path.join(os.path.dirname(__file__), 'test_server.crt')
@@ -384,7 +385,7 @@ class TestHttpd(LimitedTestCase):
sock = api.ssl_listener(('localhost', 0), certificate_file, private_key_file)
from eventlet import coros
server_coro = coros.execute(server, sock, wsgi_app)
server_coro = coros.execute(server, sock, wsgi_app, self.logfile)
client = api.connect_tcp(('localhost', sock.getsockname()[1]))
client = util.wrap_ssl(client)
@@ -431,6 +432,34 @@ class TestHttpd(LimitedTestCase):
'4\r\n hai\r\n0\r\n\r\n')
self.assert_('hello!' in fd.read())
def test_020_x_forwarded_for(self):
sock = api.connect_tcp(('localhost', self.port))
sock.sendall('GET / HTTP/1.1\r\nHost: localhost\r\nX-Forwarded-For: 1.2.3.4, 5.6.7.8\r\n\r\n')
sock.recv(1024)
sock.close()
self.assert_('1.2.3.4,5.6.7.8,127.0.0.1' in self.logfile.getvalue())
# turning off the option should work too
self.logfile = StringIO()
api.kill(self.killer)
listener = api.tcp_listener(('localhost', 0))
self.port = listener.getsockname()[1]
self.killer = api.spawn(
wsgi.server,
listener,
self.site,
max_size=128,
log=self.logfile,
log_x_forwarded_for=False)
sock = api.connect_tcp(('localhost', self.port))
sock.sendall('GET / HTTP/1.1\r\nHost: localhost\r\nX-Forwarded-For: 1.2.3.4, 5.6.7.8\r\n\r\n')
sock.recv(1024)
sock.close()
self.assert_('1.2.3.4' not in self.logfile.getvalue())
self.assert_('5.6.7.8' not in self.logfile.getvalue())
self.assert_('127.0.0.1' in self.logfile.getvalue())
if __name__ == '__main__':
main()