Add allowNullOrigin and fix allowed-origin checks
This commit is contained in:
parent
1335725243
commit
2ef13a6804
@ -29,11 +29,13 @@ from __future__ import absolute_import, print_function
|
||||
import unittest2 as unittest
|
||||
from mock import Mock
|
||||
|
||||
from autobahn.util import wildcards2patterns
|
||||
from autobahn.twisted.websocket import WebSocketServerFactory
|
||||
from autobahn.twisted.websocket import WebSocketServerProtocol
|
||||
from twisted.python.failure import Failure
|
||||
from twisted.internet.error import ConnectionDone, ConnectionAborted, \
|
||||
ConnectionLost
|
||||
from twisted.test.proto_helpers import StringTransport
|
||||
from autobahn.test import FakeTransport
|
||||
|
||||
|
||||
@ -120,3 +122,197 @@ class Hixie76RejectionTests(unittest.TestCase):
|
||||
p.processHandshake()
|
||||
self.assertIn(b"HTTP/1.1 400", t._written)
|
||||
self.assertIn(b"Hixie76 protocol not supported", t._written)
|
||||
|
||||
|
||||
class WebSocketOriginMatching(unittest.TestCase):
|
||||
"""
|
||||
Test that we match Origin: headers properly, when asked to
|
||||
"""
|
||||
|
||||
def setUp(self):
|
||||
self.factory = WebSocketServerFactory()
|
||||
self.factory.setProtocolOptions(
|
||||
allowedOrigins=[u'127.0.0.1:*', u'*.example.com:*']
|
||||
)
|
||||
self.proto = WebSocketServerProtocol()
|
||||
self.proto.transport = StringTransport()
|
||||
self.proto.factory = self.factory
|
||||
self.proto.failHandshake = Mock()
|
||||
self.proto._connectionMade()
|
||||
|
||||
def tearDown(self):
|
||||
for call in [
|
||||
self.proto.autoPingPendingCall,
|
||||
self.proto.autoPingTimeoutCall,
|
||||
self.proto.openHandshakeTimeoutCall,
|
||||
self.proto.closeHandshakeTimeoutCall,
|
||||
]:
|
||||
if call is not None:
|
||||
call.cancel()
|
||||
|
||||
def test_match_full_origin(self):
|
||||
self.proto.data = b"\r\n".join([
|
||||
b'GET /ws HTTP/1.1',
|
||||
b'Host: www.example.com',
|
||||
b'Sec-WebSocket-Version: 13',
|
||||
b'Origin: http://www.example.com.malicious.com',
|
||||
b'Sec-WebSocket-Extensions: permessage-deflate',
|
||||
b'Sec-WebSocket-Key: tXAxWFUqnhi86Ajj7dRY5g==',
|
||||
b'Connection: keep-alive, Upgrade',
|
||||
b'Upgrade: websocket',
|
||||
b'\r\n', # last string doesn't get a \r\n from join()
|
||||
])
|
||||
self.proto.consumeData()
|
||||
|
||||
self.assertTrue(self.proto.failHandshake.called, "Handshake should have failed")
|
||||
arg = self.proto.failHandshake.mock_calls[0][1][0]
|
||||
self.assertTrue('not allowed' in arg)
|
||||
|
||||
def test_match_wrong_scheme_origin(self):
|
||||
# some monkey-business since we already did this in setUp, but
|
||||
# we want a different set of matching origins
|
||||
self.factory.setProtocolOptions(
|
||||
allowedOrigins=[u'http://*.example.com:*']
|
||||
)
|
||||
self.proto.allowedOriginsPatterns = self.factory.allowedOriginsPatterns
|
||||
self.proto.allowedOrigins = self.factory.allowedOrigins
|
||||
|
||||
# the actual test
|
||||
self.factory.isSecure = False
|
||||
self.proto.data = b"\r\n".join([
|
||||
b'GET /ws HTTP/1.1',
|
||||
b'Host: www.example.com',
|
||||
b'Sec-WebSocket-Version: 13',
|
||||
b'Origin: https://www.example.com',
|
||||
b'Sec-WebSocket-Extensions: permessage-deflate',
|
||||
b'Sec-WebSocket-Key: tXAxWFUqnhi86Ajj7dRY5g==',
|
||||
b'Connection: keep-alive, Upgrade',
|
||||
b'Upgrade: websocket',
|
||||
b'\r\n', # last string doesn't get a \r\n from join()
|
||||
])
|
||||
self.proto.consumeData()
|
||||
|
||||
self.assertTrue(self.proto.failHandshake.called, "Handshake should have failed")
|
||||
arg = self.proto.failHandshake.mock_calls[0][1][0]
|
||||
self.assertTrue('not allowed' in arg)
|
||||
|
||||
def test_match_origin_secure_scheme(self):
|
||||
self.factory.isSecure = True
|
||||
self.factory.port = 443
|
||||
self.proto.data = b"\r\n".join([
|
||||
b'GET /ws HTTP/1.1',
|
||||
b'Host: www.example.com',
|
||||
b'Sec-WebSocket-Version: 13',
|
||||
b'Origin: https://www.example.com',
|
||||
b'Sec-WebSocket-Extensions: permessage-deflate',
|
||||
b'Sec-WebSocket-Key: tXAxWFUqnhi86Ajj7dRY5g==',
|
||||
b'Connection: keep-alive, Upgrade',
|
||||
b'Upgrade: websocket',
|
||||
b'\r\n', # last string doesn't get a \r\n from join()
|
||||
])
|
||||
self.proto.consumeData()
|
||||
|
||||
self.assertFalse(self.proto.failHandshake.called, "Handshake should have succeeded")
|
||||
|
||||
def test_match_origin_documentation_example(self):
|
||||
"""
|
||||
Test the examples from the docs
|
||||
"""
|
||||
self.factory.setProtocolOptions(
|
||||
allowedOrigins=['*://*.example.com:*']
|
||||
)
|
||||
self.factory.isSecure = True
|
||||
self.factory.port = 443
|
||||
self.proto.data = b"\r\n".join([
|
||||
b'GET /ws HTTP/1.1',
|
||||
b'Host: www.example.com',
|
||||
b'Sec-WebSocket-Version: 13',
|
||||
b'Origin: http://www.example.com',
|
||||
b'Sec-WebSocket-Extensions: permessage-deflate',
|
||||
b'Sec-WebSocket-Key: tXAxWFUqnhi86Ajj7dRY5g==',
|
||||
b'Connection: keep-alive, Upgrade',
|
||||
b'Upgrade: websocket',
|
||||
b'\r\n', # last string doesn't get a \r\n from join()
|
||||
])
|
||||
self.proto.consumeData()
|
||||
|
||||
self.assertFalse(self.proto.failHandshake.called, "Handshake should have succeeded")
|
||||
|
||||
def test_match_origin_examples(self):
|
||||
"""
|
||||
All the example origins from RFC6454 (3.2.1)
|
||||
"""
|
||||
# we're just testing the low-level function here...
|
||||
from autobahn.websocket.protocol import _is_same_origin, _url_to_origin
|
||||
policy = wildcards2patterns(['*example.com:*'])
|
||||
|
||||
# should parametrize test ...
|
||||
for url in ['http://example.com/', 'http://example.com:80/',
|
||||
'http://example.com/path/file',
|
||||
'http://example.com/;semi=true',
|
||||
# 'http://example.com./',
|
||||
'//example.com/',
|
||||
'http://@example.com']:
|
||||
self.assertTrue(_is_same_origin(_url_to_origin(url), 'http', 80, policy), url)
|
||||
|
||||
def test_match_origin_counter_examples(self):
|
||||
"""
|
||||
All the example 'not-same' origins from RFC6454 (3.2.1)
|
||||
"""
|
||||
# we're just testing the low-level function here...
|
||||
from autobahn.websocket.protocol import _is_same_origin, _url_to_origin
|
||||
policy = wildcards2patterns(['example.com'])
|
||||
|
||||
for url in ['http://ietf.org/', 'http://example.org/',
|
||||
'https://example.com/', 'http://example.com:8080/',
|
||||
'http://www.example.com/']:
|
||||
self.assertFalse(_is_same_origin(_url_to_origin(url), 'http', 80, policy))
|
||||
|
||||
def test_match_origin_edge(self):
|
||||
# we're just testing the low-level function here...
|
||||
from autobahn.websocket.protocol import _is_same_origin, _url_to_origin
|
||||
policy = wildcards2patterns(['http://*example.com:80'])
|
||||
|
||||
self.assertTrue(
|
||||
_is_same_origin(_url_to_origin('http://example.com:80'), 'http', 80, policy)
|
||||
)
|
||||
self.assertFalse(
|
||||
_is_same_origin(_url_to_origin('http://example.com:81'), 'http', 81, policy)
|
||||
)
|
||||
self.assertFalse(
|
||||
_is_same_origin(_url_to_origin('https://example.com:80'), 'http', 80, policy)
|
||||
)
|
||||
|
||||
def test_origin_from_url(self):
|
||||
from autobahn.websocket.protocol import _url_to_origin
|
||||
|
||||
# basic function
|
||||
self.assertEqual(
|
||||
_url_to_origin('http://example.com'),
|
||||
('http', 'example.com', 80)
|
||||
)
|
||||
# should lower-case scheme
|
||||
self.assertEqual(
|
||||
_url_to_origin('hTTp://example.com'),
|
||||
('http', 'example.com', 80)
|
||||
)
|
||||
|
||||
def test_origin_file(self):
|
||||
from autobahn.websocket.protocol import _url_to_origin
|
||||
self.assertEqual('null', _url_to_origin('file:///etc/passwd'))
|
||||
|
||||
def test_origin_null(self):
|
||||
from autobahn.websocket.protocol import _is_same_origin, _url_to_origin
|
||||
self.assertEqual('null', _url_to_origin('null'))
|
||||
self.assertFalse(
|
||||
_is_same_origin(_url_to_origin('null'), 'http', 80, [])
|
||||
)
|
||||
self.assertFalse(
|
||||
_is_same_origin(_url_to_origin('null'), 'https', 80, [])
|
||||
)
|
||||
self.assertFalse(
|
||||
_is_same_origin(_url_to_origin('null'), '', 80, [])
|
||||
)
|
||||
self.assertFalse(
|
||||
_is_same_origin(_url_to_origin('null'), None, 80, [])
|
||||
)
|
||||
|
@ -637,10 +637,15 @@ def wildcards2patterns(wildcards):
|
||||
|
||||
:param wildcards: List of wildcard strings to compute regular expression patterns for.
|
||||
:type wildcards: list of str
|
||||
|
||||
:returns: Computed regular expressions.
|
||||
:rtype: list of obj
|
||||
"""
|
||||
return [re.compile(wc.replace('.', '\.').replace('*', '.*')) for wc in wildcards]
|
||||
# note that we add the ^ and $ so that the *entire* string must
|
||||
# match. Without this, e.g. a prefix will match:
|
||||
# re.match('.*good\\.com', 'good.com.evil.com') # match!
|
||||
# re.match('.*good\\.com$', 'good.com.evil.com') # no match!
|
||||
return [re.compile('^' + wc.replace('.', '\.').replace('*', '.*') + '$') for wc in wildcards]
|
||||
|
||||
|
||||
class ObservableMixin(object):
|
||||
|
@ -74,6 +74,83 @@ __all__ = ("ConnectionRequest",
|
||||
"WebSocketClientFactory")
|
||||
|
||||
|
||||
def _url_to_origin(url):
|
||||
"""
|
||||
Given an RFC6455 Origin URL, this returns the (scheme, host, port)
|
||||
triple. If there's no port, and the scheme isn't http or https
|
||||
then port will be None
|
||||
"""
|
||||
if url.lower() == 'null':
|
||||
return 'null'
|
||||
|
||||
res = urllib.parse.urlsplit(url)
|
||||
scheme = res.scheme.lower()
|
||||
if scheme == 'file':
|
||||
# when browsing local files, Chrome sends file:// URLs,
|
||||
# Firefox sends 'null'
|
||||
return 'null'
|
||||
|
||||
host = res.hostname
|
||||
port = res.port
|
||||
if port is None:
|
||||
try:
|
||||
port = {'https': 443, 'http': 80}[scheme]
|
||||
except KeyError:
|
||||
port = None
|
||||
|
||||
if not host:
|
||||
raise ValueError("No host part in Origin '{}'".format(url))
|
||||
return scheme, host, port
|
||||
|
||||
|
||||
def _is_same_origin(websocket_origin, host_scheme, host_port, host_policy):
|
||||
"""
|
||||
Internal helper. Returns True if the provided websocket_origin
|
||||
triple should be considered valid given the provided policy and
|
||||
expected host_port.
|
||||
|
||||
Currently, the policy is just the list of allowedOriginsPatterns
|
||||
from the WebSocketProtocol instance. Schemes and ports are matched
|
||||
first, and only if there is not a mismatch do we compare each
|
||||
allowed-origin pattern against the host.
|
||||
"""
|
||||
|
||||
if websocket_origin == 'null':
|
||||
# nothing is the same as the null origin
|
||||
return False
|
||||
|
||||
if not isinstance(websocket_origin, tuple) or not len(websocket_origin) == 3:
|
||||
raise ValueError("'websocket_origin' must be a 3-tuple")
|
||||
|
||||
(origin_scheme, origin_host, origin_port) = websocket_origin
|
||||
|
||||
# so, theoretically we should match on the 3-tuple of (scheme,
|
||||
# origin, port) to follow the RFC. However, the existing API just
|
||||
# allows you to pass a list of regular expressions that match
|
||||
# against the Origin header -- so to keep that API working, we
|
||||
# just match a reconstituted/sanitized Origin line against the
|
||||
# regular expressions. We *do* explicitly match the scheme first,
|
||||
# however!
|
||||
|
||||
# therefore, the default of "*" will still match everything (even
|
||||
# if things are on weird ports). To be "actually secure" and pass
|
||||
# explicit ports, you can put it in your matcher
|
||||
# (e.g. "https://*.example.com:1234")
|
||||
|
||||
template = '{scheme}://{host}:{port}'
|
||||
origin_header = template.format(
|
||||
scheme=origin_scheme,
|
||||
host=origin_host,
|
||||
port=origin_port,
|
||||
)
|
||||
# so, this will be matching against e.g. "http://example.com:8080"
|
||||
for origin_pattern in host_policy:
|
||||
if origin_pattern.match(origin_header):
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
|
||||
class TrafficStats(object):
|
||||
|
||||
def __init__(self):
|
||||
@ -453,6 +530,7 @@ class WebSocketProtocol(object):
|
||||
'flashSocketPolicy',
|
||||
'allowedOrigins',
|
||||
'allowedOriginsPatterns',
|
||||
'allowNullOrigin',
|
||||
'maxConnections']
|
||||
"""
|
||||
Configuration attributes specific to servers.
|
||||
@ -2577,19 +2655,32 @@ class WebSocketServerProtocol(WebSocketProtocol):
|
||||
if http_headers_cnt[websocket_origin_header_key] > 1:
|
||||
return self.failHandshake("HTTP Origin header appears more than once in opening handshake request")
|
||||
self.websocket_origin = self.http_headers[websocket_origin_header_key].strip()
|
||||
try:
|
||||
origin_tuple = _url_to_origin(self.websocket_origin)
|
||||
except ValueError as e:
|
||||
return self.failHandshake(
|
||||
"HTTP Origin header invalid: {}".format(e)
|
||||
)
|
||||
have_origin = True
|
||||
else:
|
||||
# non-browser clients are allowed to omit this header
|
||||
pass
|
||||
have_origin = False
|
||||
|
||||
# check allowed WebSocket origins
|
||||
#
|
||||
origin_is_allowed = False
|
||||
for origin_pattern in self.allowedOriginsPatterns:
|
||||
if origin_pattern.match(self.websocket_origin):
|
||||
if have_origin:
|
||||
if origin_tuple == 'null' and self.factory.allowNullOrigin:
|
||||
origin_is_allowed = True
|
||||
break
|
||||
if not origin_is_allowed:
|
||||
return self.failHandshake("WebSocket connection denied: origin '{0}' not allowed".format(self.websocket_origin))
|
||||
else:
|
||||
origin_is_allowed = _is_same_origin(
|
||||
origin_tuple,
|
||||
'https' if self.factory.isSecure else 'http',
|
||||
self.factory.externalPort or self.factory.port,
|
||||
self.allowedOriginsPatterns,
|
||||
)
|
||||
if not origin_is_allowed:
|
||||
return self.failHandshake(
|
||||
"WebSocket connection denied: origin '{0}' "
|
||||
"not allowed".format(self.websocket_origin)
|
||||
)
|
||||
|
||||
# Sec-WebSocket-Key
|
||||
#
|
||||
@ -3079,6 +3170,7 @@ class WebSocketServerFactory(WebSocketFactory):
|
||||
# check WebSocket origin against this list
|
||||
self.allowedOrigins = ["*"]
|
||||
self.allowedOriginsPatterns = wildcards2patterns(self.allowedOrigins)
|
||||
self.allowNullOrigin = True
|
||||
|
||||
# maximum number of concurrent connections
|
||||
self.maxConnections = 0
|
||||
@ -3105,6 +3197,7 @@ class WebSocketServerFactory(WebSocketFactory):
|
||||
serveFlashSocketPolicy=None,
|
||||
flashSocketPolicy=None,
|
||||
allowedOrigins=None,
|
||||
allowNullOrigin=False,
|
||||
maxConnections=None):
|
||||
"""
|
||||
Set WebSocket protocol options used as defaults for new protocol instances.
|
||||
@ -3152,8 +3245,13 @@ class WebSocketServerFactory(WebSocketFactory):
|
||||
:param flashSocketPolicy: The flash socket policy to be served when we are serving the Flash Socket Policy on this protocol
|
||||
and when Flash tried to connect to the destination port. It must end with a null character.
|
||||
:type flashSocketPolicy: str or None
|
||||
|
||||
:param allowedOrigins: A list of allowed WebSocket origins (with '*' as a wildcard character).
|
||||
:type allowedOrigins: list or None
|
||||
|
||||
:param allowNullOrigin: if True, allow WebSocket connections whose `Origin:` is `"null"`.
|
||||
:type allowNullOrigin: bool
|
||||
|
||||
:param maxConnections: Maximum number of concurrent connections. Set to `0` to disable (default: `0`).
|
||||
:type maxConnections: int or None
|
||||
"""
|
||||
@ -3227,6 +3325,10 @@ class WebSocketServerFactory(WebSocketFactory):
|
||||
self.allowedOrigins = allowedOrigins
|
||||
self.allowedOriginsPatterns = wildcards2patterns(self.allowedOrigins)
|
||||
|
||||
if allowNullOrigin not in [True, False]:
|
||||
raise ValueError('allowNullOrigin must be a bool')
|
||||
self.allowNullOrigin = allowNullOrigin
|
||||
|
||||
if maxConnections is not None and maxConnections != self.maxConnections:
|
||||
assert(type(maxConnections) in six.integer_types)
|
||||
assert(maxConnections >= 0)
|
||||
|
@ -158,6 +158,8 @@ if os.environ.get('USE_TWISTED', False):
|
||||
self.proto.connectionMade()
|
||||
|
||||
def tearDown(self):
|
||||
if self.proto.openHandshakeTimeoutCall:
|
||||
self.proto.openHandshakeTimeoutCall.cancel()
|
||||
self.factory.doStop()
|
||||
# not really necessary, but ...
|
||||
del self.factory
|
||||
|
@ -5,6 +5,11 @@
|
||||
Changelog
|
||||
=========
|
||||
|
||||
0.14.2
|
||||
------
|
||||
|
||||
* fix: `#691 <https://github.com/crossbario/autobahn-python/issues/691>`_ (**security**) If the `allowedOrigins` websocket option was set, the resulting matching was insufficient and would allow more origins than intended
|
||||
|
||||
0.14.1
|
||||
------
|
||||
|
||||
|
@ -527,7 +527,7 @@ Server-Only Options
|
||||
- perMessageCompressionAccept: if provided, a single-argument callable
|
||||
- serveFlashSocketPolicy: if True, server a flash policy file (default: False)
|
||||
- flashSocketPolicy: the actual flash policy to serve (default one allows everything)
|
||||
- allowedOrigins: a list of origins to allow, with embedded `*`'s for wildcards; these are turned into regular expressions (e.g. `*.example.com` becomes `^.*\.example\.com$`)
|
||||
- allowedOrigins: a list of origins to allow, with embedded `*`'s for wildcards; these are turned into regular expressions (e.g. `https://*.example.com:443` becomes `^https://.*\.example\.com:443$`). When doing the matching, the origin is **always** of the form `scheme://host:port` with an explicit port. By default, we match with `*` (that is, anything). To match all subdomains of `example.com` on any scheme and port, you'd need `*://*.example.com:*`
|
||||
- maxConnections: total concurrent connections allowed (default 0, unlimited)
|
||||
|
||||
|
||||
|
@ -27,10 +27,11 @@
|
||||
import sys
|
||||
|
||||
from twisted.internet import reactor, ssl
|
||||
from twisted.python import log
|
||||
from twisted.web.server import Site
|
||||
from twisted.web.static import File
|
||||
|
||||
import txaio
|
||||
|
||||
from autobahn.twisted.websocket import WebSocketServerFactory, \
|
||||
WebSocketServerProtocol, \
|
||||
listenWS
|
||||
@ -44,7 +45,7 @@ class EchoServerProtocol(WebSocketServerProtocol):
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
||||
log.startLogging(sys.stdout)
|
||||
txaio.start_logging(level='debug')
|
||||
|
||||
# SSL server context: load server key and certificate
|
||||
# We use this for both WS and Web!
|
||||
@ -52,6 +53,17 @@ if __name__ == '__main__':
|
||||
'keys/server.crt')
|
||||
|
||||
factory = WebSocketServerFactory(u"wss://127.0.0.1:9000")
|
||||
# by default, allowedOrigins is "*" and will work fine out of the
|
||||
# box, but we can do better and be more-explicit about what we
|
||||
# allow. We are serving the Web content on 8080, but our WebSocket
|
||||
# listener is on 9000 so the Origin sent by the browser will be
|
||||
# from port 8080...
|
||||
factory.setProtocolOptions(
|
||||
allowedOrigins=[
|
||||
"https://127.0.0.1:8080",
|
||||
"https://localhost:8080",
|
||||
]
|
||||
)
|
||||
|
||||
factory.protocol = EchoServerProtocol
|
||||
listenWS(factory, contextFactory)
|
||||
@ -59,7 +71,7 @@ if __name__ == '__main__':
|
||||
webdir = File(".")
|
||||
webdir.contentTypes['.crt'] = 'application/x-x509-ca-cert'
|
||||
web = Site(webdir)
|
||||
# reactor.listenSSL(8080, web, contextFactory)
|
||||
reactor.listenTCP(8080, web)
|
||||
reactor.listenSSL(8080, web, contextFactory)
|
||||
#reactor.listenTCP(8080, web)
|
||||
|
||||
reactor.run()
|
||||
|
Loading…
Reference in New Issue
Block a user