145 lines
4.6 KiB
Python
145 lines
4.6 KiB
Python
import io
|
|
|
|
from ..codec import gzip_decode, snappy_decode
|
|
from . import pickle
|
|
from .struct import Struct
|
|
from .types import (
|
|
Int8, Int32, Int64, Bytes, Schema, AbstractType
|
|
)
|
|
from ..util import crc32
|
|
|
|
|
|
class Message(Struct):
|
|
SCHEMA = Schema(
|
|
('crc', Int32),
|
|
('magic', Int8),
|
|
('attributes', Int8),
|
|
('key', Bytes),
|
|
('value', Bytes)
|
|
)
|
|
CODEC_MASK = 0x03
|
|
CODEC_GZIP = 0x01
|
|
CODEC_SNAPPY = 0x02
|
|
|
|
def __init__(self, value, key=None, magic=0, attributes=0, crc=0):
|
|
assert value is None or isinstance(value, bytes), 'value must be bytes'
|
|
assert key is None or isinstance(key, bytes), 'key must be bytes'
|
|
self.crc = crc
|
|
self.magic = magic
|
|
self.attributes = attributes
|
|
self.key = key
|
|
self.value = value
|
|
self.encode = self._encode_self
|
|
|
|
def _encode_self(self, recalc_crc=True):
|
|
message = Message.SCHEMA.encode(
|
|
(self.crc, self.magic, self.attributes, self.key, self.value)
|
|
)
|
|
if not recalc_crc:
|
|
return message
|
|
self.crc = crc32(message[4:])
|
|
return self.SCHEMA.fields[0].encode(self.crc) + message[4:]
|
|
|
|
@classmethod
|
|
def decode(cls, data):
|
|
if isinstance(data, bytes):
|
|
data = io.BytesIO(data)
|
|
fields = [field.decode(data) for field in cls.SCHEMA.fields]
|
|
return cls(fields[4], key=fields[3],
|
|
magic=fields[1], attributes=fields[2], crc=fields[0])
|
|
|
|
def validate_crc(self):
|
|
raw_msg = self._encode_self(recalc_crc=False)
|
|
crc = crc32(raw_msg[4:])
|
|
if crc == self.crc:
|
|
return True
|
|
return False
|
|
|
|
def is_compressed(self):
|
|
return self.attributes & self.CODEC_MASK != 0
|
|
|
|
def decompress(self):
|
|
codec = self.attributes & self.CODEC_MASK
|
|
assert codec in (self.CODEC_GZIP, self.CODEC_SNAPPY)
|
|
if codec == self.CODEC_GZIP:
|
|
raw_bytes = gzip_decode(self.value)
|
|
else:
|
|
raw_bytes = snappy_decode(self.value)
|
|
|
|
return MessageSet.decode(raw_bytes, bytes_to_read=len(raw_bytes))
|
|
|
|
def __hash__(self):
|
|
return hash(self._encode_self(recalc_crc=False))
|
|
|
|
|
|
class PartialMessage(bytes):
|
|
def __repr__(self):
|
|
return 'PartialMessage(%s)' % self
|
|
|
|
|
|
class MessageSet(AbstractType):
|
|
ITEM = Schema(
|
|
('offset', Int64),
|
|
('message_size', Int32),
|
|
('message', Message.SCHEMA)
|
|
)
|
|
|
|
@classmethod
|
|
def encode(cls, items, size=True, recalc_message_size=True):
|
|
encoded_values = []
|
|
for (offset, message_size, message) in items:
|
|
if isinstance(message, Message):
|
|
encoded_message = message.encode()
|
|
else:
|
|
encoded_message = cls.ITEM.fields[2].encode(message)
|
|
if recalc_message_size:
|
|
message_size = len(encoded_message)
|
|
encoded_values.append(cls.ITEM.fields[0].encode(offset))
|
|
encoded_values.append(cls.ITEM.fields[1].encode(message_size))
|
|
encoded_values.append(encoded_message)
|
|
encoded = b''.join(encoded_values)
|
|
if not size:
|
|
return encoded
|
|
return Int32.encode(len(encoded)) + encoded
|
|
|
|
@classmethod
|
|
def decode(cls, data, bytes_to_read=None):
|
|
"""Compressed messages should pass in bytes_to_read (via message size)
|
|
otherwise, we decode from data as Int32
|
|
"""
|
|
if isinstance(data, bytes):
|
|
data = io.BytesIO(data)
|
|
if bytes_to_read is None:
|
|
bytes_to_read = Int32.decode(data)
|
|
items = []
|
|
|
|
# We need at least 8 + 4 + 14 bytes to read offset + message size + message
|
|
# (14 bytes is a message w/ null key and null value)
|
|
while bytes_to_read >= 26:
|
|
offset = Int64.decode(data)
|
|
bytes_to_read -= 8
|
|
|
|
message_size = Int32.decode(data)
|
|
bytes_to_read -= 4
|
|
|
|
# if FetchRequest max_bytes is smaller than the available message set
|
|
# the server returns partial data for the final message
|
|
if message_size > bytes_to_read:
|
|
break
|
|
|
|
message = Message.decode(data)
|
|
bytes_to_read -= message_size
|
|
|
|
items.append((offset, message_size, message))
|
|
|
|
# If any bytes are left over, clear them from the buffer
|
|
# and append a PartialMessage to signal that max_bytes may be too small
|
|
if bytes_to_read:
|
|
items.append((None, None, PartialMessage(data.read(bytes_to_read))))
|
|
|
|
return items
|
|
|
|
@classmethod
|
|
def repr(cls, messages):
|
|
return '[' + ', '.join([cls.ITEM.repr(m) for m in messages]) + ']'
|