add parameter to limit max connections on WebSocket servers (fixes #84)

This commit is contained in:
Tobias Oberstein
2015-03-10 14:35:01 +01:00
parent 028ae118ad
commit 899df3da79
2 changed files with 38 additions and 17 deletions

View File

@@ -643,7 +643,8 @@ class WebSocketProtocol:
'serveFlashSocketPolicy', 'serveFlashSocketPolicy',
'flashSocketPolicy', 'flashSocketPolicy',
'allowedOrigins', 'allowedOrigins',
'allowedOriginsPatterns'] 'allowedOriginsPatterns',
'maxConnections']
""" """
Configuration attributes specific to servers. Configuration attributes specific to servers.
""" """
@@ -3037,22 +3038,30 @@ class WebSocketServerProtocol(WebSocketProtocol):
# noinspection PyUnboundLocalVariable # noinspection PyUnboundLocalVariable
self._wskey = key self._wskey = key
# WebSocket handshake validated => produce opening handshake response # DoS protection
# Now fire onConnect() on derived class, to give that class a chance to accept or deny
# the connection. onConnect() may throw, in which case the connection is denied, or it
# may return a protocol from the protocols provided by client or None.
# #
request = ConnectionRequest(self.peer, if self.maxConnections > 0 and self.factory.countConnections > self.maxConnections:
self.http_headers,
self.http_request_host, # maximum number of concurrent connections reached
self.http_request_path, #
self.http_request_params, self.failHandshake("maximum number of connections reached", code=http.SERVICE_UNAVAILABLE[0])
self.websocket_version,
self.websocket_origin, else:
self.websocket_protocols, # WebSocket handshake validated => produce opening handshake response
self.websocket_extensions) #
self._onConnect(request) request = ConnectionRequest(self.peer,
self.http_headers,
self.http_request_host,
self.http_request_path,
self.http_request_params,
self.websocket_version,
self.websocket_origin,
self.websocket_protocols,
self.websocket_extensions)
# Now fire onConnect() on derived class, to give that class a chance to accept or deny
# the connection. onConnect() may throw, in which case the connection is denied, or it
# may return a protocol from the protocols provided by client or None.
self._onConnect(request)
elif self.serveFlashSocketPolicy or self.debug: elif self.serveFlashSocketPolicy or self.debug:
flash_policy_file_request = self.data.find(b"<policy-file-request/>\x00") flash_policy_file_request = self.data.find(b"<policy-file-request/>\x00")
@@ -3522,6 +3531,9 @@ class WebSocketServerFactory(WebSocketFactory):
self.allowedOrigins = ["*"] self.allowedOrigins = ["*"]
self.allowedOriginsPatterns = wildcards2patterns(self.allowedOrigins) self.allowedOriginsPatterns = wildcards2patterns(self.allowedOrigins)
# maximum number of concurrent connections
self.maxConnections = 0
def setProtocolOptions(self, def setProtocolOptions(self,
versions=None, versions=None,
allowHixie76=None, allowHixie76=None,
@@ -3544,7 +3556,8 @@ class WebSocketServerFactory(WebSocketFactory):
autoPingSize=None, autoPingSize=None,
serveFlashSocketPolicy=None, serveFlashSocketPolicy=None,
flashSocketPolicy=None, flashSocketPolicy=None,
allowedOrigins=None): allowedOrigins=None,
maxConnections=None):
""" """
Set WebSocket protocol options used as defaults for new protocol instances. Set WebSocket protocol options used as defaults for new protocol instances.
@@ -3595,6 +3608,8 @@ class WebSocketServerFactory(WebSocketFactory):
:type flashSocketPolicy: str or None :type flashSocketPolicy: str or None
:param allowedOrigins: A list of allowed WebSocket origins (with '*' as a wildcard character). :param allowedOrigins: A list of allowed WebSocket origins (with '*' as a wildcard character).
:type allowedOrigins: list or None :type allowedOrigins: list or None
:param maxConnections: Maximum number of concurrent connections. Set to `0` to disable (default: `0`).
:type maxConnections: int or None
""" """
if allowHixie76 is not None and allowHixie76 != self.allowHixie76: if allowHixie76 is not None and allowHixie76 != self.allowHixie76:
self.allowHixie76 = allowHixie76 self.allowHixie76 = allowHixie76
@@ -3671,6 +3686,11 @@ class WebSocketServerFactory(WebSocketFactory):
self.allowedOrigins = allowedOrigins self.allowedOrigins = allowedOrigins
self.allowedOriginsPatterns = wildcards2patterns(self.allowedOrigins) self.allowedOriginsPatterns = wildcards2patterns(self.allowedOrigins)
if maxConnections is not None and maxConnections != self.maxConnections:
assert(type(maxConnections) in six.integer_types)
assert(maxConnections >= 0)
self.maxConnections = maxConnections
def getConnectionCount(self): def getConnectionCount(self):
""" """
Get number of currently connected clients. Get number of currently connected clients.

View File

@@ -60,6 +60,7 @@ if __name__ == '__main__':
factory = WebSocketServerFactory("ws://localhost:9000", debug=False) factory = WebSocketServerFactory("ws://localhost:9000", debug=False)
factory.protocol = MyServerProtocol factory.protocol = MyServerProtocol
# factory.setProtocolOptions(maxConnections=2)
reactor.listenTCP(9000, factory) reactor.listenTCP(9000, factory)
reactor.run() reactor.run()