add parameter to limit max connections on WebSocket servers (fixes #84)

This commit is contained in:
Tobias Oberstein
2015-03-10 14:35:01 +01:00
parent 028ae118ad
commit 899df3da79
2 changed files with 38 additions and 17 deletions

View File

@@ -643,7 +643,8 @@ class WebSocketProtocol:
'serveFlashSocketPolicy',
'flashSocketPolicy',
'allowedOrigins',
'allowedOriginsPatterns']
'allowedOriginsPatterns',
'maxConnections']
"""
Configuration attributes specific to servers.
"""
@@ -3037,22 +3038,30 @@ class WebSocketServerProtocol(WebSocketProtocol):
# noinspection PyUnboundLocalVariable
self._wskey = key
# WebSocket handshake validated => produce opening handshake response
# Now fire onConnect() on derived class, to give that class a chance to accept or deny
# the connection. onConnect() may throw, in which case the connection is denied, or it
# may return a protocol from the protocols provided by client or None.
# DoS protection
#
request = ConnectionRequest(self.peer,
self.http_headers,
self.http_request_host,
self.http_request_path,
self.http_request_params,
self.websocket_version,
self.websocket_origin,
self.websocket_protocols,
self.websocket_extensions)
self._onConnect(request)
if self.maxConnections > 0 and self.factory.countConnections > self.maxConnections:
# maximum number of concurrent connections reached
#
self.failHandshake("maximum number of connections reached", code=http.SERVICE_UNAVAILABLE[0])
else:
# WebSocket handshake validated => produce opening handshake response
#
request = ConnectionRequest(self.peer,
self.http_headers,
self.http_request_host,
self.http_request_path,
self.http_request_params,
self.websocket_version,
self.websocket_origin,
self.websocket_protocols,
self.websocket_extensions)
# Now fire onConnect() on derived class, to give that class a chance to accept or deny
# the connection. onConnect() may throw, in which case the connection is denied, or it
# may return a protocol from the protocols provided by client or None.
self._onConnect(request)
elif self.serveFlashSocketPolicy or self.debug:
flash_policy_file_request = self.data.find(b"<policy-file-request/>\x00")
@@ -3522,6 +3531,9 @@ class WebSocketServerFactory(WebSocketFactory):
self.allowedOrigins = ["*"]
self.allowedOriginsPatterns = wildcards2patterns(self.allowedOrigins)
# maximum number of concurrent connections
self.maxConnections = 0
def setProtocolOptions(self,
versions=None,
allowHixie76=None,
@@ -3544,7 +3556,8 @@ class WebSocketServerFactory(WebSocketFactory):
autoPingSize=None,
serveFlashSocketPolicy=None,
flashSocketPolicy=None,
allowedOrigins=None):
allowedOrigins=None,
maxConnections=None):
"""
Set WebSocket protocol options used as defaults for new protocol instances.
@@ -3595,6 +3608,8 @@ class WebSocketServerFactory(WebSocketFactory):
:type flashSocketPolicy: str or None
:param allowedOrigins: A list of allowed WebSocket origins (with '*' as a wildcard character).
:type allowedOrigins: list or None
:param maxConnections: Maximum number of concurrent connections. Set to `0` to disable (default: `0`).
:type maxConnections: int or None
"""
if allowHixie76 is not None and allowHixie76 != self.allowHixie76:
self.allowHixie76 = allowHixie76
@@ -3671,6 +3686,11 @@ class WebSocketServerFactory(WebSocketFactory):
self.allowedOrigins = allowedOrigins
self.allowedOriginsPatterns = wildcards2patterns(self.allowedOrigins)
if maxConnections is not None and maxConnections != self.maxConnections:
assert(type(maxConnections) in six.integer_types)
assert(maxConnections >= 0)
self.maxConnections = maxConnections
def getConnectionCount(self):
"""
Get number of currently connected clients.

View File

@@ -60,6 +60,7 @@ if __name__ == '__main__':
factory = WebSocketServerFactory("ws://localhost:9000", debug=False)
factory.protocol = MyServerProtocol
# factory.setProtocolOptions(maxConnections=2)
reactor.listenTCP(9000, factory)
reactor.run()