add parameter to limit max connections on WebSocket servers (fixes #84)
This commit is contained in:
		@@ -643,7 +643,8 @@ class WebSocketProtocol:
 | 
			
		||||
                           'serveFlashSocketPolicy',
 | 
			
		||||
                           'flashSocketPolicy',
 | 
			
		||||
                           'allowedOrigins',
 | 
			
		||||
                           'allowedOriginsPatterns']
 | 
			
		||||
                           'allowedOriginsPatterns',
 | 
			
		||||
                           'maxConnections']
 | 
			
		||||
    """
 | 
			
		||||
   Configuration attributes specific to servers.
 | 
			
		||||
   """
 | 
			
		||||
@@ -3037,11 +3038,16 @@ class WebSocketServerProtocol(WebSocketProtocol):
 | 
			
		||||
                # noinspection PyUnboundLocalVariable
 | 
			
		||||
                self._wskey = key
 | 
			
		||||
 | 
			
		||||
            # WebSocket handshake validated => produce opening handshake response
 | 
			
		||||
            # DoS protection
 | 
			
		||||
            #
 | 
			
		||||
            if self.maxConnections > 0 and self.factory.countConnections > self.maxConnections:
 | 
			
		||||
 | 
			
		||||
            # Now fire onConnect() on derived class, to give that class a chance to accept or deny
 | 
			
		||||
            # the connection. onConnect() may throw, in which case the connection is denied, or it
 | 
			
		||||
            # may return a protocol from the protocols provided by client or None.
 | 
			
		||||
                # maximum number of concurrent connections reached
 | 
			
		||||
                #
 | 
			
		||||
                self.failHandshake("maximum number of connections reached", code=http.SERVICE_UNAVAILABLE[0])
 | 
			
		||||
 | 
			
		||||
            else:
 | 
			
		||||
                # WebSocket handshake validated => produce opening handshake response
 | 
			
		||||
                #
 | 
			
		||||
                request = ConnectionRequest(self.peer,
 | 
			
		||||
                                            self.http_headers,
 | 
			
		||||
@@ -3052,6 +3058,9 @@ class WebSocketServerProtocol(WebSocketProtocol):
 | 
			
		||||
                                            self.websocket_origin,
 | 
			
		||||
                                            self.websocket_protocols,
 | 
			
		||||
                                            self.websocket_extensions)
 | 
			
		||||
                # Now fire onConnect() on derived class, to give that class a chance to accept or deny
 | 
			
		||||
                # the connection. onConnect() may throw, in which case the connection is denied, or it
 | 
			
		||||
                # may return a protocol from the protocols provided by client or None.
 | 
			
		||||
                self._onConnect(request)
 | 
			
		||||
 | 
			
		||||
        elif self.serveFlashSocketPolicy or self.debug:
 | 
			
		||||
@@ -3522,6 +3531,9 @@ class WebSocketServerFactory(WebSocketFactory):
 | 
			
		||||
        self.allowedOrigins = ["*"]
 | 
			
		||||
        self.allowedOriginsPatterns = wildcards2patterns(self.allowedOrigins)
 | 
			
		||||
 | 
			
		||||
        # maximum number of concurrent connections
 | 
			
		||||
        self.maxConnections = 0
 | 
			
		||||
 | 
			
		||||
    def setProtocolOptions(self,
 | 
			
		||||
                           versions=None,
 | 
			
		||||
                           allowHixie76=None,
 | 
			
		||||
@@ -3544,7 +3556,8 @@ class WebSocketServerFactory(WebSocketFactory):
 | 
			
		||||
                           autoPingSize=None,
 | 
			
		||||
                           serveFlashSocketPolicy=None,
 | 
			
		||||
                           flashSocketPolicy=None,
 | 
			
		||||
                           allowedOrigins=None):
 | 
			
		||||
                           allowedOrigins=None,
 | 
			
		||||
                           maxConnections=None):
 | 
			
		||||
        """
 | 
			
		||||
        Set WebSocket protocol options used as defaults for new protocol instances.
 | 
			
		||||
 | 
			
		||||
@@ -3595,6 +3608,8 @@ class WebSocketServerFactory(WebSocketFactory):
 | 
			
		||||
        :type flashSocketPolicy: str or None
 | 
			
		||||
        :param allowedOrigins: A list of allowed WebSocket origins (with '*' as a wildcard character).
 | 
			
		||||
        :type allowedOrigins: list or None
 | 
			
		||||
        :param maxConnections: Maximum number of concurrent connections. Set to `0` to disable (default: `0`).
 | 
			
		||||
        :type maxConnections: int or None
 | 
			
		||||
        """
 | 
			
		||||
        if allowHixie76 is not None and allowHixie76 != self.allowHixie76:
 | 
			
		||||
            self.allowHixie76 = allowHixie76
 | 
			
		||||
@@ -3671,6 +3686,11 @@ class WebSocketServerFactory(WebSocketFactory):
 | 
			
		||||
            self.allowedOrigins = allowedOrigins
 | 
			
		||||
            self.allowedOriginsPatterns = wildcards2patterns(self.allowedOrigins)
 | 
			
		||||
 | 
			
		||||
        if maxConnections is not None and maxConnections != self.maxConnections:
 | 
			
		||||
            assert(type(maxConnections) in six.integer_types)
 | 
			
		||||
            assert(maxConnections >= 0)
 | 
			
		||||
            self.maxConnections = maxConnections
 | 
			
		||||
 | 
			
		||||
    def getConnectionCount(self):
 | 
			
		||||
        """
 | 
			
		||||
        Get number of currently connected clients.
 | 
			
		||||
 
 | 
			
		||||
@@ -60,6 +60,7 @@ if __name__ == '__main__':
 | 
			
		||||
 | 
			
		||||
    factory = WebSocketServerFactory("ws://localhost:9000", debug=False)
 | 
			
		||||
    factory.protocol = MyServerProtocol
 | 
			
		||||
    # factory.setProtocolOptions(maxConnections=2)
 | 
			
		||||
 | 
			
		||||
    reactor.listenTCP(9000, factory)
 | 
			
		||||
    reactor.run()
 | 
			
		||||
 
 | 
			
		||||
		Reference in New Issue
	
	Block a user