Restructure load local code

This commit is contained in:
Stacey Wrazien
2015-01-07 10:29:07 -05:00
parent b7551d25d0
commit 0ea120e29a

View File

@@ -334,12 +334,9 @@ class MysqlPacket(object):
field_count = ord(self._data[0:1])
return 1 <= field_count <= 250
def is_local_file_packet(self):
def is_load_local_packet(self):
return self._data[0:1] == b'\xfb'
def get_local_file_name(self):
return self._data[1:]
def is_error_packet(self):
return self._data[0:1] == b'\xff'
@@ -460,6 +457,29 @@ class EOFPacketWrapper(object):
return getattr(self.packet, key)
class LoadLocalPacketWrapper(object):
"""
Load Local Packet Wrapper. It uses an existing packet object, and wraps
around it, exposing useful variables while still providing access
to the original packet objects variables and methods.
"""
def __init__(self, from_packet):
if not from_packet.is_load_local_packet():
raise ValueError(
"Cannot create '{0}' object from invalid packet type".format(
self.__class__))
self.packet = from_packet
self.filename = self.packet.get_all_data()[1:]
if DEBUG: print("filename=", self.filename)
def __getattr__(self, key):
return getattr(self.packet, key)
class Connection(object):
"""
Representation of a socket with a mysql server.
@@ -1085,15 +1105,8 @@ class MySQLResult(object):
# TODO: use classes for different packet types?
if first_packet.is_ok_packet():
self._read_ok_packet(first_packet)
if first_packet.is_local_file_packet():
requested_file = first_packet.get_local_file_name()
# ensure the filename returned by the server matches the
# file we asked to load in the initial query
if self.connection.local_file == requested_file:
local_packet = LoadLocalFile(requested_file, self.connection)
local_packet.send_data()
else:
raise OperationalError(2014, "Command Out of Sync")
elif first_packet.is_load_local_packet():
self._read_load_local_packet(first_packet)
else:
self._read_result_packet(first_packet)
finally:
@@ -1125,6 +1138,16 @@ class MySQLResult(object):
self.message = ok_packet.message
self.has_next = ok_packet.has_next
def _read_load_local_packet(self, first_packet):
load_packet = LoadLocalPacketWrapper(first_packet)
# ensure the filename returned by the server matches the
# file we asked to load in the initial query
if self.connection.local_file == load_packet.filename:
local_packet = LoadLocalFile(load_packet.filename, self.connection)
local_packet.send_data()
else:
raise OperationalError(2014, "Command Out of Sync")
def _check_packet_is_eof(self, packet):
if packet.is_eof_packet():
eof_packet = EOFPacketWrapper(packet)
@@ -1227,41 +1250,45 @@ class LoadLocalFile(object):
if not self.connection.socket:
raise InterfaceError("(0, '')")
with open(self.filename, 'r') as open_file:
chunk_size = MAX_PACKET_LEN
prelude = ""
packet = ""
packet_size = 0
# sequence id is 2 as we already sent a query packet
seq_id = 2
try:
with open(self.filename, 'r') as open_file:
chunk_size = MAX_PACKET_LEN
prelude = ""
packet = ""
packet_size = 0
# sequence id is 2 as we already sent a query packet
seq_id = 2
for line in open_file:
line_length = len(line)
format_str = '!{0}s'.format(line_length)
line = line.encode(self.connection.encoding)
if packet_size + len(line) < chunk_size:
packet += struct.pack(format_str, line)
packet_size += line_length
else:
# send the existing packet when we have reached the chunk size
prelude = struct.pack('<i', packet_size)[:3] + int2byte(seq_id)
packet = prelude + packet
self.connection._write_bytes(packet)
for line in open_file:
line_length = len(line)
format_str = '!{0}s'.format(line_length)
line = line.encode(self.connection.encoding)
if packet_size + len(line) < chunk_size:
packet += struct.pack(format_str, line)
packet_size += line_length
else:
# send the existing packet when we have reached the chunk size
prelude = struct.pack('<i', packet_size)[:3] + int2byte(seq_id)
packet = prelude + packet
self.connection._write_bytes(packet)
seq_id += 1
packet = struct.pack(format_str, line)
packet_size = line_length
seq_id += 1
packet = struct.pack(format_str, line)
packet_size = line_length
# send the last data packet
prelude = struct.pack('<i', packet_size)[:3] + int2byte(seq_id)
packet = prelude + packet
self.connection._write_bytes(packet)
# send the last data packet
prelude = struct.pack('<i', packet_size)[:3] + int2byte(seq_id)
packet = prelude + packet
self.connection._write_bytes(packet)
# send the empty packet to signify we are done sending data
seq_id += 1
packet = struct.pack('<i', 0)[:3] + int2byte(seq_id)
self.connection._write_bytes(packet)
# send the empty packet to signify we are done sending data
seq_id += 1
packet = struct.pack('<i', 0)[:3] + int2byte(seq_id)
self.connection._write_bytes(packet)
self.connection._read_ok_packet()
except IOError:
raise OperationalError(1017, "Can't find file '{}'".format(self.filename))
self.connection._read_ok_packet()
# g:khuno_ignore='E226,E301,E701'