From 899df3da79a54425317562e0b74c4e083a1631f7 Mon Sep 17 00:00:00 2001 From: Tobias Oberstein Date: Tue, 10 Mar 2015 14:35:01 +0100 Subject: [PATCH] add parameter to limit max connections on WebSocket servers (fixes #84) --- autobahn/websocket/protocol.py | 54 ++++++++++++++++------- examples/twisted/websocket/echo/server.py | 1 + 2 files changed, 38 insertions(+), 17 deletions(-) diff --git a/autobahn/websocket/protocol.py b/autobahn/websocket/protocol.py index a7d7ff36..31c6d922 100755 --- a/autobahn/websocket/protocol.py +++ b/autobahn/websocket/protocol.py @@ -643,7 +643,8 @@ class WebSocketProtocol: 'serveFlashSocketPolicy', 'flashSocketPolicy', 'allowedOrigins', - 'allowedOriginsPatterns'] + 'allowedOriginsPatterns', + 'maxConnections'] """ Configuration attributes specific to servers. """ @@ -3037,22 +3038,30 @@ class WebSocketServerProtocol(WebSocketProtocol): # noinspection PyUnboundLocalVariable self._wskey = key - # WebSocket handshake validated => produce opening handshake response - - # Now fire onConnect() on derived class, to give that class a chance to accept or deny - # the connection. onConnect() may throw, in which case the connection is denied, or it - # may return a protocol from the protocols provided by client or None. + # DoS protection # - request = ConnectionRequest(self.peer, - self.http_headers, - self.http_request_host, - self.http_request_path, - self.http_request_params, - self.websocket_version, - self.websocket_origin, - self.websocket_protocols, - self.websocket_extensions) - self._onConnect(request) + if self.maxConnections > 0 and self.factory.countConnections > self.maxConnections: + + # maximum number of concurrent connections reached + # + self.failHandshake("maximum number of connections reached", code=http.SERVICE_UNAVAILABLE[0]) + + else: + # WebSocket handshake validated => produce opening handshake response + # + request = ConnectionRequest(self.peer, + self.http_headers, + self.http_request_host, + self.http_request_path, + self.http_request_params, + self.websocket_version, + self.websocket_origin, + self.websocket_protocols, + self.websocket_extensions) + # Now fire onConnect() on derived class, to give that class a chance to accept or deny + # the connection. onConnect() may throw, in which case the connection is denied, or it + # may return a protocol from the protocols provided by client or None. + self._onConnect(request) elif self.serveFlashSocketPolicy or self.debug: flash_policy_file_request = self.data.find(b"\x00") @@ -3522,6 +3531,9 @@ class WebSocketServerFactory(WebSocketFactory): self.allowedOrigins = ["*"] self.allowedOriginsPatterns = wildcards2patterns(self.allowedOrigins) + # maximum number of concurrent connections + self.maxConnections = 0 + def setProtocolOptions(self, versions=None, allowHixie76=None, @@ -3544,7 +3556,8 @@ class WebSocketServerFactory(WebSocketFactory): autoPingSize=None, serveFlashSocketPolicy=None, flashSocketPolicy=None, - allowedOrigins=None): + allowedOrigins=None, + maxConnections=None): """ Set WebSocket protocol options used as defaults for new protocol instances. @@ -3595,6 +3608,8 @@ class WebSocketServerFactory(WebSocketFactory): :type flashSocketPolicy: str or None :param allowedOrigins: A list of allowed WebSocket origins (with '*' as a wildcard character). :type allowedOrigins: list or None + :param maxConnections: Maximum number of concurrent connections. Set to `0` to disable (default: `0`). + :type maxConnections: int or None """ if allowHixie76 is not None and allowHixie76 != self.allowHixie76: self.allowHixie76 = allowHixie76 @@ -3671,6 +3686,11 @@ class WebSocketServerFactory(WebSocketFactory): self.allowedOrigins = allowedOrigins self.allowedOriginsPatterns = wildcards2patterns(self.allowedOrigins) + if maxConnections is not None and maxConnections != self.maxConnections: + assert(type(maxConnections) in six.integer_types) + assert(maxConnections >= 0) + self.maxConnections = maxConnections + def getConnectionCount(self): """ Get number of currently connected clients. diff --git a/examples/twisted/websocket/echo/server.py b/examples/twisted/websocket/echo/server.py index 76a2376d..066cea76 100644 --- a/examples/twisted/websocket/echo/server.py +++ b/examples/twisted/websocket/echo/server.py @@ -60,6 +60,7 @@ if __name__ == '__main__': factory = WebSocketServerFactory("ws://localhost:9000", debug=False) factory.protocol = MyServerProtocol + # factory.setProtocolOptions(maxConnections=2) reactor.listenTCP(9000, factory) reactor.run()