Merge pull request #258 from methane/simple-mysql-packet

Simplify MysqlPacket
This commit is contained in:
INADA Naoki
2014-08-23 10:54:47 +09:00

View File

@@ -232,37 +232,15 @@ def unpack_int64(n):
class MysqlPacket(object): class MysqlPacket(object):
"""Representation of a MySQL response packet. Reads in the packet """Representation of a MySQL response packet.
from the network socket, removes packet header and provides an interface
for reading/parsing the packet results."""
__slots__ = ('_position', '_data', '_packet_number')
def __init__(self, connection): Provides an interface for reading/parsing the packet results.
"""
__slots__ = ('_position', '_data')
def __init__(self, data, encoding):
self._position = 0 self._position = 0
self._recv_packet(connection) self._data = data
def _recv_packet(self, connection):
"""Parse the packet header and read entire packet payload into buffer."""
buff = b''
while True:
packet_header = connection._read_bytes(4)
if DEBUG: dump_packet(packet_header)
packet_length_bin = packet_header[:3]
#TODO: check sequence id
self._packet_number = byte2int(packet_header[3])
bin_length = packet_length_bin + b'\0' # pad little-endian number
bytes_to_read = struct.unpack('<I', bin_length)[0]
recv_data = connection._read_bytes(bytes_to_read)
if DEBUG: dump_packet(recv_data)
buff += recv_data
if bytes_to_read < MAX_PACKET_LEN:
break
self._data = buff
def packet_number(self):
return self._packet_number
def get_all_data(self): def get_all_data(self):
return self._data return self._data
@@ -380,9 +358,9 @@ class FieldDescriptorPacket(MysqlPacket):
attributes on the class such as: db, table_name, name, length, type_code. attributes on the class such as: db, table_name, name, length, type_code.
""" """
def __init__(self, connection): def __init__(self, data, encoding):
MysqlPacket.__init__(self, connection) MysqlPacket.__init__(self, data, encoding)
self.__parse_field_descriptor(connection.encoding) self.__parse_field_descriptor(encoding)
def __parse_field_descriptor(self, encoding): def __parse_field_descriptor(self, encoding):
"""Parse the 'Field Descriptor' (Metadata) packet. """Parse the 'Field Descriptor' (Metadata) packet.
@@ -402,27 +380,19 @@ class FieldDescriptorPacket(MysqlPacket):
self.flags = struct.unpack('<H', self.read(2))[0] self.flags = struct.unpack('<H', self.read(2))[0]
self.scale = byte2int(self.read(1)) # "decimals" self.scale = byte2int(self.read(1)) # "decimals"
self.advance(2) # filler (always 0x00) self.advance(2) # filler (always 0x00)
# 'default' is a length coded binary and is still in the buffer? # 'default' is a length coded binary and is still in the buffer?
# not used for normal result sets... # not used for normal result sets...
def description(self): def description(self):
"""Provides a 7-item tuple compatible with the Python PEP249 DB Spec.""" """Provides a 7-item tuple compatible with the Python PEP249 DB Spec."""
desc = [] return (
desc.append(self.name) self.name,
desc.append(self.type_code) self.type_code,
desc.append(None) # TODO: display_length; should this be self.length? None, # TODO: display_length; should this be self.length?
desc.append(self.get_column_length()) # 'internal_size' self.get_column_length(), # 'internal_size'
desc.append(self.get_column_length()) # 'precision' # TODO: why!?!? self.get_column_length(), # 'precision' # TODO: why!?!?
desc.append(self.scale) self.scale,
self.flags % 2 == 0)
# 'null_ok' -- can this be True/False rather than 1/0?
# if so just do: desc.append(bool(self.flags % 2 == 0))
if self.flags % 2 == 0:
desc.append(1)
else:
desc.append(0)
return tuple(desc)
def get_column_length(self): def get_column_length(self):
if self.type_code == FIELD_TYPE.VAR_STRING: if self.type_code == FIELD_TYPE.VAR_STRING:
@@ -840,7 +810,24 @@ class Connection(object):
"""Read an entire "mysql packet" in its entirety from the network """Read an entire "mysql packet" in its entirety from the network
and return a MysqlPacket type that represents the results. and return a MysqlPacket type that represents the results.
""" """
packet = packet_type(self) buff = b''
while True:
packet_header = self._read_bytes(4)
if DEBUG: dump_packet(packet_header)
packet_length_bin = packet_header[:3]
#TODO: check sequence id
# packet_number
byte2int(packet_header[3])
bin_length = packet_length_bin + b'\0' # pad little-endian number
bytes_to_read = struct.unpack('<I', bin_length)[0]
recv_data = self._read_bytes(bytes_to_read)
if DEBUG: dump_packet(recv_data)
buff += recv_data
if bytes_to_read < MAX_PACKET_LEN:
break
packet = packet_type(buff, self.encoding)
packet.check_error() packet.check_error()
return packet return packet
@@ -968,9 +955,7 @@ class Connection(object):
self._write_bytes(data) self._write_bytes(data)
auth_packet = MysqlPacket(self) auth_packet = self._read_packet()
auth_packet.check_error()
if DEBUG: auth_packet.dump()
# if old_passwords is enabled the packet will be 1 byte long and # if old_passwords is enabled the packet will be 1 byte long and
# have the octet 254 # have the octet 254
@@ -979,11 +964,8 @@ class Connection(object):
# send legacy handshake # send legacy handshake
data = _scramble_323(self.password.encode('latin1'), self.salt) + b'\0' data = _scramble_323(self.password.encode('latin1'), self.salt) + b'\0'
data = pack_int24(len(data)) + int2byte(next_packet) + data data = pack_int24(len(data)) + int2byte(next_packet) + data
self._write_bytes(data) self._write_bytes(data)
auth_packet = MysqlPacket(self) auth_packet = self._read_packet()
auth_packet.check_error()
if DEBUG: auth_packet.dump()
# _mysql support # _mysql support
def thread_id(self): def thread_id(self):
@@ -1000,8 +982,7 @@ class Connection(object):
def _get_server_information(self): def _get_server_information(self):
i = 0 i = 0
packet = MysqlPacket(self) packet = self._read_packet()
packet.check_error()
data = packet.get_all_data() data = packet.get_all_data()
if DEBUG: dump_packet(data) if DEBUG: dump_packet(data)