add parameter to limit max connections on WebSocket servers (fixes #84)
This commit is contained in:
@@ -643,7 +643,8 @@ class WebSocketProtocol:
|
||||
'serveFlashSocketPolicy',
|
||||
'flashSocketPolicy',
|
||||
'allowedOrigins',
|
||||
'allowedOriginsPatterns']
|
||||
'allowedOriginsPatterns',
|
||||
'maxConnections']
|
||||
"""
|
||||
Configuration attributes specific to servers.
|
||||
"""
|
||||
@@ -3037,22 +3038,30 @@ class WebSocketServerProtocol(WebSocketProtocol):
|
||||
# noinspection PyUnboundLocalVariable
|
||||
self._wskey = key
|
||||
|
||||
# WebSocket handshake validated => produce opening handshake response
|
||||
|
||||
# Now fire onConnect() on derived class, to give that class a chance to accept or deny
|
||||
# the connection. onConnect() may throw, in which case the connection is denied, or it
|
||||
# may return a protocol from the protocols provided by client or None.
|
||||
# DoS protection
|
||||
#
|
||||
request = ConnectionRequest(self.peer,
|
||||
self.http_headers,
|
||||
self.http_request_host,
|
||||
self.http_request_path,
|
||||
self.http_request_params,
|
||||
self.websocket_version,
|
||||
self.websocket_origin,
|
||||
self.websocket_protocols,
|
||||
self.websocket_extensions)
|
||||
self._onConnect(request)
|
||||
if self.maxConnections > 0 and self.factory.countConnections > self.maxConnections:
|
||||
|
||||
# maximum number of concurrent connections reached
|
||||
#
|
||||
self.failHandshake("maximum number of connections reached", code=http.SERVICE_UNAVAILABLE[0])
|
||||
|
||||
else:
|
||||
# WebSocket handshake validated => produce opening handshake response
|
||||
#
|
||||
request = ConnectionRequest(self.peer,
|
||||
self.http_headers,
|
||||
self.http_request_host,
|
||||
self.http_request_path,
|
||||
self.http_request_params,
|
||||
self.websocket_version,
|
||||
self.websocket_origin,
|
||||
self.websocket_protocols,
|
||||
self.websocket_extensions)
|
||||
# Now fire onConnect() on derived class, to give that class a chance to accept or deny
|
||||
# the connection. onConnect() may throw, in which case the connection is denied, or it
|
||||
# may return a protocol from the protocols provided by client or None.
|
||||
self._onConnect(request)
|
||||
|
||||
elif self.serveFlashSocketPolicy or self.debug:
|
||||
flash_policy_file_request = self.data.find(b"<policy-file-request/>\x00")
|
||||
@@ -3522,6 +3531,9 @@ class WebSocketServerFactory(WebSocketFactory):
|
||||
self.allowedOrigins = ["*"]
|
||||
self.allowedOriginsPatterns = wildcards2patterns(self.allowedOrigins)
|
||||
|
||||
# maximum number of concurrent connections
|
||||
self.maxConnections = 0
|
||||
|
||||
def setProtocolOptions(self,
|
||||
versions=None,
|
||||
allowHixie76=None,
|
||||
@@ -3544,7 +3556,8 @@ class WebSocketServerFactory(WebSocketFactory):
|
||||
autoPingSize=None,
|
||||
serveFlashSocketPolicy=None,
|
||||
flashSocketPolicy=None,
|
||||
allowedOrigins=None):
|
||||
allowedOrigins=None,
|
||||
maxConnections=None):
|
||||
"""
|
||||
Set WebSocket protocol options used as defaults for new protocol instances.
|
||||
|
||||
@@ -3595,6 +3608,8 @@ class WebSocketServerFactory(WebSocketFactory):
|
||||
:type flashSocketPolicy: str or None
|
||||
:param allowedOrigins: A list of allowed WebSocket origins (with '*' as a wildcard character).
|
||||
:type allowedOrigins: list or None
|
||||
:param maxConnections: Maximum number of concurrent connections. Set to `0` to disable (default: `0`).
|
||||
:type maxConnections: int or None
|
||||
"""
|
||||
if allowHixie76 is not None and allowHixie76 != self.allowHixie76:
|
||||
self.allowHixie76 = allowHixie76
|
||||
@@ -3671,6 +3686,11 @@ class WebSocketServerFactory(WebSocketFactory):
|
||||
self.allowedOrigins = allowedOrigins
|
||||
self.allowedOriginsPatterns = wildcards2patterns(self.allowedOrigins)
|
||||
|
||||
if maxConnections is not None and maxConnections != self.maxConnections:
|
||||
assert(type(maxConnections) in six.integer_types)
|
||||
assert(maxConnections >= 0)
|
||||
self.maxConnections = maxConnections
|
||||
|
||||
def getConnectionCount(self):
|
||||
"""
|
||||
Get number of currently connected clients.
|
||||
|
||||
@@ -60,6 +60,7 @@ if __name__ == '__main__':
|
||||
|
||||
factory = WebSocketServerFactory("ws://localhost:9000", debug=False)
|
||||
factory.protocol = MyServerProtocol
|
||||
# factory.setProtocolOptions(maxConnections=2)
|
||||
|
||||
reactor.listenTCP(9000, factory)
|
||||
reactor.run()
|
||||
|
||||
Reference in New Issue
Block a user