Initial implementation of python 3 decoder, for CQL protocol
This commit is contained in:
		
							
								
								
									
										0
									
								
								cassandra23/__init__.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										0
									
								
								cassandra23/__init__.py
									
									
									
									
									
										Normal file
									
								
							
							
								
								
									
										409
									
								
								cassandra23/decoder.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										409
									
								
								cassandra23/decoder.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,409 @@ | |||||||
|  | import struct | ||||||
|  | import six | ||||||
|  | from six.moves import range | ||||||
|  | import uuid | ||||||
|  |  | ||||||
|  |  | ||||||
|  | # Low level byte pack and unpack methods. | ||||||
|  | def _make_pack_unpack_field(format): | ||||||
|  |     s = struct.Struct(format) | ||||||
|  |     return ( | ||||||
|  |         s.pack, | ||||||
|  |         lambda b: s.unpack(b)[0] | ||||||
|  |     ) | ||||||
|  |  | ||||||
|  | _header_struct = struct.Struct('!BBBBL') | ||||||
|  | pack_cql_header = _header_struct.pack | ||||||
|  | unpack_cql_header = _header_struct.unpack | ||||||
|  | pack_cql_byte, unpack_cql_byte = _make_pack_unpack_field('!B') | ||||||
|  | pack_cql_int, unpack_cql_int = _make_pack_unpack_field('!i') | ||||||
|  | pack_cql_short, unpack_cql_short = _make_pack_unpack_field('!H') | ||||||
|  |  | ||||||
|  |  | ||||||
|  | # Maximum values for these data types. | ||||||
|  | MAX_INT = 0x7FFFFFFF | ||||||
|  | MAX_SHORT = 0xFFFF | ||||||
|  |  | ||||||
|  |      | ||||||
|  | def read_header(f): | ||||||
|  |     """ | ||||||
|  |     Read a CQL protocol frame header. | ||||||
|  |  | ||||||
|  |     A frame header consists of 4 bytes for the fields version, flags, stream and opcode. This is followed by a 4 | ||||||
|  |     byte length field, reading a total of 8 bytes. | ||||||
|  |  | ||||||
|  |     :returns: tuple consisting of the version, flags, stream, opcode and length fields. | ||||||
|  |  | ||||||
|  |     """ | ||||||
|  |     return unpack_cql_header(f.read(8)) | ||||||
|  |  | ||||||
|  |  | ||||||
|  | def write_header(f, version, flags, stream_id, opcode, length): | ||||||
|  |     """ | ||||||
|  |     Write a CQL protocol frame header. | ||||||
|  |     """ | ||||||
|  |     f.write(pack_cql_header(version, flags, stream_id, opcode, length)) | ||||||
|  |  | ||||||
|  |  | ||||||
|  | def read_byte(f): | ||||||
|  |     return f.read() | ||||||
|  |  | ||||||
|  |  | ||||||
|  | def write_byte(f, v): | ||||||
|  |     f.write(pack_cql_byte(v)) | ||||||
|  |  | ||||||
|  |  | ||||||
|  | def read_int(f): | ||||||
|  |     return unpack_cql_int(f.read(4)) | ||||||
|  |  | ||||||
|  |  | ||||||
|  | def write_int(f, v): | ||||||
|  |     f.write(pack_cql_int(v)) | ||||||
|  |  | ||||||
|  |  | ||||||
|  | def read_short(f): | ||||||
|  |     return unpack_cql_short(f.read(2)) | ||||||
|  |  | ||||||
|  |  | ||||||
|  | def write_short(f, v): | ||||||
|  |     f.write(pack_cql_short(v)) | ||||||
|  |  | ||||||
|  |  | ||||||
|  | def read_string(f): | ||||||
|  |     """ | ||||||
|  |     :returns: Python 3 returns a str; Python 2 returns a unicode string. | ||||||
|  |     """ | ||||||
|  |     n = f.read_short() | ||||||
|  |     return f.read(n).decode('UTF8') | ||||||
|  |  | ||||||
|  |  | ||||||
|  | def write_string(f, v): | ||||||
|  |     # TODO: Should really check that a short string isn't longer than a 2^2. | ||||||
|  |     if isinstance(v, six.text_type): | ||||||
|  |         b = v.encode('UTF8') | ||||||
|  |         write_short(f, len(b)) | ||||||
|  |         f.write(b) | ||||||
|  |     elif isinstance(v, str): | ||||||
|  |         # This assumes that str will be caught by the previous if statement with Python 3. | ||||||
|  |         write_short(f, len(v)) | ||||||
|  |         f.write(v) | ||||||
|  |     else: | ||||||
|  |         write_string(f, str(v)) | ||||||
|  |  | ||||||
|  |  | ||||||
|  | def read_long_string(f): | ||||||
|  |     """ | ||||||
|  |     :returns: Python 3 returns a str; Python 2 returns a unicode string. | ||||||
|  |     """ | ||||||
|  |     n = read_int(f) | ||||||
|  |     return f.read(n).decode('UTF8') | ||||||
|  |  | ||||||
|  |  | ||||||
|  | def write_long_string(f, v): | ||||||
|  |     # TODO: Should really check that a long string isn't longer than a 2^4 / 2. | ||||||
|  |     if isinstance(v, six.text_type): | ||||||
|  |         b = v.encode('UTF8') | ||||||
|  |         write_int(f, len(b)) | ||||||
|  |         f.write(b) | ||||||
|  |     elif isinstance(v, str): | ||||||
|  |         # This assumes that str will be caught by the previous if statement with Python 3. | ||||||
|  |         write_int(f, len(v)) | ||||||
|  |         f.write(v) | ||||||
|  |     else: | ||||||
|  |         write_long_string(f, str(v)) | ||||||
|  |  | ||||||
|  |  | ||||||
|  | def read_uuid(f): | ||||||
|  |     return uuid.UUID(bytes=f.read(16)) | ||||||
|  |  | ||||||
|  |  | ||||||
|  | def write_uuid(f, v): | ||||||
|  |     assert isinstance(v, uuid.UUID) | ||||||
|  |     f.write(v.bytes) | ||||||
|  |  | ||||||
|  |  | ||||||
|  | def read_string_list(f): | ||||||
|  |     n = read_short(f) | ||||||
|  |     return [read_string(f) for _ in range(n)] | ||||||
|  |  | ||||||
|  |  | ||||||
|  | def write_string_list(f, v): | ||||||
|  |     n = len(v) | ||||||
|  |     for idx in range(n): | ||||||
|  |         write_string(f, v[idx]) | ||||||
|  |  | ||||||
|  |  | ||||||
|  | def read_bytes(f): | ||||||
|  |     n = read_int(f) | ||||||
|  |     return None if n < 0 else f.read(n) | ||||||
|  |  | ||||||
|  |  | ||||||
|  | def write_bytes(f, v): | ||||||
|  |     if v is None: | ||||||
|  |         write_int(f, -1) | ||||||
|  |         f.write(v) | ||||||
|  |     else: | ||||||
|  |         write_int(f, len(v)) | ||||||
|  |         f.write(v) | ||||||
|  |  | ||||||
|  |  | ||||||
|  | def read_short_bytes(f): | ||||||
|  |     n = read_short(f) | ||||||
|  |     return f.read(n) | ||||||
|  |  | ||||||
|  |  | ||||||
|  | def write_short_bytes(f, v): | ||||||
|  |     if v is None: | ||||||
|  |         write_short(f, 0) | ||||||
|  |     else: | ||||||
|  |         n = len(v) | ||||||
|  |         assert n <= MAX_SHORT | ||||||
|  |         write_short(f, n) | ||||||
|  |         f.write(v) | ||||||
|  |  | ||||||
|  |  | ||||||
|  | def read_inet(f): | ||||||
|  |     n = f.read(1) | ||||||
|  |     values = f.read(n) | ||||||
|  |     raise NotImplementedError | ||||||
|  |  | ||||||
|  |  | ||||||
|  | def write_inet(f, v): | ||||||
|  |     raise NotImplementedError | ||||||
|  |  | ||||||
|  | read_consistency = read_short | ||||||
|  | write_consistency = write_short | ||||||
|  |  | ||||||
|  |  | ||||||
|  | def read_string_map(f): | ||||||
|  |     n = read_short(f) | ||||||
|  |     return dict((read_string(f), read_string(f)) for _ in range(n)) | ||||||
|  |  | ||||||
|  |  | ||||||
|  | def write_string_map(f, v): | ||||||
|  |     write_short(f, len(v)) | ||||||
|  |     for key, value in six.iteritems(v): | ||||||
|  |         write_string(f, key) | ||||||
|  |         write_string(f, value) | ||||||
|  |  | ||||||
|  |  | ||||||
|  | def read_string_multimap(f): | ||||||
|  |     n = read_short(f) | ||||||
|  |     return dict((read_string(f), read_string_list(f)) for _ in range(n)) | ||||||
|  |  | ||||||
|  |  | ||||||
|  | def write_string_multimap(f, v): | ||||||
|  |     write_short(f, len(v)) | ||||||
|  |     for key, value in six.iteritems(v): | ||||||
|  |         write_string(f, key) | ||||||
|  |         write_string_list(f, value) | ||||||
|  |  | ||||||
|  |  | ||||||
|  | ## Define messages ############################## | ||||||
|  |  | ||||||
|  | HEADER_DIRECTION_FROM_CLIENT = 0x00 | ||||||
|  | HEADER_DIRECTION_TO_CLIENT = 0x80 | ||||||
|  | HEADER_DIRECTION_MASK = 0x80 | ||||||
|  |  | ||||||
|  | COMPRESSED_FLAG = 0x01 | ||||||
|  | TRACING_FLAG = 0x02 | ||||||
|  |  | ||||||
|  | _message_types_by_name = {} | ||||||
|  | _message_types_by_opcode = {} | ||||||
|  | _error_classes = {} | ||||||
|  |  | ||||||
|  |  | ||||||
|  | class _RegisterMessageType(type): | ||||||
|  |     def __init__(cls, what, *args, **kwargs): | ||||||
|  |         if what not in ('_MessageType', 'NewBase'): | ||||||
|  |             _message_types_by_name[cls.name] = cls | ||||||
|  |             _message_types_by_opcode[cls.opcode] = cls | ||||||
|  |  | ||||||
|  |  | ||||||
|  | def _get_params(message_obj): | ||||||
|  |     base_attrs = dir(_MessageType) | ||||||
|  |     return ( | ||||||
|  |         (n, a) for n, a in message_obj.__dict__.items() | ||||||
|  |         if n not in base_attrs and not n.startswith('_') and not callable(a) | ||||||
|  |     ) | ||||||
|  |  | ||||||
|  |  | ||||||
|  | class _MessageType(six.with_metaclass(_RegisterMessageType, object)): | ||||||
|  |     opcode = None | ||||||
|  |     name = None | ||||||
|  |     tracing = False | ||||||
|  |  | ||||||
|  |     def __repr__(self): | ||||||
|  |         return '<%s(%s)>' % (self.__class__.__name__, ', '.join('%s=%r' % i for i in _get_params(self))) | ||||||
|  |  | ||||||
|  |     def send_body(self, buf, protocol_version): | ||||||
|  |         """ | ||||||
|  |         Encode the body of this message for sending. | ||||||
|  |  | ||||||
|  |         :param buf: An instance of `ByteBuffer`. | ||||||
|  |         :param protocol_version: Version of the protocol currently being used. | ||||||
|  |  | ||||||
|  |         """ | ||||||
|  |         pass | ||||||
|  |  | ||||||
|  |     def to_binary(self, stream_id, protocol_version, compression=None): | ||||||
|  |         """ | ||||||
|  |         Pack this message into it's binary format. | ||||||
|  |         """ | ||||||
|  |         body = six.BytesIO() | ||||||
|  |         self.send_body(body, protocol_version) | ||||||
|  |         body = body.getvalue() | ||||||
|  |  | ||||||
|  |         flags = 0 | ||||||
|  |         if compression and len(body) > 0: | ||||||
|  |             body = compression(body) | ||||||
|  |             flags |= COMPRESSED_FLAG | ||||||
|  |         if self.tracing: | ||||||
|  |             flags |= TRACING_FLAG | ||||||
|  |  | ||||||
|  |         msg = six.BytesIO() | ||||||
|  |         write_header( | ||||||
|  |             msg, | ||||||
|  |             protocol_version | HEADER_DIRECTION_FROM_CLIENT, | ||||||
|  |             flags, stream_id, self.opcode, len(body) | ||||||
|  |         ) | ||||||
|  |         msg.write(body) | ||||||
|  |  | ||||||
|  |         return msg.getvalue() | ||||||
|  |  | ||||||
|  |  | ||||||
|  | def decode_response(stream_id, flags, opcode, body, decompressor=None): | ||||||
|  |     """ | ||||||
|  |     Build msg class. | ||||||
|  |     """ | ||||||
|  |     if flags & COMPRESSED_FLAG: | ||||||
|  |         if callable(decompressor): | ||||||
|  |             body = decompressor(body) | ||||||
|  |             flags ^= COMPRESSED_FLAG | ||||||
|  |         else: | ||||||
|  |             raise TypeError("De-compressor not available for compressed frame!") | ||||||
|  |  | ||||||
|  |     body = six.BytesIO(body) | ||||||
|  |     if flags & TRACING_FLAG: | ||||||
|  |         trace_id = read_uuid(body) | ||||||
|  |         flags ^= TRACING_FLAG | ||||||
|  |     else: | ||||||
|  |         trace_id = None | ||||||
|  |  | ||||||
|  |     if flags: | ||||||
|  |         # TODO: log.warn("Unknown protocol flags set: %02x. May cause problems.", flags) | ||||||
|  |         pass | ||||||
|  |  | ||||||
|  |     msg_class = _message_types_by_opcode[opcode] | ||||||
|  |     msg = msg_class.recv_body(body) | ||||||
|  |     msg.stream_id = stream_id | ||||||
|  |     msg.trace_id = trace_id | ||||||
|  |     return msg | ||||||
|  |  | ||||||
|  |  | ||||||
|  | class StartupMessage(_MessageType): | ||||||
|  |     opcode = 0x01 | ||||||
|  |     name = 'STARTUP' | ||||||
|  |  | ||||||
|  |     KNOWN_OPTION_KEYS = set(('CQL_VERSION', 'COMPRESSION',)) | ||||||
|  |  | ||||||
|  |     def __init__(self, cqlversion, options): | ||||||
|  |         self.cqlversion = cqlversion | ||||||
|  |         self.options = options | ||||||
|  |  | ||||||
|  |     def send_body(self, f, protocol_version): | ||||||
|  |         opt_map = self.options.copy() | ||||||
|  |         opt_map['CQL_VERSION'] = self.cqlversion | ||||||
|  |         write_string_map(f, opt_map) | ||||||
|  |  | ||||||
|  |  | ||||||
|  | class ReadyMessage(_MessageType): | ||||||
|  |     opcode = 0x02 | ||||||
|  |     name = 'READY' | ||||||
|  |  | ||||||
|  |     @classmethod | ||||||
|  |     def recv_body(cls, f): | ||||||
|  |         return cls() | ||||||
|  |  | ||||||
|  |  | ||||||
|  | class AuthenticateMessage(_MessageType): | ||||||
|  |     opcode = 0x03 | ||||||
|  |     name = 'AUTHENTICATE' | ||||||
|  |  | ||||||
|  |     def __init__(self, authenticator): | ||||||
|  |         self.authenticator = authenticator | ||||||
|  |  | ||||||
|  |     @classmethod | ||||||
|  |     def recv_body(cls, f): | ||||||
|  |         authenticator = read_string(f) | ||||||
|  |         return cls(authenticator) | ||||||
|  |  | ||||||
|  |  | ||||||
|  | class CredentialsMessage(_MessageType): | ||||||
|  |     opcode = 0x04 | ||||||
|  |     name = 'CREDENTIALS' | ||||||
|  |  | ||||||
|  |     def __init__(self, credentials): | ||||||
|  |         self.credentials = credentials | ||||||
|  |  | ||||||
|  |     def send_body(self, f, protocol_version): | ||||||
|  |         write_string_map(f, self.credentials) | ||||||
|  |  | ||||||
|  |  | ||||||
|  | class OptionsMessage(_MessageType): | ||||||
|  |     opcode = 0x05 | ||||||
|  |     name = 'OPTIONS' | ||||||
|  |  | ||||||
|  |  | ||||||
|  | class SupportedMessage(_MessageType): | ||||||
|  |     opcode = 0x06 | ||||||
|  |     name = 'SUPPORTED' | ||||||
|  |  | ||||||
|  |     def __init__(self, cql_versions, options): | ||||||
|  |         self.cql_versions = cql_versions | ||||||
|  |         self.options = options | ||||||
|  |  | ||||||
|  |     @classmethod | ||||||
|  |     def recv_body(cls, f): | ||||||
|  |         options = read_string_multimap(f) | ||||||
|  |         cql_versions = options.pop('CQL_VERSION') | ||||||
|  |         return cls(cql_versions, options) | ||||||
|  |  | ||||||
|  |  | ||||||
|  | class QueryMessage(_MessageType): | ||||||
|  |     opcode = 0x07 | ||||||
|  |     name = 'QUERY' | ||||||
|  |  | ||||||
|  |     def __init__(self, query, consistency_level): | ||||||
|  |         self.query = query | ||||||
|  |         self.consistency_level = consistency_level | ||||||
|  |  | ||||||
|  |     def send_body(self, f, protocol_version): | ||||||
|  |         write_long_string(f, self.query) | ||||||
|  |         write_consistency(f, self.consistency_level) | ||||||
|  |  | ||||||
|  |     @classmethod | ||||||
|  |     def recv_body(cls, f): | ||||||
|  |         query = read_long_string(f) | ||||||
|  |         consistency_level = read_consistency(f) | ||||||
|  |         return cls(query, consistency_level) | ||||||
|  |  | ||||||
|  |  | ||||||
|  | class ResultMessage(_MessageType): | ||||||
|  |     opcode = 0x08 | ||||||
|  |     name = 'RESULT' | ||||||
|  |  | ||||||
|  |     def __init__(self, kind, results): | ||||||
|  |         self.kind = kind | ||||||
|  |         self.results = results | ||||||
|  |  | ||||||
|  |  | ||||||
|  | class PrepareMessage(_MessageType): | ||||||
|  |     opcode = 0x09 | ||||||
|  |     name = 'PREPARE' | ||||||
|  |  | ||||||
|  |     def __init__(self, query): | ||||||
|  |         self.query = query | ||||||
|  |  | ||||||
|  |     def send_body(self, f, protocol_version): | ||||||
|  |         write_long_string(f, self.query) | ||||||
		Reference in New Issue
	
	Block a user
	 Tim Savage
					Tim Savage