add parameter to limit max connections on WebSocket servers (fixes #84)
This commit is contained in:
@@ -643,7 +643,8 @@ class WebSocketProtocol:
|
|||||||
'serveFlashSocketPolicy',
|
'serveFlashSocketPolicy',
|
||||||
'flashSocketPolicy',
|
'flashSocketPolicy',
|
||||||
'allowedOrigins',
|
'allowedOrigins',
|
||||||
'allowedOriginsPatterns']
|
'allowedOriginsPatterns',
|
||||||
|
'maxConnections']
|
||||||
"""
|
"""
|
||||||
Configuration attributes specific to servers.
|
Configuration attributes specific to servers.
|
||||||
"""
|
"""
|
||||||
@@ -3037,22 +3038,30 @@ class WebSocketServerProtocol(WebSocketProtocol):
|
|||||||
# noinspection PyUnboundLocalVariable
|
# noinspection PyUnboundLocalVariable
|
||||||
self._wskey = key
|
self._wskey = key
|
||||||
|
|
||||||
# WebSocket handshake validated => produce opening handshake response
|
# DoS protection
|
||||||
|
|
||||||
# Now fire onConnect() on derived class, to give that class a chance to accept or deny
|
|
||||||
# the connection. onConnect() may throw, in which case the connection is denied, or it
|
|
||||||
# may return a protocol from the protocols provided by client or None.
|
|
||||||
#
|
#
|
||||||
request = ConnectionRequest(self.peer,
|
if self.maxConnections > 0 and self.factory.countConnections > self.maxConnections:
|
||||||
self.http_headers,
|
|
||||||
self.http_request_host,
|
# maximum number of concurrent connections reached
|
||||||
self.http_request_path,
|
#
|
||||||
self.http_request_params,
|
self.failHandshake("maximum number of connections reached", code=http.SERVICE_UNAVAILABLE[0])
|
||||||
self.websocket_version,
|
|
||||||
self.websocket_origin,
|
else:
|
||||||
self.websocket_protocols,
|
# WebSocket handshake validated => produce opening handshake response
|
||||||
self.websocket_extensions)
|
#
|
||||||
self._onConnect(request)
|
request = ConnectionRequest(self.peer,
|
||||||
|
self.http_headers,
|
||||||
|
self.http_request_host,
|
||||||
|
self.http_request_path,
|
||||||
|
self.http_request_params,
|
||||||
|
self.websocket_version,
|
||||||
|
self.websocket_origin,
|
||||||
|
self.websocket_protocols,
|
||||||
|
self.websocket_extensions)
|
||||||
|
# Now fire onConnect() on derived class, to give that class a chance to accept or deny
|
||||||
|
# the connection. onConnect() may throw, in which case the connection is denied, or it
|
||||||
|
# may return a protocol from the protocols provided by client or None.
|
||||||
|
self._onConnect(request)
|
||||||
|
|
||||||
elif self.serveFlashSocketPolicy or self.debug:
|
elif self.serveFlashSocketPolicy or self.debug:
|
||||||
flash_policy_file_request = self.data.find(b"<policy-file-request/>\x00")
|
flash_policy_file_request = self.data.find(b"<policy-file-request/>\x00")
|
||||||
@@ -3522,6 +3531,9 @@ class WebSocketServerFactory(WebSocketFactory):
|
|||||||
self.allowedOrigins = ["*"]
|
self.allowedOrigins = ["*"]
|
||||||
self.allowedOriginsPatterns = wildcards2patterns(self.allowedOrigins)
|
self.allowedOriginsPatterns = wildcards2patterns(self.allowedOrigins)
|
||||||
|
|
||||||
|
# maximum number of concurrent connections
|
||||||
|
self.maxConnections = 0
|
||||||
|
|
||||||
def setProtocolOptions(self,
|
def setProtocolOptions(self,
|
||||||
versions=None,
|
versions=None,
|
||||||
allowHixie76=None,
|
allowHixie76=None,
|
||||||
@@ -3544,7 +3556,8 @@ class WebSocketServerFactory(WebSocketFactory):
|
|||||||
autoPingSize=None,
|
autoPingSize=None,
|
||||||
serveFlashSocketPolicy=None,
|
serveFlashSocketPolicy=None,
|
||||||
flashSocketPolicy=None,
|
flashSocketPolicy=None,
|
||||||
allowedOrigins=None):
|
allowedOrigins=None,
|
||||||
|
maxConnections=None):
|
||||||
"""
|
"""
|
||||||
Set WebSocket protocol options used as defaults for new protocol instances.
|
Set WebSocket protocol options used as defaults for new protocol instances.
|
||||||
|
|
||||||
@@ -3595,6 +3608,8 @@ class WebSocketServerFactory(WebSocketFactory):
|
|||||||
:type flashSocketPolicy: str or None
|
:type flashSocketPolicy: str or None
|
||||||
:param allowedOrigins: A list of allowed WebSocket origins (with '*' as a wildcard character).
|
:param allowedOrigins: A list of allowed WebSocket origins (with '*' as a wildcard character).
|
||||||
:type allowedOrigins: list or None
|
:type allowedOrigins: list or None
|
||||||
|
:param maxConnections: Maximum number of concurrent connections. Set to `0` to disable (default: `0`).
|
||||||
|
:type maxConnections: int or None
|
||||||
"""
|
"""
|
||||||
if allowHixie76 is not None and allowHixie76 != self.allowHixie76:
|
if allowHixie76 is not None and allowHixie76 != self.allowHixie76:
|
||||||
self.allowHixie76 = allowHixie76
|
self.allowHixie76 = allowHixie76
|
||||||
@@ -3671,6 +3686,11 @@ class WebSocketServerFactory(WebSocketFactory):
|
|||||||
self.allowedOrigins = allowedOrigins
|
self.allowedOrigins = allowedOrigins
|
||||||
self.allowedOriginsPatterns = wildcards2patterns(self.allowedOrigins)
|
self.allowedOriginsPatterns = wildcards2patterns(self.allowedOrigins)
|
||||||
|
|
||||||
|
if maxConnections is not None and maxConnections != self.maxConnections:
|
||||||
|
assert(type(maxConnections) in six.integer_types)
|
||||||
|
assert(maxConnections >= 0)
|
||||||
|
self.maxConnections = maxConnections
|
||||||
|
|
||||||
def getConnectionCount(self):
|
def getConnectionCount(self):
|
||||||
"""
|
"""
|
||||||
Get number of currently connected clients.
|
Get number of currently connected clients.
|
||||||
|
|||||||
@@ -60,6 +60,7 @@ if __name__ == '__main__':
|
|||||||
|
|
||||||
factory = WebSocketServerFactory("ws://localhost:9000", debug=False)
|
factory = WebSocketServerFactory("ws://localhost:9000", debug=False)
|
||||||
factory.protocol = MyServerProtocol
|
factory.protocol = MyServerProtocol
|
||||||
|
# factory.setProtocolOptions(maxConnections=2)
|
||||||
|
|
||||||
reactor.listenTCP(9000, factory)
|
reactor.listenTCP(9000, factory)
|
||||||
reactor.run()
|
reactor.run()
|
||||||
|
|||||||
Reference in New Issue
Block a user