168 lines
		
	
	
		
			5.6 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			168 lines
		
	
	
		
			5.6 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
| import io
 | |
| 
 | |
| from ..codec import (has_gzip, has_snappy, has_lz4,
 | |
|                      gzip_decode, snappy_decode, lz4_decode)
 | |
| from . import pickle
 | |
| from .struct import Struct
 | |
| from .types import (
 | |
|     Int8, Int32, Int64, Bytes, Schema, AbstractType
 | |
| )
 | |
| from ..util import crc32
 | |
| 
 | |
| 
 | |
| class Message(Struct):
 | |
|     SCHEMA = Schema(
 | |
|         ('crc', Int32),
 | |
|         ('magic', Int8),
 | |
|         ('attributes', Int8),
 | |
|         ('key', Bytes),
 | |
|         ('value', Bytes)
 | |
|     )
 | |
|     CODEC_MASK = 0x03
 | |
|     CODEC_GZIP = 0x01
 | |
|     CODEC_SNAPPY = 0x02
 | |
|     CODEC_LZ4 = 0x03
 | |
|     HEADER_SIZE = 14 # crc(4), magic(1), attributes(1), key+value size(4*2)
 | |
| 
 | |
|     def __init__(self, value, key=None, magic=0, attributes=0, crc=0):
 | |
|         assert value is None or isinstance(value, bytes), 'value must be bytes'
 | |
|         assert key is None or isinstance(key, bytes), 'key must be bytes'
 | |
|         self.crc = crc
 | |
|         self.magic = magic
 | |
|         self.attributes = attributes
 | |
|         self.key = key
 | |
|         self.value = value
 | |
|         self.encode = self._encode_self
 | |
| 
 | |
|     def _encode_self(self, recalc_crc=True):
 | |
|         message = Message.SCHEMA.encode(
 | |
|           (self.crc, self.magic, self.attributes, self.key, self.value)
 | |
|         )
 | |
|         if not recalc_crc:
 | |
|             return message
 | |
|         self.crc = crc32(message[4:])
 | |
|         return self.SCHEMA.fields[0].encode(self.crc) + message[4:]
 | |
| 
 | |
|     @classmethod
 | |
|     def decode(cls, data):
 | |
|         if isinstance(data, bytes):
 | |
|             data = io.BytesIO(data)
 | |
|         fields = [field.decode(data) for field in cls.SCHEMA.fields]
 | |
|         return cls(fields[4], key=fields[3],
 | |
|                    magic=fields[1], attributes=fields[2], crc=fields[0])
 | |
| 
 | |
|     def validate_crc(self):
 | |
|         raw_msg = self._encode_self(recalc_crc=False)
 | |
|         crc = crc32(raw_msg[4:])
 | |
|         if crc == self.crc:
 | |
|             return True
 | |
|         return False
 | |
| 
 | |
|     def is_compressed(self):
 | |
|         return self.attributes & self.CODEC_MASK != 0
 | |
| 
 | |
|     def decompress(self):
 | |
|         codec = self.attributes & self.CODEC_MASK
 | |
|         assert codec in (self.CODEC_GZIP, self.CODEC_SNAPPY, self.CODEC_LZ4)
 | |
|         if codec == self.CODEC_GZIP:
 | |
|             assert has_gzip(), 'Gzip decompression unsupported'
 | |
|             raw_bytes = gzip_decode(self.value)
 | |
|         elif codec == self.CODEC_SNAPPY:
 | |
|             assert has_snappy(), 'Snappy decompression unsupported'
 | |
|             raw_bytes = snappy_decode(self.value)
 | |
|         elif codec == self.CODEC_LZ4:
 | |
|             assert has_lz4(), 'LZ4 decompression unsupported'
 | |
|             raw_bytes = lz4_decode(self.value)
 | |
|         else:
 | |
|           raise Exception('This should be impossible')
 | |
| 
 | |
|         return MessageSet.decode(raw_bytes, bytes_to_read=len(raw_bytes))
 | |
| 
 | |
|     def __hash__(self):
 | |
|         return hash(self._encode_self(recalc_crc=False))
 | |
| 
 | |
| 
 | |
| class PartialMessage(bytes):
 | |
|     def __repr__(self):
 | |
|         return 'PartialMessage(%s)' % self
 | |
| 
 | |
| 
 | |
| class MessageSet(AbstractType):
 | |
|     ITEM = Schema(
 | |
|         ('offset', Int64),
 | |
|         ('message_size', Int32),
 | |
|         ('message', Message.SCHEMA)
 | |
|     )
 | |
|     HEADER_SIZE = 12 # offset + message_size
 | |
| 
 | |
|     @classmethod
 | |
|     def encode(cls, items, size=True, recalc_message_size=True):
 | |
|         # RecordAccumulator encodes messagesets internally
 | |
|         if isinstance(items, io.BytesIO):
 | |
|             size = Int32.decode(items)
 | |
|             # rewind and return all the bytes
 | |
|             items.seek(-4, 1)
 | |
|             return items.read(size + 4)
 | |
| 
 | |
|         encoded_values = []
 | |
|         for (offset, message_size, message) in items:
 | |
|             if isinstance(message, Message):
 | |
|                 encoded_message = message.encode()
 | |
|             else:
 | |
|                 encoded_message = cls.ITEM.fields[2].encode(message)
 | |
|             if recalc_message_size:
 | |
|                 message_size = len(encoded_message)
 | |
|             encoded_values.append(cls.ITEM.fields[0].encode(offset))
 | |
|             encoded_values.append(cls.ITEM.fields[1].encode(message_size))
 | |
|             encoded_values.append(encoded_message)
 | |
|         encoded = b''.join(encoded_values)
 | |
|         if not size:
 | |
|             return encoded
 | |
|         return Int32.encode(len(encoded)) + encoded
 | |
| 
 | |
|     @classmethod
 | |
|     def decode(cls, data, bytes_to_read=None):
 | |
|         """Compressed messages should pass in bytes_to_read (via message size)
 | |
|         otherwise, we decode from data as Int32
 | |
|         """
 | |
|         if isinstance(data, bytes):
 | |
|             data = io.BytesIO(data)
 | |
|         if bytes_to_read is None:
 | |
|             bytes_to_read = Int32.decode(data)
 | |
|         items = []
 | |
| 
 | |
|         # We need at least 8 + 4 + 14 bytes to read offset + message size + message
 | |
|         # (14 bytes is a message w/ null key and null value)
 | |
|         while bytes_to_read >= 26:
 | |
|             offset = Int64.decode(data)
 | |
|             bytes_to_read -= 8
 | |
| 
 | |
|             message_size = Int32.decode(data)
 | |
|             bytes_to_read -= 4
 | |
| 
 | |
|             # if FetchRequest max_bytes is smaller than the available message set
 | |
|             # the server returns partial data for the final message
 | |
|             if message_size > bytes_to_read:
 | |
|                 break
 | |
| 
 | |
|             message = Message.decode(data)
 | |
|             bytes_to_read -= message_size
 | |
| 
 | |
|             items.append((offset, message_size, message))
 | |
| 
 | |
|         # If any bytes are left over, clear them from the buffer
 | |
|         # and append a PartialMessage to signal that max_bytes may be too small
 | |
|         if bytes_to_read:
 | |
|             items.append((None, None, PartialMessage(data.read(bytes_to_read))))
 | |
| 
 | |
|         return items
 | |
| 
 | |
|     @classmethod
 | |
|     def repr(cls, messages):
 | |
|         if isinstance(messages, io.BytesIO):
 | |
|             offset = messages.tell()
 | |
|             decoded = cls.decode(messages)
 | |
|             messages.seek(offset)
 | |
|             messages = decoded
 | |
|         return '[' + ', '.join([cls.ITEM.repr(m) for m in messages]) + ']'
 | 
