Restructure load local code
This commit is contained in:
@@ -334,12 +334,9 @@ class MysqlPacket(object):
|
||||
field_count = ord(self._data[0:1])
|
||||
return 1 <= field_count <= 250
|
||||
|
||||
def is_local_file_packet(self):
|
||||
def is_load_local_packet(self):
|
||||
return self._data[0:1] == b'\xfb'
|
||||
|
||||
def get_local_file_name(self):
|
||||
return self._data[1:]
|
||||
|
||||
def is_error_packet(self):
|
||||
return self._data[0:1] == b'\xff'
|
||||
|
||||
@@ -460,6 +457,29 @@ class EOFPacketWrapper(object):
|
||||
return getattr(self.packet, key)
|
||||
|
||||
|
||||
|
||||
class LoadLocalPacketWrapper(object):
|
||||
"""
|
||||
Load Local Packet Wrapper. It uses an existing packet object, and wraps
|
||||
around it, exposing useful variables while still providing access
|
||||
to the original packet objects variables and methods.
|
||||
"""
|
||||
|
||||
def __init__(self, from_packet):
|
||||
if not from_packet.is_load_local_packet():
|
||||
raise ValueError(
|
||||
"Cannot create '{0}' object from invalid packet type".format(
|
||||
self.__class__))
|
||||
|
||||
self.packet = from_packet
|
||||
self.filename = self.packet.get_all_data()[1:]
|
||||
if DEBUG: print("filename=", self.filename)
|
||||
|
||||
def __getattr__(self, key):
|
||||
return getattr(self.packet, key)
|
||||
|
||||
|
||||
|
||||
class Connection(object):
|
||||
"""
|
||||
Representation of a socket with a mysql server.
|
||||
@@ -1085,15 +1105,8 @@ class MySQLResult(object):
|
||||
# TODO: use classes for different packet types?
|
||||
if first_packet.is_ok_packet():
|
||||
self._read_ok_packet(first_packet)
|
||||
if first_packet.is_local_file_packet():
|
||||
requested_file = first_packet.get_local_file_name()
|
||||
# ensure the filename returned by the server matches the
|
||||
# file we asked to load in the initial query
|
||||
if self.connection.local_file == requested_file:
|
||||
local_packet = LoadLocalFile(requested_file, self.connection)
|
||||
local_packet.send_data()
|
||||
else:
|
||||
raise OperationalError(2014, "Command Out of Sync")
|
||||
elif first_packet.is_load_local_packet():
|
||||
self._read_load_local_packet(first_packet)
|
||||
else:
|
||||
self._read_result_packet(first_packet)
|
||||
finally:
|
||||
@@ -1125,6 +1138,16 @@ class MySQLResult(object):
|
||||
self.message = ok_packet.message
|
||||
self.has_next = ok_packet.has_next
|
||||
|
||||
def _read_load_local_packet(self, first_packet):
|
||||
load_packet = LoadLocalPacketWrapper(first_packet)
|
||||
# ensure the filename returned by the server matches the
|
||||
# file we asked to load in the initial query
|
||||
if self.connection.local_file == load_packet.filename:
|
||||
local_packet = LoadLocalFile(load_packet.filename, self.connection)
|
||||
local_packet.send_data()
|
||||
else:
|
||||
raise OperationalError(2014, "Command Out of Sync")
|
||||
|
||||
def _check_packet_is_eof(self, packet):
|
||||
if packet.is_eof_packet():
|
||||
eof_packet = EOFPacketWrapper(packet)
|
||||
@@ -1227,41 +1250,45 @@ class LoadLocalFile(object):
|
||||
if not self.connection.socket:
|
||||
raise InterfaceError("(0, '')")
|
||||
|
||||
with open(self.filename, 'r') as open_file:
|
||||
chunk_size = MAX_PACKET_LEN
|
||||
prelude = ""
|
||||
packet = ""
|
||||
packet_size = 0
|
||||
# sequence id is 2 as we already sent a query packet
|
||||
seq_id = 2
|
||||
try:
|
||||
with open(self.filename, 'r') as open_file:
|
||||
chunk_size = MAX_PACKET_LEN
|
||||
prelude = ""
|
||||
packet = ""
|
||||
packet_size = 0
|
||||
# sequence id is 2 as we already sent a query packet
|
||||
seq_id = 2
|
||||
|
||||
for line in open_file:
|
||||
line_length = len(line)
|
||||
format_str = '!{0}s'.format(line_length)
|
||||
line = line.encode(self.connection.encoding)
|
||||
if packet_size + len(line) < chunk_size:
|
||||
packet += struct.pack(format_str, line)
|
||||
packet_size += line_length
|
||||
else:
|
||||
# send the existing packet when we have reached the chunk size
|
||||
prelude = struct.pack('<i', packet_size)[:3] + int2byte(seq_id)
|
||||
packet = prelude + packet
|
||||
self.connection._write_bytes(packet)
|
||||
for line in open_file:
|
||||
line_length = len(line)
|
||||
format_str = '!{0}s'.format(line_length)
|
||||
line = line.encode(self.connection.encoding)
|
||||
if packet_size + len(line) < chunk_size:
|
||||
packet += struct.pack(format_str, line)
|
||||
packet_size += line_length
|
||||
else:
|
||||
# send the existing packet when we have reached the chunk size
|
||||
prelude = struct.pack('<i', packet_size)[:3] + int2byte(seq_id)
|
||||
packet = prelude + packet
|
||||
self.connection._write_bytes(packet)
|
||||
|
||||
seq_id += 1
|
||||
packet = struct.pack(format_str, line)
|
||||
packet_size = line_length
|
||||
seq_id += 1
|
||||
packet = struct.pack(format_str, line)
|
||||
packet_size = line_length
|
||||
|
||||
# send the last data packet
|
||||
prelude = struct.pack('<i', packet_size)[:3] + int2byte(seq_id)
|
||||
packet = prelude + packet
|
||||
self.connection._write_bytes(packet)
|
||||
# send the last data packet
|
||||
prelude = struct.pack('<i', packet_size)[:3] + int2byte(seq_id)
|
||||
packet = prelude + packet
|
||||
self.connection._write_bytes(packet)
|
||||
|
||||
# send the empty packet to signify we are done sending data
|
||||
seq_id += 1
|
||||
packet = struct.pack('<i', 0)[:3] + int2byte(seq_id)
|
||||
self.connection._write_bytes(packet)
|
||||
# send the empty packet to signify we are done sending data
|
||||
seq_id += 1
|
||||
packet = struct.pack('<i', 0)[:3] + int2byte(seq_id)
|
||||
self.connection._write_bytes(packet)
|
||||
|
||||
self.connection._read_ok_packet()
|
||||
except IOError:
|
||||
raise OperationalError(1017, "Can't find file '{}'".format(self.filename))
|
||||
|
||||
self.connection._read_ok_packet()
|
||||
|
||||
# g:khuno_ignore='E226,E301,E701'
|
||||
|
||||
Reference in New Issue
Block a user