Looks like it works
This commit is contained in:
parent
1c86fefa73
commit
5a72a46275
4
.gitignore
vendored
4
.gitignore
vendored
@ -30,3 +30,7 @@ htmlcov/
|
||||
*.pid
|
||||
.cache/
|
||||
node.key
|
||||
/.project
|
||||
/.pydevproject
|
||||
/.venv/
|
||||
/my_tests/
|
||||
|
440
autobahn/asyncio/rawsocket.py
Normal file
440
autobahn/asyncio/rawsocket.py
Normal file
@ -0,0 +1,440 @@
|
||||
#from __future__ import absolute_import
|
||||
|
||||
import asyncio #TODO: should we support trollious - it has been deprecated - http://trollius.readthedocs.io/deprecated.htm
|
||||
import struct
|
||||
import math
|
||||
from autobahn.asyncio.util import _LazyHexFormatter
|
||||
from autobahn.wamp.exception import ProtocolError, SerializationError, TransportLost
|
||||
from autobahn.asyncio.util import peer2str, transport_channel_id, get_serializes
|
||||
import txaio
|
||||
|
||||
__all__ = (
|
||||
'WampRawSocketServerProtocol',
|
||||
'WampRawSocketClientProtocol',
|
||||
'WampRawSocketServerFactory',
|
||||
'WampRawSocketClientFactory'
|
||||
)
|
||||
|
||||
txaio.use_asyncio()
|
||||
|
||||
FRAME_TYPE_DATA=0
|
||||
FRAME_TYPE_PING=1
|
||||
FRAME_TYPE_PONG=2
|
||||
|
||||
class PrefixProtocol(asyncio.Protocol):
|
||||
|
||||
prefix_format='!L'
|
||||
prefix_length=struct.calcsize(prefix_format)
|
||||
max_length = 16 * 1024 *1024
|
||||
max_length_send = max_length
|
||||
log=txaio.make_logger() # @UndefinedVariable
|
||||
|
||||
def connection_made(self, transport):
|
||||
self.transport=transport
|
||||
peer = transport.get_extra_info('peername')
|
||||
self.peer = peer2str(peer)
|
||||
self.log.debug('RawSocker Asyncio: Connection made with peer {peer}'.format(peer=self.peer))
|
||||
self._buffer=b''
|
||||
self._header=None
|
||||
|
||||
def connection_lost(self, exc):
|
||||
self.log.debug('RawSocker Asyncio: Connection lost')
|
||||
self.transport=None
|
||||
|
||||
def protocol_error(self, msg):
|
||||
self.log.error(msg)
|
||||
self.transport.close()
|
||||
|
||||
def sendString(self, data):
|
||||
l=len(data)
|
||||
if l>self.max_length_send:
|
||||
raise ValueError('Data too big')
|
||||
header=struct.pack(self.prefix_format, len(data))
|
||||
self.transport.write(header)
|
||||
self.transport.write(data)
|
||||
|
||||
def ping(self, data):
|
||||
raise NotImplementedError()
|
||||
|
||||
def pong(self,data):
|
||||
raise NotImplementedError()
|
||||
|
||||
def data_received(self, data):
|
||||
self._buffer+=data
|
||||
pos=0
|
||||
remaining=len(self._buffer)
|
||||
while remaining >= self.prefix_length:
|
||||
# do not recalculate header if available from previous call
|
||||
if self._header:
|
||||
frame_type, frame_length = self._header
|
||||
else:
|
||||
header=self._buffer[pos:pos+self.prefix_length]
|
||||
frame_type=header[0]& 0b00000111
|
||||
if frame_type > FRAME_TYPE_PONG:
|
||||
self.protocol_error('Invalid frame type')
|
||||
return
|
||||
frame_length= struct.unpack(self.prefix_format, b'\0'+header[1:])[0]
|
||||
if frame_length> self.max_length:
|
||||
self.protocol_error('Frame too big')
|
||||
return
|
||||
|
||||
|
||||
if remaining-self.prefix_length >= frame_length:
|
||||
self._header=None
|
||||
pos+=self.prefix_length
|
||||
remaining-=self.prefix_length
|
||||
data=self._buffer[pos:pos+frame_length]
|
||||
pos+=frame_length
|
||||
remaining-=frame_length
|
||||
|
||||
if frame_type == FRAME_TYPE_DATA:
|
||||
self.stringReceived(data)
|
||||
elif frame_type == FRAME_TYPE_PING:
|
||||
self.ping(data)
|
||||
elif frame_type == FRAME_TYPE_PONG:
|
||||
self.pong(data)
|
||||
else:
|
||||
# save heaader
|
||||
self._header = frame_type, frame_length
|
||||
break
|
||||
|
||||
self._buffer=self._buffer[:remaining]
|
||||
|
||||
def stringReceived(self, data):
|
||||
raise NotImplementedError()
|
||||
|
||||
|
||||
class RawSocketProtocol(PrefixProtocol):
|
||||
|
||||
def __init__(self, max_size=None):
|
||||
if max_size:
|
||||
exp=math.ceil(math.log2(max_size))-9
|
||||
if exp>15:
|
||||
raise ValueError('Maximum length is 16M')
|
||||
self.max_length=2**(exp+9)
|
||||
self._length_exp=exp
|
||||
else:
|
||||
self._length_exp=15
|
||||
self.max_length=2**24
|
||||
|
||||
def connection_made(self, transport):
|
||||
PrefixProtocol.connection_made(self, transport)
|
||||
self._handshake_done=False
|
||||
|
||||
def _on_handshake_complete(self):
|
||||
raise NotImplementedError()
|
||||
|
||||
def parse_handshake(self):
|
||||
if self._buffer[0] != 0x7F:
|
||||
raise HandshakeError('Invalid magic byte in handshake')
|
||||
return
|
||||
b1=self._buffer[1]
|
||||
ser=b1 & 0x0F
|
||||
lexp=b1>>4
|
||||
self.max_length_send=2**((lexp)+9)
|
||||
if self._buffer[2] !=0 or self._buffer[3]!=0:
|
||||
raise HandshakeError('Reserved bytes must be zero')
|
||||
return ser, lexp
|
||||
|
||||
def process_handshake(self):
|
||||
raise NotImplementedError()
|
||||
|
||||
def data_received(self, data):
|
||||
self.log.debug('RawSocker Asyncio: data received {data}', data=_LazyHexFormatter(data))
|
||||
if self._handshake_done:
|
||||
return PrefixProtocol.data_received(self, data)
|
||||
else:
|
||||
self._buffer+=data
|
||||
if len(self._buffer)>=4:
|
||||
try:
|
||||
self.process_handshake()
|
||||
except HandshakeError as e:
|
||||
self.protocol_error('Handshake error : {err}'.format(err=e))
|
||||
return
|
||||
self._handshake_done=True
|
||||
self._on_handshake_complete()
|
||||
data=self._buffer[4:]
|
||||
self._buffer=b''
|
||||
if data:
|
||||
PrefixProtocol.data_received(self, data)
|
||||
|
||||
ERR_SERIALIZER_UNSUPPORTED=1
|
||||
ERRMAP = {
|
||||
0: "illegal (must not be used)",
|
||||
1: "serializer unsupported",
|
||||
2: "maximum message length unacceptable",
|
||||
3: "use of reserved bits (unsupported feature)",
|
||||
4: "maximum connection count reached"
|
||||
}
|
||||
|
||||
|
||||
|
||||
class HandshakeError(Exception):
|
||||
def __init__(self,msg,code=0):
|
||||
Exception.__init__(self, msg if not code else msg+' : %s' % ERRMAP.get(code))
|
||||
|
||||
|
||||
class RawSocketClientProtocol(RawSocketProtocol):
|
||||
|
||||
def __init__(self, max_size=None):
|
||||
RawSocketProtocol.__init__(self, max_size=max_size)
|
||||
|
||||
def check_serializer(self, ser_id):
|
||||
|
||||
return True
|
||||
|
||||
def process_handshake(self):
|
||||
ser_id,err=self.parse_handshake()
|
||||
if ser_id == 0:
|
||||
raise HandshakeError('Server returned handshake error', err)
|
||||
if self.serializer_id != ser_id:
|
||||
raise HandshakeError('Server returned different serializer {0} then requested {1}'.format(ser_id, self.serializer_id))
|
||||
|
||||
|
||||
@property
|
||||
def serializer_id(self):
|
||||
raise NotImplementedError()
|
||||
|
||||
def connection_made(self, transport):
|
||||
RawSocketProtocol.connection_made(self, transport)
|
||||
#start handsjake
|
||||
hs=bytes([0x7F,
|
||||
self._length_exp << 4 | self.serializer_id,
|
||||
0, 0])
|
||||
transport.write(hs)
|
||||
self.log.debug('RawSocket Asyncio: Client handshake sent')
|
||||
|
||||
class RawSocketServerProtocol(RawSocketProtocol):
|
||||
|
||||
def __init__(self, max_size=None):
|
||||
RawSocketProtocol.__init__(self, max_size=max_size)
|
||||
|
||||
def supports_serializer(self,ser_id):
|
||||
raise NotImplementedError()
|
||||
|
||||
def process_handshake(self):
|
||||
def send_response(lexp,ser_id):
|
||||
b2=lexp<<4 | (ser_id & 0x0f)
|
||||
self.transport.write(bytes( bytearray([0x7F, b2, 0, 0])))
|
||||
ser_id,lexp=self.parse_handshake()
|
||||
if not self.supports_serializer(ser_id):
|
||||
send_response(ERR_SERIALIZER_UNSUPPORTED, 0)
|
||||
raise HandshakeError('Serializer unsupported : {ser_id}'.format(ser_id=ser_id))
|
||||
send_response(self._length_exp, ser_id)
|
||||
|
||||
|
||||
# this is transport independent part of WAMP protocol
|
||||
class WampRawSocketMixinGeneral(object):
|
||||
|
||||
def _on_handshake_complete(self):
|
||||
self.log.debug("WampRawSocketProtocol: Handshake complete")
|
||||
try:
|
||||
self._session = self.factory._factory()
|
||||
self._session.onOpen(self)
|
||||
except Exception as e:
|
||||
# Exceptions raised in onOpen are fatal ..
|
||||
self.log.warn("WampRawSocketProtocol: ApplicationSession constructor / onOpen raised ({err})", err=e)
|
||||
self.abort()
|
||||
else:
|
||||
self.log.info("ApplicationSession started.")
|
||||
|
||||
def stringReceived(self, payload):
|
||||
self.log.debug("WampRawSocketProtocol: RX octets: {octets}", octets=_LazyHexFormatter(payload))
|
||||
try:
|
||||
for msg in self._serializer.unserialize(payload):
|
||||
self.log.debug("WampRawSocketProtocol: RX WAMP message: {msg}", msg=msg)
|
||||
self._session.onMessage(msg)
|
||||
|
||||
except ProtocolError as e:
|
||||
self.log.warn("WampRawSocketProtocol: WAMP Protocol Error ({err}) - aborting connection", err=e)
|
||||
self.abort()
|
||||
|
||||
except Exception as e:
|
||||
self.log.warn("WampRawSocketProtocol: WAMP Internal Error ({err}) - aborting connection", err=e)
|
||||
self.abort()
|
||||
|
||||
def send(self, msg):
|
||||
"""
|
||||
Implements :func:`autobahn.wamp.interfaces.ITransport.send`
|
||||
"""
|
||||
if self.isOpen():
|
||||
self.log.debug("WampRawSocketProtocol: TX WAMP message: {msg}", msg=msg)
|
||||
try:
|
||||
payload, _ = self._serializer.serialize(msg)
|
||||
except Exception as e:
|
||||
# all exceptions raised from above should be serialization errors ..
|
||||
raise SerializationError("WampRawSocketProtocol: unable to serialize WAMP application payload ({0})".format(e))
|
||||
else:
|
||||
self.sendString(payload)
|
||||
self.log.debug("WampRawSocketProtocol: TX octets: {octets}", octets=_LazyHexFormatter(payload))
|
||||
else:
|
||||
raise TransportLost()
|
||||
|
||||
def isOpen(self):
|
||||
"""
|
||||
Implements :func:`autobahn.wamp.interfaces.ITransport.isOpen`
|
||||
"""
|
||||
return hasattr(self, '_session') and self._session is not None
|
||||
|
||||
|
||||
# this is asyncio dependent part of WAMP protocol
|
||||
class WampRawSocketMixinAsyncio():
|
||||
"""
|
||||
Base class for asyncio-based WAMP-over-RawSocket protocols.
|
||||
"""
|
||||
|
||||
def connection_lost(self, exc):
|
||||
|
||||
try:
|
||||
wasClean = exc is None
|
||||
self._session.onClose(wasClean)
|
||||
except Exception as e:
|
||||
# silently ignore exceptions raised here ..
|
||||
self.log.warn("WampRawSocketProtocol: ApplicationSession.onClose raised ({err})", err=e)
|
||||
self._session = None
|
||||
|
||||
def close(self):
|
||||
"""
|
||||
Implements :func:`autobahn.wamp.interfaces.ITransport.close`
|
||||
"""
|
||||
if self.isOpen():
|
||||
self.transport.close()
|
||||
else:
|
||||
raise TransportLost()
|
||||
|
||||
def abort(self):
|
||||
"""
|
||||
Implements :func:`autobahn.wamp.interfaces.ITransport.abort`
|
||||
"""
|
||||
if self.isOpen():
|
||||
if hasattr(self.transport, 'abort'):
|
||||
# ProcessProtocol lacks abortConnection()
|
||||
self.transport.abort()
|
||||
else:
|
||||
self.transport.close()
|
||||
else:
|
||||
raise TransportLost()
|
||||
|
||||
|
||||
|
||||
class WampRawSocketServerProtocol(WampRawSocketMixinGeneral, WampRawSocketMixinAsyncio, RawSocketServerProtocol):
|
||||
"""
|
||||
Base class for Twisted-based WAMP-over-RawSocket server protocols.
|
||||
"""
|
||||
|
||||
def supports_serializer(self, ser_id):
|
||||
if ser_id in self.factory._serializers:
|
||||
self._serializer = self.factory._serializers[ser_id]()
|
||||
self.log.debug(
|
||||
"WampRawSocketProtocol: client wants to use serializer '{serializer}'",
|
||||
serializer=ser_id,
|
||||
)
|
||||
return True
|
||||
else:
|
||||
self.log.debug(
|
||||
"WampRawSocketProtocol: opening handshake - no suitable serializer found (client requested {serializer}, and we have {serializers}",
|
||||
serializer=ser_id,
|
||||
serializers=self.factory._serializers.keys(),
|
||||
)
|
||||
self.abort()
|
||||
return False
|
||||
|
||||
|
||||
|
||||
def get_channel_id(self, channel_id_type=u'tls-unique'):
|
||||
"""
|
||||
Implements :func:`autobahn.wamp.interfaces.ITransport.get_channel_id`
|
||||
"""
|
||||
return transport_channel_id(self.transport, is_server=True, channel_id_type=channel_id_type)
|
||||
|
||||
|
||||
class WampRawSocketClientProtocol(WampRawSocketMixinGeneral, WampRawSocketMixinAsyncio, RawSocketClientProtocol):
|
||||
"""
|
||||
Base class for Twisted-based WAMP-over-RawSocket client protocols.
|
||||
"""
|
||||
@property
|
||||
def serializer_id(self):
|
||||
if not hasattr(self, '_serializer'):
|
||||
self._serializer=self.factory._serializer()
|
||||
return self._serializer.RAWSOCKET_SERIALIZER_ID
|
||||
|
||||
|
||||
def get_channel_id(self, channel_id_type=u'tls-unique'):
|
||||
"""
|
||||
Implements :func:`autobahn.wamp.interfaces.ITransport.get_channel_id`
|
||||
"""
|
||||
return transport_channel_id(self.transport, is_server=False, channel_id_type=channel_id_type)
|
||||
|
||||
|
||||
class WampRawSocketFactory(object):
|
||||
"""
|
||||
Adapter class for asyncio-based WebSocket client and server factories.def dataReceived(self, data):
|
||||
"""
|
||||
|
||||
log = txaio.make_logger() # @UndefinedVariable
|
||||
|
||||
def __call__(self):
|
||||
proto = self.protocol()
|
||||
proto.factory = self
|
||||
return proto
|
||||
|
||||
|
||||
class WampRawSocketServerFactory(WampRawSocketFactory):
|
||||
"""
|
||||
Base class for Twisted-based WAMP-over-RawSocket server factories.
|
||||
"""
|
||||
protocol = WampRawSocketServerProtocol
|
||||
|
||||
def __init__(self, factory, serializers=None):
|
||||
"""
|
||||
|
||||
:param factory: A callable that produces instances that implement
|
||||
:class:`autobahn.wamp.interfaces.ITransportHandler`
|
||||
:type factory: callable
|
||||
:param serializers: A list of WAMP serializers to use (or None for default
|
||||
serializers). Serializers must implement
|
||||
:class:`autobahn.wamp.interfaces.ISerializer`.
|
||||
:type serializers: list
|
||||
"""
|
||||
assert(callable(factory))
|
||||
self._factory = factory
|
||||
|
||||
if serializers is None:
|
||||
serializers = get_serializes()
|
||||
|
||||
if not serializers:
|
||||
raise Exception("could not import any WAMP serializers")
|
||||
|
||||
self._serializers = {ser.RAWSOCKET_SERIALIZER_ID:ser for ser in serializers}
|
||||
|
||||
|
||||
|
||||
class WampRawSocketClientFactory(WampRawSocketFactory):
|
||||
"""
|
||||
Base class for Twisted-based WAMP-over-RawSocket client factories.
|
||||
"""
|
||||
protocol = WampRawSocketClientProtocol
|
||||
|
||||
def __init__(self, factory, serializer=None):
|
||||
"""
|
||||
|
||||
:param factory: A callable that produces instances that implement
|
||||
:class:`autobahn.wamp.interfaces.ITransportHandler`
|
||||
:type factory: callable
|
||||
:param serializer: The WAMP serializer to use (or None for default
|
||||
serializer). Serializers must implement
|
||||
:class:`autobahn.wamp.interfaces.ISerializer`.
|
||||
:type serializer: obj
|
||||
"""
|
||||
assert(callable(factory))
|
||||
self._factory = factory
|
||||
|
||||
if serializer is None:
|
||||
serializers=get_serializes()
|
||||
if serializers:
|
||||
serializer=serializers[0]
|
||||
|
||||
if serializer is None:
|
||||
raise Exception("could not import any WAMP serializer")
|
||||
|
||||
self._serializer = serializer
|
205
autobahn/asyncio/test/test_rawsocket.py
Normal file
205
autobahn/asyncio/test/test_rawsocket.py
Normal file
@ -0,0 +1,205 @@
|
||||
'''
|
||||
Created on May 19, 2016
|
||||
|
||||
@author: ivan
|
||||
'''
|
||||
from unittest import TestCase
|
||||
from unittest.mock import Mock, call, patch
|
||||
from autobahn.asyncio.rawsocket import PrefixProtocol, RawSocketClientProtocol, RawSocketServerProtocol, \
|
||||
WampRawSocketClientFactory, WampRawSocketServerFactory
|
||||
from autobahn.asyncio.util import get_serializes
|
||||
from autobahn.wamp import message
|
||||
|
||||
|
||||
class Test(TestCase):
|
||||
|
||||
def test_sers(self):
|
||||
serializers=get_serializes()
|
||||
self.assertTrue(len(serializers)>0)
|
||||
m=serializers[0]().serialize(message.Abort('close'))
|
||||
print(m)
|
||||
self.assertTrue(m)
|
||||
|
||||
def test_prefix(self):
|
||||
p=PrefixProtocol()
|
||||
transport = Mock()
|
||||
receiver= Mock()
|
||||
p.stringReceived=receiver
|
||||
p.connection_made(transport)
|
||||
small_msg=b'\x00\x00\x00\x04abcd'
|
||||
p.data_received(small_msg)
|
||||
receiver.assert_called_once_with(b'abcd')
|
||||
self.assertEqual(len(p._buffer), 0)
|
||||
|
||||
p.sendString(b'abcd')
|
||||
|
||||
#print(transport.write.call_args_list)
|
||||
transport.write.assert_has_calls([call(b'\x00\x00\x00\x04'), call(b'abcd')])
|
||||
|
||||
transport.reset_mock()
|
||||
receiver.reset_mock()
|
||||
|
||||
big_msg=b'\x00\x00\x00\x0C'+b'0123456789AB'
|
||||
|
||||
p.data_received(big_msg[0:2])
|
||||
self.assertFalse(receiver.called)
|
||||
|
||||
p.data_received(big_msg[2:6])
|
||||
self.assertFalse(receiver.called)
|
||||
|
||||
p.data_received(big_msg[6:11])
|
||||
self.assertFalse(receiver.called)
|
||||
|
||||
p.data_received(big_msg[11:16])
|
||||
receiver.assert_called_once_with(b'0123456789AB')
|
||||
|
||||
transport.reset_mock()
|
||||
receiver.reset_mock()
|
||||
|
||||
two_messages = b'\x00\x00\x00\x04'+b'abcd'+b'\x00\x00\x00\x05' +b'12345' +b'\x00'
|
||||
p.data_received(two_messages)
|
||||
receiver.assert_has_calls([call(b'abcd'), call(b'12345')])
|
||||
self.assertEqual(p._buffer, b'\x00' )
|
||||
|
||||
|
||||
def test_raw_socket_server1(self):
|
||||
|
||||
server=RawSocketServerProtocol(max_size=10000)
|
||||
ser=Mock(return_value=True)
|
||||
on_hs=Mock()
|
||||
transport=Mock()
|
||||
receiver=Mock()
|
||||
server.supports_serializer=ser
|
||||
server.stringReceived=receiver
|
||||
server._on_handshake_complete=on_hs
|
||||
server.stringReceived=receiver
|
||||
|
||||
server.connection_made(transport)
|
||||
hs=b'\x7F\xF1\x00\x00'+b'\x00\x00\x00\x04abcd'
|
||||
server.data_received(hs)
|
||||
|
||||
ser.assert_called_once_with(1)
|
||||
on_hs.assert_called_once_with()
|
||||
self.assertTrue(transport.write.called)
|
||||
transport.write.assert_called_once_with(b'\x7F\x51\x00\x00')
|
||||
self.assertFalse(transport.close.called)
|
||||
receiver.assert_called_once_with(b'abcd')
|
||||
|
||||
def test_raw_socket_server_errors(self):
|
||||
|
||||
server=RawSocketServerProtocol(max_size=10000)
|
||||
ser=Mock(return_value=True)
|
||||
on_hs=Mock()
|
||||
transport=Mock()
|
||||
receiver=Mock()
|
||||
server.supports_serializer=ser
|
||||
server.stringReceived=receiver
|
||||
server._on_handshake_complete=on_hs
|
||||
server.stringReceived=receiver
|
||||
server.connection_made(transport)
|
||||
server.data_received(b'abcdef')
|
||||
transport.close.assert_called_once_with()
|
||||
|
||||
server=RawSocketServerProtocol(max_size=10000)
|
||||
ser=Mock(return_value=False)
|
||||
on_hs=Mock()
|
||||
transport=Mock(spec_set=('close','write', 'get_extra_info'))
|
||||
receiver=Mock()
|
||||
server.supports_serializer=ser
|
||||
server.stringReceived=receiver
|
||||
server._on_handshake_complete=on_hs
|
||||
server.stringReceived=receiver
|
||||
server.connection_made(transport)
|
||||
server.data_received(b'\x7F\xF1\x00\x00')
|
||||
transport.close.assert_called_once_with()
|
||||
transport.write.assert_called_once_with(b'\x7F\x10\x00\x00')
|
||||
|
||||
|
||||
|
||||
def test_raw_socket_client1(self):
|
||||
class CP(RawSocketClientProtocol):
|
||||
@property
|
||||
def serializer_id(self):
|
||||
return 1
|
||||
client=CP()
|
||||
|
||||
on_hs=Mock()
|
||||
transport=Mock()
|
||||
receiver=Mock()
|
||||
client.stringReceived=receiver
|
||||
client._on_handshake_complete=on_hs
|
||||
|
||||
client.connection_made(transport)
|
||||
client.data_received(b'\x7F\xF1\x00\x00'+b'\x00\x00\x00\x04abcd')
|
||||
on_hs.assert_called_once_with()
|
||||
self.assertTrue(transport.write.called)
|
||||
transport.write.called_one_with(b'\x7F\xF1\x00\x00')
|
||||
self.assertFalse(transport.close.called)
|
||||
receiver.assert_called_once_with(b'abcd')
|
||||
|
||||
def test_raw_socket_client_error(self):
|
||||
class CP(RawSocketClientProtocol):
|
||||
@property
|
||||
def serializer_id(self):
|
||||
return 1
|
||||
client=CP()
|
||||
|
||||
on_hs=Mock()
|
||||
transport=Mock(spec_set=('close','write', 'get_extra_info'))
|
||||
receiver=Mock()
|
||||
client.stringReceived=receiver
|
||||
client._on_handshake_complete=on_hs
|
||||
client.connection_made(transport)
|
||||
|
||||
client.data_received(b'\x7F\xF1\x00\x01')
|
||||
transport.close.assert_called_once_with()
|
||||
|
||||
|
||||
def test_wamp(self):
|
||||
transport=Mock(spec_set=('abort','close','write', 'get_extra_info'))
|
||||
transport.write=Mock(side_effect=lambda m: messages.append(m))
|
||||
client=Mock(spec=['onOpen', 'onMessage'])
|
||||
|
||||
def fact():
|
||||
return client
|
||||
messages=[]
|
||||
proto=WampRawSocketClientFactory(fact)()
|
||||
proto.connection_made(transport)
|
||||
self.assertTrue(proto._serializer)
|
||||
s=proto._serializer.RAWSOCKET_SERIALIZER_ID
|
||||
proto.data_received(bytes(bytearray([0x7F, 0xF0 |s, 0, 0])))
|
||||
client.onOpen.assert_called_once_with(proto)
|
||||
proto.send(message.Abort('close'))
|
||||
for d in messages[1:]:
|
||||
proto.data_received(d)
|
||||
self.assertTrue(client.onMessage.called)
|
||||
self.assertTrue(isinstance(client.onMessage.call_args[0][0], message.Abort))
|
||||
|
||||
# server
|
||||
transport=Mock(spec_set=('abort','close','write', 'get_extra_info'))
|
||||
transport.write=Mock(side_effect=lambda m: messages.append(m))
|
||||
|
||||
client=None
|
||||
server=Mock(spec=['onOpen', 'onMessage'])
|
||||
|
||||
def fact_server():
|
||||
return server
|
||||
messages=[]
|
||||
proto=WampRawSocketServerFactory(fact_server)()
|
||||
proto.connection_made(transport)
|
||||
self.assertTrue(proto.factory._serializers)
|
||||
s=proto.factory._serializers[1].RAWSOCKET_SERIALIZER_ID
|
||||
proto.data_received(bytes(bytearray([0x7F, 0xF0 |s, 0, 0])))
|
||||
self.assertTrue(proto._serializer)
|
||||
server.onOpen.assert_called_once_with(proto)
|
||||
proto.send(message.Abort('close'))
|
||||
for d in messages[1:]:
|
||||
proto.data_received(d)
|
||||
self.assertTrue(server.onMessage.called)
|
||||
self.assertTrue(isinstance(server.onMessage.call_args[0][0], message.Abort))
|
||||
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
#import sys;sys.argv = ['', 'Test.test_prefix']
|
||||
unittest.main()
|
74
autobahn/asyncio/util.py
Normal file
74
autobahn/asyncio/util.py
Normal file
@ -0,0 +1,74 @@
|
||||
import binascii
|
||||
class _LazyHexFormatter(object):
|
||||
"""
|
||||
This is used to avoid calling binascii.hexlify() on data given to
|
||||
log.debug() calls unless debug is active (for example). Like::
|
||||
|
||||
self.log.debug(
|
||||
"Some data: {octets}",
|
||||
octets=_LazyHexFormatter(os.urandom(32)),
|
||||
)
|
||||
"""
|
||||
__slots__ = ('obj',)
|
||||
|
||||
def __init__(self, obj):
|
||||
self.obj = obj
|
||||
|
||||
def __str__(self):
|
||||
return binascii.hexlify(self.obj).decode('ascii')
|
||||
|
||||
def peer2str(peer):
|
||||
if isinstance(peer, tuple):
|
||||
ip_ver=4 if len(peer)==2 else 6
|
||||
return "tcp{2}:{0}:{1}".format(peer[0], peer[1], ip_ver)
|
||||
elif isinstance(peer, str):
|
||||
return "unix:{0}".format(peer)
|
||||
else:
|
||||
return "?:{0}".format(peer)
|
||||
|
||||
def get_serializes():
|
||||
from autobahn.wamp import serializer
|
||||
|
||||
serializers=['CBORSerializer', 'MsgPackSerializer', 'UBJSONSerializer', 'JsonSerializer']
|
||||
serializers=list(filter(lambda x:x, map(lambda s: getattr(serializer, s) if hasattr(serializer,s) else None,
|
||||
serializers )))
|
||||
return serializers
|
||||
|
||||
#TODO - check and modify for asyncio transport
|
||||
def transport_channel_id(transport, is_server, channel_id_type):
|
||||
"""
|
||||
Application-layer user authentication protocols are vulnerable to generic
|
||||
credential forwarding attacks, where an authentication credential sent by
|
||||
a client C to a server M may then be used by M to impersonate C at another
|
||||
server S. To prevent such credential forwarding attacks, modern authentication
|
||||
protocols rely on channel bindings. For example, WAMP-cryptosign can use
|
||||
the tls-unique channel identifier provided by the TLS layer to strongly bind
|
||||
authentication credentials to the underlying channel, so that a credential
|
||||
received on one TLS channel cannot be forwarded on another.
|
||||
|
||||
"""
|
||||
if channel_id_type is None:
|
||||
return None
|
||||
|
||||
if channel_id_type not in [u'tls-unique']:
|
||||
raise Exception("invalid channel ID type {}".format(channel_id_type))
|
||||
|
||||
if hasattr(transport, '_tlsConnection'):
|
||||
# Obtain latest TLS Finished message that we expected from peer, or None if handshake is not completed.
|
||||
# http://www.pyopenssl.org/en/stable/api/ssl.html#OpenSSL.SSL.Connection.get_peer_finished
|
||||
|
||||
if is_server:
|
||||
# for routers (=servers), the channel ID is based on the TLS Finished message we
|
||||
# expected to receive from the client
|
||||
tls_finished_msg = transport._tlsConnection.get_peer_finished()
|
||||
else:
|
||||
# for clients, the channel ID is based on the TLS Finished message we sent
|
||||
# to the router (=server)
|
||||
tls_finished_msg = transport._tlsConnection.get_finished()
|
||||
|
||||
m = hashlib.sha256()
|
||||
m.update(tls_finished_msg)
|
||||
return m.digest()
|
||||
|
||||
else:
|
||||
return None
|
@ -172,3 +172,109 @@ class ApplicationRunner(object):
|
||||
loop.run_until_complete(protocol._session.leave())
|
||||
|
||||
loop.close()
|
||||
|
||||
|
||||
#TODO - unify with previous class
|
||||
class ApplicationRunnerRawSocket(object):
|
||||
"""
|
||||
This class is a convenience tool mainly for development and quick hosting
|
||||
of WAMP application components.
|
||||
|
||||
It can host a WAMP application component in a WAMP-over-WebSocket client
|
||||
connecting to a WAMP router.
|
||||
"""
|
||||
|
||||
def __init__(self, url, realm, extra=None, serializer=None):
|
||||
"""
|
||||
:param url: Raw socket unicode - either path on local server to unix socket (or unix:/path)
|
||||
or tcp://host:port for internet socket
|
||||
:type url: unicode
|
||||
|
||||
:param realm: The WAMP realm to join the application session to.
|
||||
:type realm: unicode
|
||||
|
||||
:param extra: Optional extra configuration to forward to the application component.
|
||||
:type extra: dict
|
||||
|
||||
:param serializer: WAMP serializer to use (or None for default serializer).
|
||||
:type serializer: `autobahn.wamp.interfaces.ISerializer`
|
||||
|
||||
|
||||
"""
|
||||
assert(type(url) == six.text_type)
|
||||
assert(type(realm) == six.text_type)
|
||||
assert(extra is None or type(extra) == dict)
|
||||
self.url = url
|
||||
self.realm = realm
|
||||
self.extra = extra or dict()
|
||||
self.serializer = serializer
|
||||
|
||||
|
||||
def run(self, make, logging_level='info'):
|
||||
"""
|
||||
Run the application component.
|
||||
|
||||
:param make: A factory that produces instances of :class:`autobahn.asyncio.wamp.ApplicationSession`
|
||||
when called with an instance of :class:`autobahn.wamp.types.ComponentConfig`.
|
||||
:type make: callable
|
||||
"""
|
||||
|
||||
#make imports local for now not to infere with rest of module
|
||||
from six.moves.urllib.parse import urlparse
|
||||
from autobahn.asyncio.rawsocket import WampRawSocketClientFactory
|
||||
|
||||
def create():
|
||||
cfg = ComponentConfig(self.realm, self.extra)
|
||||
try:
|
||||
session = make(cfg)
|
||||
except Exception:
|
||||
self.log.failure("App session could not be created! ")
|
||||
asyncio.get_event_loop().stop()
|
||||
else:
|
||||
return session
|
||||
|
||||
parsed_url=urlparse(self.url)
|
||||
|
||||
if parsed_url.scheme=='tcp':
|
||||
is_unix=False
|
||||
if not parsed_url.hostname or not parsed_url.port:
|
||||
raise ValueError('Host and port is required in URL')
|
||||
elif parsed_url.scheme=='unix' or parsed_url.scheme=='':
|
||||
is_unix=True
|
||||
if not parsed_url.path:
|
||||
raise ValueError('Path to unix socket must be in URL')
|
||||
|
||||
|
||||
|
||||
transport_factory = WampRawSocketClientFactory(create, serializer=self.serializer)
|
||||
|
||||
# 3) start the client
|
||||
loop = asyncio.get_event_loop()
|
||||
txaio.use_asyncio()
|
||||
txaio.config.loop = loop
|
||||
|
||||
if is_unix:
|
||||
coro = loop.create_unix_connection(transport_factory, parsed_url.path)
|
||||
else:
|
||||
coro = loop.create_connection(transport_factory, parsed_url.hostname, parsed_url.port)
|
||||
(transport, protocol) = loop.run_until_complete(coro)
|
||||
|
||||
txaio.start_logging(level=logging_level)
|
||||
|
||||
try:
|
||||
loop.add_signal_handler(signal.SIGTERM, loop.stop)
|
||||
except NotImplementedError:
|
||||
# signals are not available on Windows
|
||||
pass
|
||||
|
||||
try:
|
||||
loop.run_forever()
|
||||
except KeyboardInterrupt:
|
||||
pass
|
||||
|
||||
# give Goodbye message a chance to go through, if we still
|
||||
# have an active session
|
||||
if protocol._session:
|
||||
loop.run_until_complete(protocol._session.leave())
|
||||
|
||||
loop.close()
|
||||
|
Loading…
Reference in New Issue
Block a user