Merge pull request #830 from agronholm/x_forwarded_for
Added support for the X-Forwarded-For HTTP header
This commit is contained in:
@@ -316,3 +316,49 @@ class WebSocketOriginMatching(unittest.TestCase):
|
||||
self.assertFalse(
|
||||
_is_same_origin(_url_to_origin('null'), None, 80, [])
|
||||
)
|
||||
|
||||
|
||||
class WebSocketXForwardedFor(unittest.TestCase):
|
||||
"""
|
||||
Test that (only) a trusted X-Forwarded-For can replace the peer address.
|
||||
"""
|
||||
|
||||
def setUp(self):
|
||||
self.factory = WebSocketServerFactory()
|
||||
self.factory.setProtocolOptions(
|
||||
trustXForwardedFor=2
|
||||
)
|
||||
self.proto = WebSocketServerProtocol()
|
||||
self.proto.transport = StringTransport()
|
||||
self.proto.factory = self.factory
|
||||
self.proto.failHandshake = Mock()
|
||||
self.proto._connectionMade()
|
||||
|
||||
def tearDown(self):
|
||||
for call in [
|
||||
self.proto.autoPingPendingCall,
|
||||
self.proto.autoPingTimeoutCall,
|
||||
self.proto.openHandshakeTimeoutCall,
|
||||
self.proto.closeHandshakeTimeoutCall,
|
||||
]:
|
||||
if call is not None:
|
||||
call.cancel()
|
||||
|
||||
def test_trusted_addresses(self):
|
||||
self.proto.data = b"\r\n".join([
|
||||
b'GET /ws HTTP/1.1',
|
||||
b'Host: www.example.com',
|
||||
b'Origin: http://www.example.com',
|
||||
b'Sec-WebSocket-Version: 13',
|
||||
b'Sec-WebSocket-Extensions: permessage-deflate',
|
||||
b'Sec-WebSocket-Key: tXAxWFUqnhi86Ajj7dRY5g==',
|
||||
b'Connection: keep-alive, Upgrade',
|
||||
b'Upgrade: websocket',
|
||||
b'X-Forwarded-For: 1.2.3.4, 2.3.4.5, 111.222.33.44',
|
||||
b'\r\n', # last string doesn't get a \r\n from join()
|
||||
])
|
||||
self.proto.consumeData()
|
||||
|
||||
self.assertEquals(
|
||||
self.proto.peer, "2.3.4.5",
|
||||
"The second address in X-Forwarded-For should have been picked as the peer address")
|
||||
|
||||
@@ -126,7 +126,8 @@ class IWebSocketServerChannelFactory(object):
|
||||
flashSocketPolicy=None,
|
||||
allowedOrigins=None,
|
||||
allowNullOrigin=False,
|
||||
maxConnections=None):
|
||||
maxConnections=None,
|
||||
trustXForwardedFor=0):
|
||||
"""
|
||||
Set WebSocket protocol options used as defaults for new protocol instances.
|
||||
|
||||
@@ -201,6 +202,9 @@ class IWebSocketServerChannelFactory(object):
|
||||
|
||||
:param maxConnections: Maximum number of concurrent connections. Set to `0` to disable (default: `0`).
|
||||
:type maxConnections: int or None
|
||||
|
||||
:param trustXForwardedFor: Number of trusted web servers in front of this server that add their own X-Forwarded-For header (default: `0`)
|
||||
:type trustXForwardedFor: int
|
||||
"""
|
||||
|
||||
@public
|
||||
|
||||
@@ -528,7 +528,8 @@ class WebSocketProtocol(object):
|
||||
'allowedOrigins',
|
||||
'allowedOriginsPatterns',
|
||||
'allowNullOrigin',
|
||||
'maxConnections']
|
||||
'maxConnections',
|
||||
'trustXForwardedFor']
|
||||
"""
|
||||
Configuration attributes specific to servers.
|
||||
"""
|
||||
@@ -2491,6 +2492,13 @@ class WebSocketServerProtocol(WebSocketProtocol):
|
||||
except Exception as e:
|
||||
return self.failHandshake("Error during parsing of HTTP status line / request headers : {0}".format(e))
|
||||
|
||||
# replace self.peer if the x-forwarded-for header is present and trusted
|
||||
#
|
||||
if 'x-forwarded-for' in self.http_headers and self.trustXForwardedFor:
|
||||
addresses = [x.strip() for x in self.http_headers['x-forwarded-for'].split(',')]
|
||||
trusted_addresses = addresses[-self.trustXForwardedFor:]
|
||||
self.peer = trusted_addresses[0]
|
||||
|
||||
# validate WebSocket opening handshake client request
|
||||
#
|
||||
self.log.debug(
|
||||
@@ -3179,6 +3187,9 @@ class WebSocketServerFactory(WebSocketFactory):
|
||||
# maximum number of concurrent connections
|
||||
self.maxConnections = 0
|
||||
|
||||
# number of trusted web servers in front of this server
|
||||
self.trustXForwardedFor = 0
|
||||
|
||||
def setProtocolOptions(self,
|
||||
versions=None,
|
||||
webStatus=None,
|
||||
@@ -3202,7 +3213,8 @@ class WebSocketServerFactory(WebSocketFactory):
|
||||
flashSocketPolicy=None,
|
||||
allowedOrigins=None,
|
||||
allowNullOrigin=False,
|
||||
maxConnections=None):
|
||||
maxConnections=None,
|
||||
trustXForwardedFor=None):
|
||||
"""
|
||||
Implements :func:`autobahn.websocket.interfaces.IWebSocketServerChannelFactory.setProtocolOptions`
|
||||
"""
|
||||
@@ -3285,6 +3297,11 @@ class WebSocketServerFactory(WebSocketFactory):
|
||||
assert(maxConnections >= 0)
|
||||
self.maxConnections = maxConnections
|
||||
|
||||
if trustXForwardedFor is not None and trustXForwardedFor != self.trustXForwardedFor:
|
||||
assert(type(trustXForwardedFor) in six.integer_types)
|
||||
assert(trustXForwardedFor >= 0)
|
||||
self.trustXForwardedFor = trustXForwardedFor
|
||||
|
||||
def getConnectionCount(self):
|
||||
"""
|
||||
Get number of currently connected clients.
|
||||
|
||||
@@ -529,6 +529,7 @@ Server-Only Options
|
||||
- flashSocketPolicy: the actual flash policy to serve (default one allows everything)
|
||||
- allowedOrigins: a list of origins to allow, with embedded `*`'s for wildcards; these are turned into regular expressions (e.g. `https://*.example.com:443` becomes `^https://.*\.example\.com:443$`). When doing the matching, the origin is **always** of the form `scheme://host:port` with an explicit port. By default, we match with `*` (that is, anything). To match all subdomains of `example.com` on any scheme and port, you'd need `*://*.example.com:*`
|
||||
- maxConnections: total concurrent connections allowed (default 0, unlimited)
|
||||
- trustXForwardedFor: number of trusted web servers (reverse proxies) in front of this server which set the X-Forwarded-For header
|
||||
|
||||
|
||||
Client-Only Options
|
||||
|
||||
Reference in New Issue
Block a user