Add allowNullOrigin and fix allowed-origin checks

This commit is contained in:
meejah 2016-06-15 12:42:43 -06:00
parent 1335725243
commit 2ef13a6804
7 changed files with 337 additions and 15 deletions

View File

@ -29,11 +29,13 @@ from __future__ import absolute_import, print_function
import unittest2 as unittest
from mock import Mock
from autobahn.util import wildcards2patterns
from autobahn.twisted.websocket import WebSocketServerFactory
from autobahn.twisted.websocket import WebSocketServerProtocol
from twisted.python.failure import Failure
from twisted.internet.error import ConnectionDone, ConnectionAborted, \
ConnectionLost
from twisted.test.proto_helpers import StringTransport
from autobahn.test import FakeTransport
@ -120,3 +122,197 @@ class Hixie76RejectionTests(unittest.TestCase):
p.processHandshake()
self.assertIn(b"HTTP/1.1 400", t._written)
self.assertIn(b"Hixie76 protocol not supported", t._written)
class WebSocketOriginMatching(unittest.TestCase):
"""
Test that we match Origin: headers properly, when asked to
"""
def setUp(self):
self.factory = WebSocketServerFactory()
self.factory.setProtocolOptions(
allowedOrigins=[u'127.0.0.1:*', u'*.example.com:*']
)
self.proto = WebSocketServerProtocol()
self.proto.transport = StringTransport()
self.proto.factory = self.factory
self.proto.failHandshake = Mock()
self.proto._connectionMade()
def tearDown(self):
for call in [
self.proto.autoPingPendingCall,
self.proto.autoPingTimeoutCall,
self.proto.openHandshakeTimeoutCall,
self.proto.closeHandshakeTimeoutCall,
]:
if call is not None:
call.cancel()
def test_match_full_origin(self):
self.proto.data = b"\r\n".join([
b'GET /ws HTTP/1.1',
b'Host: www.example.com',
b'Sec-WebSocket-Version: 13',
b'Origin: http://www.example.com.malicious.com',
b'Sec-WebSocket-Extensions: permessage-deflate',
b'Sec-WebSocket-Key: tXAxWFUqnhi86Ajj7dRY5g==',
b'Connection: keep-alive, Upgrade',
b'Upgrade: websocket',
b'\r\n', # last string doesn't get a \r\n from join()
])
self.proto.consumeData()
self.assertTrue(self.proto.failHandshake.called, "Handshake should have failed")
arg = self.proto.failHandshake.mock_calls[0][1][0]
self.assertTrue('not allowed' in arg)
def test_match_wrong_scheme_origin(self):
# some monkey-business since we already did this in setUp, but
# we want a different set of matching origins
self.factory.setProtocolOptions(
allowedOrigins=[u'http://*.example.com:*']
)
self.proto.allowedOriginsPatterns = self.factory.allowedOriginsPatterns
self.proto.allowedOrigins = self.factory.allowedOrigins
# the actual test
self.factory.isSecure = False
self.proto.data = b"\r\n".join([
b'GET /ws HTTP/1.1',
b'Host: www.example.com',
b'Sec-WebSocket-Version: 13',
b'Origin: https://www.example.com',
b'Sec-WebSocket-Extensions: permessage-deflate',
b'Sec-WebSocket-Key: tXAxWFUqnhi86Ajj7dRY5g==',
b'Connection: keep-alive, Upgrade',
b'Upgrade: websocket',
b'\r\n', # last string doesn't get a \r\n from join()
])
self.proto.consumeData()
self.assertTrue(self.proto.failHandshake.called, "Handshake should have failed")
arg = self.proto.failHandshake.mock_calls[0][1][0]
self.assertTrue('not allowed' in arg)
def test_match_origin_secure_scheme(self):
self.factory.isSecure = True
self.factory.port = 443
self.proto.data = b"\r\n".join([
b'GET /ws HTTP/1.1',
b'Host: www.example.com',
b'Sec-WebSocket-Version: 13',
b'Origin: https://www.example.com',
b'Sec-WebSocket-Extensions: permessage-deflate',
b'Sec-WebSocket-Key: tXAxWFUqnhi86Ajj7dRY5g==',
b'Connection: keep-alive, Upgrade',
b'Upgrade: websocket',
b'\r\n', # last string doesn't get a \r\n from join()
])
self.proto.consumeData()
self.assertFalse(self.proto.failHandshake.called, "Handshake should have succeeded")
def test_match_origin_documentation_example(self):
"""
Test the examples from the docs
"""
self.factory.setProtocolOptions(
allowedOrigins=['*://*.example.com:*']
)
self.factory.isSecure = True
self.factory.port = 443
self.proto.data = b"\r\n".join([
b'GET /ws HTTP/1.1',
b'Host: www.example.com',
b'Sec-WebSocket-Version: 13',
b'Origin: http://www.example.com',
b'Sec-WebSocket-Extensions: permessage-deflate',
b'Sec-WebSocket-Key: tXAxWFUqnhi86Ajj7dRY5g==',
b'Connection: keep-alive, Upgrade',
b'Upgrade: websocket',
b'\r\n', # last string doesn't get a \r\n from join()
])
self.proto.consumeData()
self.assertFalse(self.proto.failHandshake.called, "Handshake should have succeeded")
def test_match_origin_examples(self):
"""
All the example origins from RFC6454 (3.2.1)
"""
# we're just testing the low-level function here...
from autobahn.websocket.protocol import _is_same_origin, _url_to_origin
policy = wildcards2patterns(['*example.com:*'])
# should parametrize test ...
for url in ['http://example.com/', 'http://example.com:80/',
'http://example.com/path/file',
'http://example.com/;semi=true',
# 'http://example.com./',
'//example.com/',
'http://@example.com']:
self.assertTrue(_is_same_origin(_url_to_origin(url), 'http', 80, policy), url)
def test_match_origin_counter_examples(self):
"""
All the example 'not-same' origins from RFC6454 (3.2.1)
"""
# we're just testing the low-level function here...
from autobahn.websocket.protocol import _is_same_origin, _url_to_origin
policy = wildcards2patterns(['example.com'])
for url in ['http://ietf.org/', 'http://example.org/',
'https://example.com/', 'http://example.com:8080/',
'http://www.example.com/']:
self.assertFalse(_is_same_origin(_url_to_origin(url), 'http', 80, policy))
def test_match_origin_edge(self):
# we're just testing the low-level function here...
from autobahn.websocket.protocol import _is_same_origin, _url_to_origin
policy = wildcards2patterns(['http://*example.com:80'])
self.assertTrue(
_is_same_origin(_url_to_origin('http://example.com:80'), 'http', 80, policy)
)
self.assertFalse(
_is_same_origin(_url_to_origin('http://example.com:81'), 'http', 81, policy)
)
self.assertFalse(
_is_same_origin(_url_to_origin('https://example.com:80'), 'http', 80, policy)
)
def test_origin_from_url(self):
from autobahn.websocket.protocol import _url_to_origin
# basic function
self.assertEqual(
_url_to_origin('http://example.com'),
('http', 'example.com', 80)
)
# should lower-case scheme
self.assertEqual(
_url_to_origin('hTTp://example.com'),
('http', 'example.com', 80)
)
def test_origin_file(self):
from autobahn.websocket.protocol import _url_to_origin
self.assertEqual('null', _url_to_origin('file:///etc/passwd'))
def test_origin_null(self):
from autobahn.websocket.protocol import _is_same_origin, _url_to_origin
self.assertEqual('null', _url_to_origin('null'))
self.assertFalse(
_is_same_origin(_url_to_origin('null'), 'http', 80, [])
)
self.assertFalse(
_is_same_origin(_url_to_origin('null'), 'https', 80, [])
)
self.assertFalse(
_is_same_origin(_url_to_origin('null'), '', 80, [])
)
self.assertFalse(
_is_same_origin(_url_to_origin('null'), None, 80, [])
)

View File

@ -637,10 +637,15 @@ def wildcards2patterns(wildcards):
:param wildcards: List of wildcard strings to compute regular expression patterns for.
:type wildcards: list of str
:returns: Computed regular expressions.
:rtype: list of obj
"""
return [re.compile(wc.replace('.', '\.').replace('*', '.*')) for wc in wildcards]
# note that we add the ^ and $ so that the *entire* string must
# match. Without this, e.g. a prefix will match:
# re.match('.*good\\.com', 'good.com.evil.com') # match!
# re.match('.*good\\.com$', 'good.com.evil.com') # no match!
return [re.compile('^' + wc.replace('.', '\.').replace('*', '.*') + '$') for wc in wildcards]
class ObservableMixin(object):

View File

@ -74,6 +74,83 @@ __all__ = ("ConnectionRequest",
"WebSocketClientFactory")
def _url_to_origin(url):
"""
Given an RFC6455 Origin URL, this returns the (scheme, host, port)
triple. If there's no port, and the scheme isn't http or https
then port will be None
"""
if url.lower() == 'null':
return 'null'
res = urllib.parse.urlsplit(url)
scheme = res.scheme.lower()
if scheme == 'file':
# when browsing local files, Chrome sends file:// URLs,
# Firefox sends 'null'
return 'null'
host = res.hostname
port = res.port
if port is None:
try:
port = {'https': 443, 'http': 80}[scheme]
except KeyError:
port = None
if not host:
raise ValueError("No host part in Origin '{}'".format(url))
return scheme, host, port
def _is_same_origin(websocket_origin, host_scheme, host_port, host_policy):
"""
Internal helper. Returns True if the provided websocket_origin
triple should be considered valid given the provided policy and
expected host_port.
Currently, the policy is just the list of allowedOriginsPatterns
from the WebSocketProtocol instance. Schemes and ports are matched
first, and only if there is not a mismatch do we compare each
allowed-origin pattern against the host.
"""
if websocket_origin == 'null':
# nothing is the same as the null origin
return False
if not isinstance(websocket_origin, tuple) or not len(websocket_origin) == 3:
raise ValueError("'websocket_origin' must be a 3-tuple")
(origin_scheme, origin_host, origin_port) = websocket_origin
# so, theoretically we should match on the 3-tuple of (scheme,
# origin, port) to follow the RFC. However, the existing API just
# allows you to pass a list of regular expressions that match
# against the Origin header -- so to keep that API working, we
# just match a reconstituted/sanitized Origin line against the
# regular expressions. We *do* explicitly match the scheme first,
# however!
# therefore, the default of "*" will still match everything (even
# if things are on weird ports). To be "actually secure" and pass
# explicit ports, you can put it in your matcher
# (e.g. "https://*.example.com:1234")
template = '{scheme}://{host}:{port}'
origin_header = template.format(
scheme=origin_scheme,
host=origin_host,
port=origin_port,
)
# so, this will be matching against e.g. "http://example.com:8080"
for origin_pattern in host_policy:
if origin_pattern.match(origin_header):
return True
return False
class TrafficStats(object):
def __init__(self):
@ -453,6 +530,7 @@ class WebSocketProtocol(object):
'flashSocketPolicy',
'allowedOrigins',
'allowedOriginsPatterns',
'allowNullOrigin',
'maxConnections']
"""
Configuration attributes specific to servers.
@ -2577,19 +2655,32 @@ class WebSocketServerProtocol(WebSocketProtocol):
if http_headers_cnt[websocket_origin_header_key] > 1:
return self.failHandshake("HTTP Origin header appears more than once in opening handshake request")
self.websocket_origin = self.http_headers[websocket_origin_header_key].strip()
try:
origin_tuple = _url_to_origin(self.websocket_origin)
except ValueError as e:
return self.failHandshake(
"HTTP Origin header invalid: {}".format(e)
)
have_origin = True
else:
# non-browser clients are allowed to omit this header
pass
have_origin = False
# check allowed WebSocket origins
#
origin_is_allowed = False
for origin_pattern in self.allowedOriginsPatterns:
if origin_pattern.match(self.websocket_origin):
if have_origin:
if origin_tuple == 'null' and self.factory.allowNullOrigin:
origin_is_allowed = True
break
if not origin_is_allowed:
return self.failHandshake("WebSocket connection denied: origin '{0}' not allowed".format(self.websocket_origin))
else:
origin_is_allowed = _is_same_origin(
origin_tuple,
'https' if self.factory.isSecure else 'http',
self.factory.externalPort or self.factory.port,
self.allowedOriginsPatterns,
)
if not origin_is_allowed:
return self.failHandshake(
"WebSocket connection denied: origin '{0}' "
"not allowed".format(self.websocket_origin)
)
# Sec-WebSocket-Key
#
@ -3079,6 +3170,7 @@ class WebSocketServerFactory(WebSocketFactory):
# check WebSocket origin against this list
self.allowedOrigins = ["*"]
self.allowedOriginsPatterns = wildcards2patterns(self.allowedOrigins)
self.allowNullOrigin = True
# maximum number of concurrent connections
self.maxConnections = 0
@ -3105,6 +3197,7 @@ class WebSocketServerFactory(WebSocketFactory):
serveFlashSocketPolicy=None,
flashSocketPolicy=None,
allowedOrigins=None,
allowNullOrigin=False,
maxConnections=None):
"""
Set WebSocket protocol options used as defaults for new protocol instances.
@ -3152,8 +3245,13 @@ class WebSocketServerFactory(WebSocketFactory):
:param flashSocketPolicy: The flash socket policy to be served when we are serving the Flash Socket Policy on this protocol
and when Flash tried to connect to the destination port. It must end with a null character.
:type flashSocketPolicy: str or None
:param allowedOrigins: A list of allowed WebSocket origins (with '*' as a wildcard character).
:type allowedOrigins: list or None
:param allowNullOrigin: if True, allow WebSocket connections whose `Origin:` is `"null"`.
:type allowNullOrigin: bool
:param maxConnections: Maximum number of concurrent connections. Set to `0` to disable (default: `0`).
:type maxConnections: int or None
"""
@ -3227,6 +3325,10 @@ class WebSocketServerFactory(WebSocketFactory):
self.allowedOrigins = allowedOrigins
self.allowedOriginsPatterns = wildcards2patterns(self.allowedOrigins)
if allowNullOrigin not in [True, False]:
raise ValueError('allowNullOrigin must be a bool')
self.allowNullOrigin = allowNullOrigin
if maxConnections is not None and maxConnections != self.maxConnections:
assert(type(maxConnections) in six.integer_types)
assert(maxConnections >= 0)

View File

@ -158,6 +158,8 @@ if os.environ.get('USE_TWISTED', False):
self.proto.connectionMade()
def tearDown(self):
if self.proto.openHandshakeTimeoutCall:
self.proto.openHandshakeTimeoutCall.cancel()
self.factory.doStop()
# not really necessary, but ...
del self.factory

View File

@ -5,6 +5,11 @@
Changelog
=========
0.14.2
------
* fix: `#691 <https://github.com/crossbario/autobahn-python/issues/691>`_ (**security**) If the `allowedOrigins` websocket option was set, the resulting matching was insufficient and would allow more origins than intended
0.14.1
------

View File

@ -527,7 +527,7 @@ Server-Only Options
- perMessageCompressionAccept: if provided, a single-argument callable
- serveFlashSocketPolicy: if True, server a flash policy file (default: False)
- flashSocketPolicy: the actual flash policy to serve (default one allows everything)
- allowedOrigins: a list of origins to allow, with embedded `*`'s for wildcards; these are turned into regular expressions (e.g. `*.example.com` becomes `^.*\.example\.com$`)
- allowedOrigins: a list of origins to allow, with embedded `*`'s for wildcards; these are turned into regular expressions (e.g. `https://*.example.com:443` becomes `^https://.*\.example\.com:443$`). When doing the matching, the origin is **always** of the form `scheme://host:port` with an explicit port. By default, we match with `*` (that is, anything). To match all subdomains of `example.com` on any scheme and port, you'd need `*://*.example.com:*`
- maxConnections: total concurrent connections allowed (default 0, unlimited)

View File

@ -27,10 +27,11 @@
import sys
from twisted.internet import reactor, ssl
from twisted.python import log
from twisted.web.server import Site
from twisted.web.static import File
import txaio
from autobahn.twisted.websocket import WebSocketServerFactory, \
WebSocketServerProtocol, \
listenWS
@ -44,7 +45,7 @@ class EchoServerProtocol(WebSocketServerProtocol):
if __name__ == '__main__':
log.startLogging(sys.stdout)
txaio.start_logging(level='debug')
# SSL server context: load server key and certificate
# We use this for both WS and Web!
@ -52,6 +53,17 @@ if __name__ == '__main__':
'keys/server.crt')
factory = WebSocketServerFactory(u"wss://127.0.0.1:9000")
# by default, allowedOrigins is "*" and will work fine out of the
# box, but we can do better and be more-explicit about what we
# allow. We are serving the Web content on 8080, but our WebSocket
# listener is on 9000 so the Origin sent by the browser will be
# from port 8080...
factory.setProtocolOptions(
allowedOrigins=[
"https://127.0.0.1:8080",
"https://localhost:8080",
]
)
factory.protocol = EchoServerProtocol
listenWS(factory, contextFactory)
@ -59,7 +71,7 @@ if __name__ == '__main__':
webdir = File(".")
webdir.contentTypes['.crt'] = 'application/x-x509-ca-cert'
web = Site(webdir)
# reactor.listenSSL(8080, web, contextFactory)
reactor.listenTCP(8080, web)
reactor.listenSSL(8080, web, contextFactory)
#reactor.listenTCP(8080, web)
reactor.run()