Looks like it works

This commit is contained in:
izderadicka 2016-05-21 10:16:49 +02:00
parent 1c86fefa73
commit 5a72a46275
5 changed files with 829 additions and 0 deletions

4
.gitignore vendored
View File

@ -30,3 +30,7 @@ htmlcov/
*.pid
.cache/
node.key
/.project
/.pydevproject
/.venv/
/my_tests/

View File

@ -0,0 +1,440 @@
#from __future__ import absolute_import
import asyncio #TODO: should we support trollious - it has been deprecated - http://trollius.readthedocs.io/deprecated.htm
import struct
import math
from autobahn.asyncio.util import _LazyHexFormatter
from autobahn.wamp.exception import ProtocolError, SerializationError, TransportLost
from autobahn.asyncio.util import peer2str, transport_channel_id, get_serializes
import txaio
__all__ = (
'WampRawSocketServerProtocol',
'WampRawSocketClientProtocol',
'WampRawSocketServerFactory',
'WampRawSocketClientFactory'
)
txaio.use_asyncio()
FRAME_TYPE_DATA=0
FRAME_TYPE_PING=1
FRAME_TYPE_PONG=2
class PrefixProtocol(asyncio.Protocol):
prefix_format='!L'
prefix_length=struct.calcsize(prefix_format)
max_length = 16 * 1024 *1024
max_length_send = max_length
log=txaio.make_logger() # @UndefinedVariable
def connection_made(self, transport):
self.transport=transport
peer = transport.get_extra_info('peername')
self.peer = peer2str(peer)
self.log.debug('RawSocker Asyncio: Connection made with peer {peer}'.format(peer=self.peer))
self._buffer=b''
self._header=None
def connection_lost(self, exc):
self.log.debug('RawSocker Asyncio: Connection lost')
self.transport=None
def protocol_error(self, msg):
self.log.error(msg)
self.transport.close()
def sendString(self, data):
l=len(data)
if l>self.max_length_send:
raise ValueError('Data too big')
header=struct.pack(self.prefix_format, len(data))
self.transport.write(header)
self.transport.write(data)
def ping(self, data):
raise NotImplementedError()
def pong(self,data):
raise NotImplementedError()
def data_received(self, data):
self._buffer+=data
pos=0
remaining=len(self._buffer)
while remaining >= self.prefix_length:
# do not recalculate header if available from previous call
if self._header:
frame_type, frame_length = self._header
else:
header=self._buffer[pos:pos+self.prefix_length]
frame_type=header[0]& 0b00000111
if frame_type > FRAME_TYPE_PONG:
self.protocol_error('Invalid frame type')
return
frame_length= struct.unpack(self.prefix_format, b'\0'+header[1:])[0]
if frame_length> self.max_length:
self.protocol_error('Frame too big')
return
if remaining-self.prefix_length >= frame_length:
self._header=None
pos+=self.prefix_length
remaining-=self.prefix_length
data=self._buffer[pos:pos+frame_length]
pos+=frame_length
remaining-=frame_length
if frame_type == FRAME_TYPE_DATA:
self.stringReceived(data)
elif frame_type == FRAME_TYPE_PING:
self.ping(data)
elif frame_type == FRAME_TYPE_PONG:
self.pong(data)
else:
# save heaader
self._header = frame_type, frame_length
break
self._buffer=self._buffer[:remaining]
def stringReceived(self, data):
raise NotImplementedError()
class RawSocketProtocol(PrefixProtocol):
def __init__(self, max_size=None):
if max_size:
exp=math.ceil(math.log2(max_size))-9
if exp>15:
raise ValueError('Maximum length is 16M')
self.max_length=2**(exp+9)
self._length_exp=exp
else:
self._length_exp=15
self.max_length=2**24
def connection_made(self, transport):
PrefixProtocol.connection_made(self, transport)
self._handshake_done=False
def _on_handshake_complete(self):
raise NotImplementedError()
def parse_handshake(self):
if self._buffer[0] != 0x7F:
raise HandshakeError('Invalid magic byte in handshake')
return
b1=self._buffer[1]
ser=b1 & 0x0F
lexp=b1>>4
self.max_length_send=2**((lexp)+9)
if self._buffer[2] !=0 or self._buffer[3]!=0:
raise HandshakeError('Reserved bytes must be zero')
return ser, lexp
def process_handshake(self):
raise NotImplementedError()
def data_received(self, data):
self.log.debug('RawSocker Asyncio: data received {data}', data=_LazyHexFormatter(data))
if self._handshake_done:
return PrefixProtocol.data_received(self, data)
else:
self._buffer+=data
if len(self._buffer)>=4:
try:
self.process_handshake()
except HandshakeError as e:
self.protocol_error('Handshake error : {err}'.format(err=e))
return
self._handshake_done=True
self._on_handshake_complete()
data=self._buffer[4:]
self._buffer=b''
if data:
PrefixProtocol.data_received(self, data)
ERR_SERIALIZER_UNSUPPORTED=1
ERRMAP = {
0: "illegal (must not be used)",
1: "serializer unsupported",
2: "maximum message length unacceptable",
3: "use of reserved bits (unsupported feature)",
4: "maximum connection count reached"
}
class HandshakeError(Exception):
def __init__(self,msg,code=0):
Exception.__init__(self, msg if not code else msg+' : %s' % ERRMAP.get(code))
class RawSocketClientProtocol(RawSocketProtocol):
def __init__(self, max_size=None):
RawSocketProtocol.__init__(self, max_size=max_size)
def check_serializer(self, ser_id):
return True
def process_handshake(self):
ser_id,err=self.parse_handshake()
if ser_id == 0:
raise HandshakeError('Server returned handshake error', err)
if self.serializer_id != ser_id:
raise HandshakeError('Server returned different serializer {0} then requested {1}'.format(ser_id, self.serializer_id))
@property
def serializer_id(self):
raise NotImplementedError()
def connection_made(self, transport):
RawSocketProtocol.connection_made(self, transport)
#start handsjake
hs=bytes([0x7F,
self._length_exp << 4 | self.serializer_id,
0, 0])
transport.write(hs)
self.log.debug('RawSocket Asyncio: Client handshake sent')
class RawSocketServerProtocol(RawSocketProtocol):
def __init__(self, max_size=None):
RawSocketProtocol.__init__(self, max_size=max_size)
def supports_serializer(self,ser_id):
raise NotImplementedError()
def process_handshake(self):
def send_response(lexp,ser_id):
b2=lexp<<4 | (ser_id & 0x0f)
self.transport.write(bytes( bytearray([0x7F, b2, 0, 0])))
ser_id,lexp=self.parse_handshake()
if not self.supports_serializer(ser_id):
send_response(ERR_SERIALIZER_UNSUPPORTED, 0)
raise HandshakeError('Serializer unsupported : {ser_id}'.format(ser_id=ser_id))
send_response(self._length_exp, ser_id)
# this is transport independent part of WAMP protocol
class WampRawSocketMixinGeneral(object):
def _on_handshake_complete(self):
self.log.debug("WampRawSocketProtocol: Handshake complete")
try:
self._session = self.factory._factory()
self._session.onOpen(self)
except Exception as e:
# Exceptions raised in onOpen are fatal ..
self.log.warn("WampRawSocketProtocol: ApplicationSession constructor / onOpen raised ({err})", err=e)
self.abort()
else:
self.log.info("ApplicationSession started.")
def stringReceived(self, payload):
self.log.debug("WampRawSocketProtocol: RX octets: {octets}", octets=_LazyHexFormatter(payload))
try:
for msg in self._serializer.unserialize(payload):
self.log.debug("WampRawSocketProtocol: RX WAMP message: {msg}", msg=msg)
self._session.onMessage(msg)
except ProtocolError as e:
self.log.warn("WampRawSocketProtocol: WAMP Protocol Error ({err}) - aborting connection", err=e)
self.abort()
except Exception as e:
self.log.warn("WampRawSocketProtocol: WAMP Internal Error ({err}) - aborting connection", err=e)
self.abort()
def send(self, msg):
"""
Implements :func:`autobahn.wamp.interfaces.ITransport.send`
"""
if self.isOpen():
self.log.debug("WampRawSocketProtocol: TX WAMP message: {msg}", msg=msg)
try:
payload, _ = self._serializer.serialize(msg)
except Exception as e:
# all exceptions raised from above should be serialization errors ..
raise SerializationError("WampRawSocketProtocol: unable to serialize WAMP application payload ({0})".format(e))
else:
self.sendString(payload)
self.log.debug("WampRawSocketProtocol: TX octets: {octets}", octets=_LazyHexFormatter(payload))
else:
raise TransportLost()
def isOpen(self):
"""
Implements :func:`autobahn.wamp.interfaces.ITransport.isOpen`
"""
return hasattr(self, '_session') and self._session is not None
# this is asyncio dependent part of WAMP protocol
class WampRawSocketMixinAsyncio():
"""
Base class for asyncio-based WAMP-over-RawSocket protocols.
"""
def connection_lost(self, exc):
try:
wasClean = exc is None
self._session.onClose(wasClean)
except Exception as e:
# silently ignore exceptions raised here ..
self.log.warn("WampRawSocketProtocol: ApplicationSession.onClose raised ({err})", err=e)
self._session = None
def close(self):
"""
Implements :func:`autobahn.wamp.interfaces.ITransport.close`
"""
if self.isOpen():
self.transport.close()
else:
raise TransportLost()
def abort(self):
"""
Implements :func:`autobahn.wamp.interfaces.ITransport.abort`
"""
if self.isOpen():
if hasattr(self.transport, 'abort'):
# ProcessProtocol lacks abortConnection()
self.transport.abort()
else:
self.transport.close()
else:
raise TransportLost()
class WampRawSocketServerProtocol(WampRawSocketMixinGeneral, WampRawSocketMixinAsyncio, RawSocketServerProtocol):
"""
Base class for Twisted-based WAMP-over-RawSocket server protocols.
"""
def supports_serializer(self, ser_id):
if ser_id in self.factory._serializers:
self._serializer = self.factory._serializers[ser_id]()
self.log.debug(
"WampRawSocketProtocol: client wants to use serializer '{serializer}'",
serializer=ser_id,
)
return True
else:
self.log.debug(
"WampRawSocketProtocol: opening handshake - no suitable serializer found (client requested {serializer}, and we have {serializers}",
serializer=ser_id,
serializers=self.factory._serializers.keys(),
)
self.abort()
return False
def get_channel_id(self, channel_id_type=u'tls-unique'):
"""
Implements :func:`autobahn.wamp.interfaces.ITransport.get_channel_id`
"""
return transport_channel_id(self.transport, is_server=True, channel_id_type=channel_id_type)
class WampRawSocketClientProtocol(WampRawSocketMixinGeneral, WampRawSocketMixinAsyncio, RawSocketClientProtocol):
"""
Base class for Twisted-based WAMP-over-RawSocket client protocols.
"""
@property
def serializer_id(self):
if not hasattr(self, '_serializer'):
self._serializer=self.factory._serializer()
return self._serializer.RAWSOCKET_SERIALIZER_ID
def get_channel_id(self, channel_id_type=u'tls-unique'):
"""
Implements :func:`autobahn.wamp.interfaces.ITransport.get_channel_id`
"""
return transport_channel_id(self.transport, is_server=False, channel_id_type=channel_id_type)
class WampRawSocketFactory(object):
"""
Adapter class for asyncio-based WebSocket client and server factories.def dataReceived(self, data):
"""
log = txaio.make_logger() # @UndefinedVariable
def __call__(self):
proto = self.protocol()
proto.factory = self
return proto
class WampRawSocketServerFactory(WampRawSocketFactory):
"""
Base class for Twisted-based WAMP-over-RawSocket server factories.
"""
protocol = WampRawSocketServerProtocol
def __init__(self, factory, serializers=None):
"""
:param factory: A callable that produces instances that implement
:class:`autobahn.wamp.interfaces.ITransportHandler`
:type factory: callable
:param serializers: A list of WAMP serializers to use (or None for default
serializers). Serializers must implement
:class:`autobahn.wamp.interfaces.ISerializer`.
:type serializers: list
"""
assert(callable(factory))
self._factory = factory
if serializers is None:
serializers = get_serializes()
if not serializers:
raise Exception("could not import any WAMP serializers")
self._serializers = {ser.RAWSOCKET_SERIALIZER_ID:ser for ser in serializers}
class WampRawSocketClientFactory(WampRawSocketFactory):
"""
Base class for Twisted-based WAMP-over-RawSocket client factories.
"""
protocol = WampRawSocketClientProtocol
def __init__(self, factory, serializer=None):
"""
:param factory: A callable that produces instances that implement
:class:`autobahn.wamp.interfaces.ITransportHandler`
:type factory: callable
:param serializer: The WAMP serializer to use (or None for default
serializer). Serializers must implement
:class:`autobahn.wamp.interfaces.ISerializer`.
:type serializer: obj
"""
assert(callable(factory))
self._factory = factory
if serializer is None:
serializers=get_serializes()
if serializers:
serializer=serializers[0]
if serializer is None:
raise Exception("could not import any WAMP serializer")
self._serializer = serializer

View File

@ -0,0 +1,205 @@
'''
Created on May 19, 2016
@author: ivan
'''
from unittest import TestCase
from unittest.mock import Mock, call, patch
from autobahn.asyncio.rawsocket import PrefixProtocol, RawSocketClientProtocol, RawSocketServerProtocol, \
WampRawSocketClientFactory, WampRawSocketServerFactory
from autobahn.asyncio.util import get_serializes
from autobahn.wamp import message
class Test(TestCase):
def test_sers(self):
serializers=get_serializes()
self.assertTrue(len(serializers)>0)
m=serializers[0]().serialize(message.Abort('close'))
print(m)
self.assertTrue(m)
def test_prefix(self):
p=PrefixProtocol()
transport = Mock()
receiver= Mock()
p.stringReceived=receiver
p.connection_made(transport)
small_msg=b'\x00\x00\x00\x04abcd'
p.data_received(small_msg)
receiver.assert_called_once_with(b'abcd')
self.assertEqual(len(p._buffer), 0)
p.sendString(b'abcd')
#print(transport.write.call_args_list)
transport.write.assert_has_calls([call(b'\x00\x00\x00\x04'), call(b'abcd')])
transport.reset_mock()
receiver.reset_mock()
big_msg=b'\x00\x00\x00\x0C'+b'0123456789AB'
p.data_received(big_msg[0:2])
self.assertFalse(receiver.called)
p.data_received(big_msg[2:6])
self.assertFalse(receiver.called)
p.data_received(big_msg[6:11])
self.assertFalse(receiver.called)
p.data_received(big_msg[11:16])
receiver.assert_called_once_with(b'0123456789AB')
transport.reset_mock()
receiver.reset_mock()
two_messages = b'\x00\x00\x00\x04'+b'abcd'+b'\x00\x00\x00\x05' +b'12345' +b'\x00'
p.data_received(two_messages)
receiver.assert_has_calls([call(b'abcd'), call(b'12345')])
self.assertEqual(p._buffer, b'\x00' )
def test_raw_socket_server1(self):
server=RawSocketServerProtocol(max_size=10000)
ser=Mock(return_value=True)
on_hs=Mock()
transport=Mock()
receiver=Mock()
server.supports_serializer=ser
server.stringReceived=receiver
server._on_handshake_complete=on_hs
server.stringReceived=receiver
server.connection_made(transport)
hs=b'\x7F\xF1\x00\x00'+b'\x00\x00\x00\x04abcd'
server.data_received(hs)
ser.assert_called_once_with(1)
on_hs.assert_called_once_with()
self.assertTrue(transport.write.called)
transport.write.assert_called_once_with(b'\x7F\x51\x00\x00')
self.assertFalse(transport.close.called)
receiver.assert_called_once_with(b'abcd')
def test_raw_socket_server_errors(self):
server=RawSocketServerProtocol(max_size=10000)
ser=Mock(return_value=True)
on_hs=Mock()
transport=Mock()
receiver=Mock()
server.supports_serializer=ser
server.stringReceived=receiver
server._on_handshake_complete=on_hs
server.stringReceived=receiver
server.connection_made(transport)
server.data_received(b'abcdef')
transport.close.assert_called_once_with()
server=RawSocketServerProtocol(max_size=10000)
ser=Mock(return_value=False)
on_hs=Mock()
transport=Mock(spec_set=('close','write', 'get_extra_info'))
receiver=Mock()
server.supports_serializer=ser
server.stringReceived=receiver
server._on_handshake_complete=on_hs
server.stringReceived=receiver
server.connection_made(transport)
server.data_received(b'\x7F\xF1\x00\x00')
transport.close.assert_called_once_with()
transport.write.assert_called_once_with(b'\x7F\x10\x00\x00')
def test_raw_socket_client1(self):
class CP(RawSocketClientProtocol):
@property
def serializer_id(self):
return 1
client=CP()
on_hs=Mock()
transport=Mock()
receiver=Mock()
client.stringReceived=receiver
client._on_handshake_complete=on_hs
client.connection_made(transport)
client.data_received(b'\x7F\xF1\x00\x00'+b'\x00\x00\x00\x04abcd')
on_hs.assert_called_once_with()
self.assertTrue(transport.write.called)
transport.write.called_one_with(b'\x7F\xF1\x00\x00')
self.assertFalse(transport.close.called)
receiver.assert_called_once_with(b'abcd')
def test_raw_socket_client_error(self):
class CP(RawSocketClientProtocol):
@property
def serializer_id(self):
return 1
client=CP()
on_hs=Mock()
transport=Mock(spec_set=('close','write', 'get_extra_info'))
receiver=Mock()
client.stringReceived=receiver
client._on_handshake_complete=on_hs
client.connection_made(transport)
client.data_received(b'\x7F\xF1\x00\x01')
transport.close.assert_called_once_with()
def test_wamp(self):
transport=Mock(spec_set=('abort','close','write', 'get_extra_info'))
transport.write=Mock(side_effect=lambda m: messages.append(m))
client=Mock(spec=['onOpen', 'onMessage'])
def fact():
return client
messages=[]
proto=WampRawSocketClientFactory(fact)()
proto.connection_made(transport)
self.assertTrue(proto._serializer)
s=proto._serializer.RAWSOCKET_SERIALIZER_ID
proto.data_received(bytes(bytearray([0x7F, 0xF0 |s, 0, 0])))
client.onOpen.assert_called_once_with(proto)
proto.send(message.Abort('close'))
for d in messages[1:]:
proto.data_received(d)
self.assertTrue(client.onMessage.called)
self.assertTrue(isinstance(client.onMessage.call_args[0][0], message.Abort))
# server
transport=Mock(spec_set=('abort','close','write', 'get_extra_info'))
transport.write=Mock(side_effect=lambda m: messages.append(m))
client=None
server=Mock(spec=['onOpen', 'onMessage'])
def fact_server():
return server
messages=[]
proto=WampRawSocketServerFactory(fact_server)()
proto.connection_made(transport)
self.assertTrue(proto.factory._serializers)
s=proto.factory._serializers[1].RAWSOCKET_SERIALIZER_ID
proto.data_received(bytes(bytearray([0x7F, 0xF0 |s, 0, 0])))
self.assertTrue(proto._serializer)
server.onOpen.assert_called_once_with(proto)
proto.send(message.Abort('close'))
for d in messages[1:]:
proto.data_received(d)
self.assertTrue(server.onMessage.called)
self.assertTrue(isinstance(server.onMessage.call_args[0][0], message.Abort))
if __name__ == "__main__":
#import sys;sys.argv = ['', 'Test.test_prefix']
unittest.main()

74
autobahn/asyncio/util.py Normal file
View File

@ -0,0 +1,74 @@
import binascii
class _LazyHexFormatter(object):
"""
This is used to avoid calling binascii.hexlify() on data given to
log.debug() calls unless debug is active (for example). Like::
self.log.debug(
"Some data: {octets}",
octets=_LazyHexFormatter(os.urandom(32)),
)
"""
__slots__ = ('obj',)
def __init__(self, obj):
self.obj = obj
def __str__(self):
return binascii.hexlify(self.obj).decode('ascii')
def peer2str(peer):
if isinstance(peer, tuple):
ip_ver=4 if len(peer)==2 else 6
return "tcp{2}:{0}:{1}".format(peer[0], peer[1], ip_ver)
elif isinstance(peer, str):
return "unix:{0}".format(peer)
else:
return "?:{0}".format(peer)
def get_serializes():
from autobahn.wamp import serializer
serializers=['CBORSerializer', 'MsgPackSerializer', 'UBJSONSerializer', 'JsonSerializer']
serializers=list(filter(lambda x:x, map(lambda s: getattr(serializer, s) if hasattr(serializer,s) else None,
serializers )))
return serializers
#TODO - check and modify for asyncio transport
def transport_channel_id(transport, is_server, channel_id_type):
"""
Application-layer user authentication protocols are vulnerable to generic
credential forwarding attacks, where an authentication credential sent by
a client C to a server M may then be used by M to impersonate C at another
server S. To prevent such credential forwarding attacks, modern authentication
protocols rely on channel bindings. For example, WAMP-cryptosign can use
the tls-unique channel identifier provided by the TLS layer to strongly bind
authentication credentials to the underlying channel, so that a credential
received on one TLS channel cannot be forwarded on another.
"""
if channel_id_type is None:
return None
if channel_id_type not in [u'tls-unique']:
raise Exception("invalid channel ID type {}".format(channel_id_type))
if hasattr(transport, '_tlsConnection'):
# Obtain latest TLS Finished message that we expected from peer, or None if handshake is not completed.
# http://www.pyopenssl.org/en/stable/api/ssl.html#OpenSSL.SSL.Connection.get_peer_finished
if is_server:
# for routers (=servers), the channel ID is based on the TLS Finished message we
# expected to receive from the client
tls_finished_msg = transport._tlsConnection.get_peer_finished()
else:
# for clients, the channel ID is based on the TLS Finished message we sent
# to the router (=server)
tls_finished_msg = transport._tlsConnection.get_finished()
m = hashlib.sha256()
m.update(tls_finished_msg)
return m.digest()
else:
return None

View File

@ -172,3 +172,109 @@ class ApplicationRunner(object):
loop.run_until_complete(protocol._session.leave())
loop.close()
#TODO - unify with previous class
class ApplicationRunnerRawSocket(object):
"""
This class is a convenience tool mainly for development and quick hosting
of WAMP application components.
It can host a WAMP application component in a WAMP-over-WebSocket client
connecting to a WAMP router.
"""
def __init__(self, url, realm, extra=None, serializer=None):
"""
:param url: Raw socket unicode - either path on local server to unix socket (or unix:/path)
or tcp://host:port for internet socket
:type url: unicode
:param realm: The WAMP realm to join the application session to.
:type realm: unicode
:param extra: Optional extra configuration to forward to the application component.
:type extra: dict
:param serializer: WAMP serializer to use (or None for default serializer).
:type serializer: `autobahn.wamp.interfaces.ISerializer`
"""
assert(type(url) == six.text_type)
assert(type(realm) == six.text_type)
assert(extra is None or type(extra) == dict)
self.url = url
self.realm = realm
self.extra = extra or dict()
self.serializer = serializer
def run(self, make, logging_level='info'):
"""
Run the application component.
:param make: A factory that produces instances of :class:`autobahn.asyncio.wamp.ApplicationSession`
when called with an instance of :class:`autobahn.wamp.types.ComponentConfig`.
:type make: callable
"""
#make imports local for now not to infere with rest of module
from six.moves.urllib.parse import urlparse
from autobahn.asyncio.rawsocket import WampRawSocketClientFactory
def create():
cfg = ComponentConfig(self.realm, self.extra)
try:
session = make(cfg)
except Exception:
self.log.failure("App session could not be created! ")
asyncio.get_event_loop().stop()
else:
return session
parsed_url=urlparse(self.url)
if parsed_url.scheme=='tcp':
is_unix=False
if not parsed_url.hostname or not parsed_url.port:
raise ValueError('Host and port is required in URL')
elif parsed_url.scheme=='unix' or parsed_url.scheme=='':
is_unix=True
if not parsed_url.path:
raise ValueError('Path to unix socket must be in URL')
transport_factory = WampRawSocketClientFactory(create, serializer=self.serializer)
# 3) start the client
loop = asyncio.get_event_loop()
txaio.use_asyncio()
txaio.config.loop = loop
if is_unix:
coro = loop.create_unix_connection(transport_factory, parsed_url.path)
else:
coro = loop.create_connection(transport_factory, parsed_url.hostname, parsed_url.port)
(transport, protocol) = loop.run_until_complete(coro)
txaio.start_logging(level=logging_level)
try:
loop.add_signal_handler(signal.SIGTERM, loop.stop)
except NotImplementedError:
# signals are not available on Windows
pass
try:
loop.run_forever()
except KeyboardInterrupt:
pass
# give Goodbye message a chance to go through, if we still
# have an active session
if protocol._session:
loop.run_until_complete(protocol._session.leave())
loop.close()