Merge pull request #279 from wraziens/load_local
Implement load data local infile command. Resolves #62
This commit is contained in:
@@ -1,4 +1,5 @@
|
|||||||
[
|
[
|
||||||
{"host": "localhost", "user": "root", "passwd": "", "db": "test_pymysql", "use_unicode": true},
|
{"host": "localhost", "user": "root", "passwd": "", "db": "test_pymysql", "use_unicode": true},
|
||||||
{"host": "localhost", "user": "root", "passwd": "", "db": "test_pymysql2" }
|
{"host": "localhost", "user": "root", "passwd": "", "db": "test_pymysql2" },
|
||||||
|
{"host": "localhost", "user": "root", "passwd": "", "db": "test_pymysql", "local_infile": true}
|
||||||
]
|
]
|
||||||
|
|||||||
@@ -2,7 +2,6 @@
|
|||||||
# http://dev.mysql.com/doc/internals/en/client-server-protocol.html
|
# http://dev.mysql.com/doc/internals/en/client-server-protocol.html
|
||||||
# Error codes:
|
# Error codes:
|
||||||
# http://dev.mysql.com/doc/refman/5.5/en/error-messages-client.html
|
# http://dev.mysql.com/doc/refman/5.5/en/error-messages-client.html
|
||||||
|
|
||||||
from __future__ import print_function
|
from __future__ import print_function
|
||||||
from ._compat import PY2, range_type, text_type, str_type, JYTHON, IRONPYTHON
|
from ._compat import PY2, range_type, text_type, str_type, JYTHON, IRONPYTHON
|
||||||
DEBUG = False
|
DEBUG = False
|
||||||
@@ -12,6 +11,7 @@ from functools import partial
|
|||||||
import hashlib
|
import hashlib
|
||||||
import io
|
import io
|
||||||
import os
|
import os
|
||||||
|
import re
|
||||||
import socket
|
import socket
|
||||||
import struct
|
import struct
|
||||||
import sys
|
import sys
|
||||||
@@ -334,6 +334,9 @@ class MysqlPacket(object):
|
|||||||
field_count = ord(self._data[0:1])
|
field_count = ord(self._data[0:1])
|
||||||
return 1 <= field_count <= 250
|
return 1 <= field_count <= 250
|
||||||
|
|
||||||
|
def is_load_local_packet(self):
|
||||||
|
return self._data[0:1] == b'\xfb'
|
||||||
|
|
||||||
def is_error_packet(self):
|
def is_error_packet(self):
|
||||||
return self._data[0:1] == b'\xff'
|
return self._data[0:1] == b'\xff'
|
||||||
|
|
||||||
@@ -454,6 +457,27 @@ class EOFPacketWrapper(object):
|
|||||||
return getattr(self.packet, key)
|
return getattr(self.packet, key)
|
||||||
|
|
||||||
|
|
||||||
|
class LoadLocalPacketWrapper(object):
|
||||||
|
"""
|
||||||
|
Load Local Packet Wrapper. It uses an existing packet object, and wraps
|
||||||
|
around it, exposing useful variables while still providing access
|
||||||
|
to the original packet objects variables and methods.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, from_packet):
|
||||||
|
if not from_packet.is_load_local_packet():
|
||||||
|
raise ValueError(
|
||||||
|
"Cannot create '{0}' object from invalid packet type".format(
|
||||||
|
self.__class__))
|
||||||
|
|
||||||
|
self.packet = from_packet
|
||||||
|
self.filename = self.packet.get_all_data()[1:]
|
||||||
|
if DEBUG: print("filename=", self.filename)
|
||||||
|
|
||||||
|
def __getattr__(self, key):
|
||||||
|
return getattr(self.packet, key)
|
||||||
|
|
||||||
|
|
||||||
class Connection(object):
|
class Connection(object):
|
||||||
"""
|
"""
|
||||||
Representation of a socket with a mysql server.
|
Representation of a socket with a mysql server.
|
||||||
@@ -472,7 +496,7 @@ class Connection(object):
|
|||||||
client_flag=0, cursorclass=Cursor, init_command=None,
|
client_flag=0, cursorclass=Cursor, init_command=None,
|
||||||
connect_timeout=None, ssl=None, read_default_group=None,
|
connect_timeout=None, ssl=None, read_default_group=None,
|
||||||
compress=None, named_pipe=None, no_delay=False,
|
compress=None, named_pipe=None, no_delay=False,
|
||||||
autocommit=False, db=None, passwd=None):
|
autocommit=False, db=None, passwd=None, local_infile=False):
|
||||||
"""
|
"""
|
||||||
Establish a connection to the MySQL database. Accepts several
|
Establish a connection to the MySQL database. Accepts several
|
||||||
arguments:
|
arguments:
|
||||||
@@ -505,6 +529,7 @@ class Connection(object):
|
|||||||
named_pipe: Not supported
|
named_pipe: Not supported
|
||||||
no_delay: Disable Nagle's algorithm on the socket
|
no_delay: Disable Nagle's algorithm on the socket
|
||||||
autocommit: Autocommit mode. None means use server default. (default: False)
|
autocommit: Autocommit mode. None means use server default. (default: False)
|
||||||
|
local_infile: Boolean to enable the use of LOAD DATA LOCAL command. (default: False)
|
||||||
|
|
||||||
db: Alias for database. (for compatibility to MySQLdb)
|
db: Alias for database. (for compatibility to MySQLdb)
|
||||||
passwd: Alias for password. (for compatibility to MySQLdb)
|
passwd: Alias for password. (for compatibility to MySQLdb)
|
||||||
@@ -521,6 +546,9 @@ class Connection(object):
|
|||||||
if compress or named_pipe:
|
if compress or named_pipe:
|
||||||
raise NotImplementedError("compress and named_pipe arguments are not supported")
|
raise NotImplementedError("compress and named_pipe arguments are not supported")
|
||||||
|
|
||||||
|
if local_infile:
|
||||||
|
client_flag |= CLIENT.LOCAL_FILES
|
||||||
|
|
||||||
if ssl and ('capath' in ssl or 'cipher' in ssl):
|
if ssl and ('capath' in ssl or 'cipher' in ssl):
|
||||||
raise NotImplementedError('ssl options capath and cipher are not supported')
|
raise NotImplementedError('ssl options capath and cipher are not supported')
|
||||||
|
|
||||||
@@ -1057,6 +1085,7 @@ class MySQLResult(object):
|
|||||||
self.rows = None
|
self.rows = None
|
||||||
self.has_next = None
|
self.has_next = None
|
||||||
self.unbuffered_active = False
|
self.unbuffered_active = False
|
||||||
|
self.filename = None
|
||||||
|
|
||||||
def __del__(self):
|
def __del__(self):
|
||||||
if self.unbuffered_active:
|
if self.unbuffered_active:
|
||||||
@@ -1069,6 +1098,8 @@ class MySQLResult(object):
|
|||||||
# TODO: use classes for different packet types?
|
# TODO: use classes for different packet types?
|
||||||
if first_packet.is_ok_packet():
|
if first_packet.is_ok_packet():
|
||||||
self._read_ok_packet(first_packet)
|
self._read_ok_packet(first_packet)
|
||||||
|
elif first_packet.is_load_local_packet():
|
||||||
|
self._read_load_local_packet(first_packet)
|
||||||
else:
|
else:
|
||||||
self._read_result_packet(first_packet)
|
self._read_result_packet(first_packet)
|
||||||
finally:
|
finally:
|
||||||
@@ -1100,6 +1131,21 @@ class MySQLResult(object):
|
|||||||
self.message = ok_packet.message
|
self.message = ok_packet.message
|
||||||
self.has_next = ok_packet.has_next
|
self.has_next = ok_packet.has_next
|
||||||
|
|
||||||
|
def _read_load_local_packet(self, first_packet):
|
||||||
|
load_packet = LoadLocalPacketWrapper(first_packet)
|
||||||
|
local_packet = LoadLocalFile(load_packet.filename, self.connection)
|
||||||
|
self.filename = load_packet.filename
|
||||||
|
local_packet.send_data()
|
||||||
|
|
||||||
|
ok_packet = self.connection._read_packet()
|
||||||
|
if not ok_packet.is_ok_packet():
|
||||||
|
raise OperationalError(2014, "Commands Out of Sync")
|
||||||
|
self._read_ok_packet(ok_packet)
|
||||||
|
|
||||||
|
if self.warning_count > 0:
|
||||||
|
self._print_warnings()
|
||||||
|
self.filename = None
|
||||||
|
|
||||||
def _check_packet_is_eof(self, packet):
|
def _check_packet_is_eof(self, packet):
|
||||||
if packet.is_eof_packet():
|
if packet.is_eof_packet():
|
||||||
eof_packet = EOFPacketWrapper(packet)
|
eof_packet = EOFPacketWrapper(packet)
|
||||||
@@ -1108,6 +1154,16 @@ class MySQLResult(object):
|
|||||||
return True
|
return True
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
def _print_warnings(self):
|
||||||
|
from warnings import warn
|
||||||
|
self.connection._execute_command(COMMAND.COM_QUERY, 'SHOW WARNINGS')
|
||||||
|
self.read()
|
||||||
|
if self.rows:
|
||||||
|
message = "\n"
|
||||||
|
for db_warning in self.rows:
|
||||||
|
message += "{0} in file '{1}'\n".format(db_warning[2], self.filename.decode('utf-8'))
|
||||||
|
warn(message, Warning, 3)
|
||||||
|
|
||||||
def _read_result_packet(self, first_packet):
|
def _read_result_packet(self, first_packet):
|
||||||
self.field_count = first_packet.read_length_encoded_integer()
|
self.field_count = first_packet.read_length_encoded_integer()
|
||||||
self._get_descriptions()
|
self._get_descriptions()
|
||||||
@@ -1191,4 +1247,40 @@ class MySQLResult(object):
|
|||||||
assert eof_packet.is_eof_packet(), 'Protocol error, expecting EOF'
|
assert eof_packet.is_eof_packet(), 'Protocol error, expecting EOF'
|
||||||
self.description = tuple(description)
|
self.description = tuple(description)
|
||||||
|
|
||||||
|
|
||||||
|
class LoadLocalFile(object):
|
||||||
|
def __init__(self, filename, connection):
|
||||||
|
self.filename = filename
|
||||||
|
self.connection = connection
|
||||||
|
|
||||||
|
def send_data(self):
|
||||||
|
"""Send data packets from the local file to the server"""
|
||||||
|
if not self.connection.socket:
|
||||||
|
raise InterfaceError("(0, '')")
|
||||||
|
|
||||||
|
# sequence id is 2 as we already sent a query packet
|
||||||
|
seq_id = 2
|
||||||
|
try:
|
||||||
|
with open(self.filename, 'rb') as open_file:
|
||||||
|
chunk_size = MAX_PACKET_LEN
|
||||||
|
prelude = b""
|
||||||
|
packet = b""
|
||||||
|
packet_size = 0
|
||||||
|
|
||||||
|
while True:
|
||||||
|
chunk = open_file.read(chunk_size)
|
||||||
|
if not chunk:
|
||||||
|
break
|
||||||
|
packet = struct.pack('<i', len(chunk))[:3] + int2byte(seq_id)
|
||||||
|
format_str = '!{0}s'.format(len(chunk))
|
||||||
|
packet += struct.pack(format_str, chunk)
|
||||||
|
self.connection._write_bytes(packet)
|
||||||
|
seq_id += 1
|
||||||
|
except IOError:
|
||||||
|
raise OperationalError(1017, "Can't find file '{0}'".format(self.filename))
|
||||||
|
finally:
|
||||||
|
# send the empty packet to signify we are done sending data
|
||||||
|
packet = struct.pack('<i', 0)[:3] + int2byte(seq_id)
|
||||||
|
self.connection._write_bytes(packet)
|
||||||
|
|
||||||
# g:khuno_ignore='E226,E301,E701'
|
# g:khuno_ignore='E226,E301,E701'
|
||||||
|
|||||||
@@ -4,6 +4,7 @@ from pymysql.tests.test_nextset import *
|
|||||||
from pymysql.tests.test_DictCursor import *
|
from pymysql.tests.test_DictCursor import *
|
||||||
from pymysql.tests.test_connection import TestConnection
|
from pymysql.tests.test_connection import TestConnection
|
||||||
from pymysql.tests.test_SSCursor import *
|
from pymysql.tests.test_SSCursor import *
|
||||||
|
from pymysql.tests.test_load_local import *
|
||||||
|
|
||||||
from pymysql.tests.thirdparty import *
|
from pymysql.tests.thirdparty import *
|
||||||
|
|
||||||
|
|||||||
22749
pymysql/tests/data/load_local_data.txt
Normal file
22749
pymysql/tests/data/load_local_data.txt
Normal file
File diff suppressed because it is too large
Load Diff
50
pymysql/tests/data/load_local_warn_data.txt
Normal file
50
pymysql/tests/data/load_local_warn_data.txt
Normal file
@@ -0,0 +1,50 @@
|
|||||||
|
1,2,
|
||||||
|
3,4,
|
||||||
|
5,6,
|
||||||
|
7,8,
|
||||||
|
1,2,
|
||||||
|
3,4,
|
||||||
|
5,6,
|
||||||
|
,8,
|
||||||
|
1,2,
|
||||||
|
3,4,
|
||||||
|
5,6,
|
||||||
|
7,8,
|
||||||
|
1,2,
|
||||||
|
3,4,
|
||||||
|
5,6,
|
||||||
|
7,8,
|
||||||
|
1,2,
|
||||||
|
3,4,
|
||||||
|
5,6,
|
||||||
|
7,8,
|
||||||
|
1,2,
|
||||||
|
3,4,
|
||||||
|
5,6,
|
||||||
|
7,8,
|
||||||
|
1,2,
|
||||||
|
3,4,
|
||||||
|
5,6,
|
||||||
|
7,8,
|
||||||
|
1,2,
|
||||||
|
3,4,
|
||||||
|
5,6,
|
||||||
|
7,8,
|
||||||
|
1,2,
|
||||||
|
3,4,
|
||||||
|
5,6,
|
||||||
|
7,8,
|
||||||
|
1,2,
|
||||||
|
3,4,
|
||||||
|
5,6,
|
||||||
|
7,8,
|
||||||
|
1,2,
|
||||||
|
3,4,
|
||||||
|
5,6,
|
||||||
|
7,8,
|
||||||
|
1,2,
|
||||||
|
3,4,
|
||||||
|
5,6,
|
||||||
|
7,8,
|
||||||
|
1,2,
|
||||||
|
3,4,
|
||||||
68
pymysql/tests/test_load_local.py
Normal file
68
pymysql/tests/test_load_local.py
Normal file
@@ -0,0 +1,68 @@
|
|||||||
|
from pymysql.err import OperationalError
|
||||||
|
from pymysql.tests import base
|
||||||
|
|
||||||
|
import os
|
||||||
|
|
||||||
|
__all__ = ["TestLoadLocal"]
|
||||||
|
|
||||||
|
|
||||||
|
class TestLoadLocal(base.PyMySQLTestCase):
|
||||||
|
def test_no_file(self):
|
||||||
|
"""Test load local infile when the file does not exist"""
|
||||||
|
conn = self.connections[2]
|
||||||
|
c = conn.cursor()
|
||||||
|
c.execute("CREATE TABLE test_load_local (a INTEGER, b INTEGER)")
|
||||||
|
try:
|
||||||
|
self.assertRaises(
|
||||||
|
OperationalError,
|
||||||
|
c.execute,
|
||||||
|
("LOAD DATA LOCAL INFILE 'no_data.txt' INTO TABLE "
|
||||||
|
"test_load_local fields terminated by ','")
|
||||||
|
)
|
||||||
|
finally:
|
||||||
|
c.execute("DROP TABLE test_load_local")
|
||||||
|
c.close()
|
||||||
|
|
||||||
|
def test_load_file(self):
|
||||||
|
"""Test load local infile with a valid file"""
|
||||||
|
conn = self.connections[2]
|
||||||
|
c = conn.cursor()
|
||||||
|
c.execute("CREATE TABLE test_load_local (a INTEGER, b INTEGER)")
|
||||||
|
filename = os.path.join(os.path.dirname(os.path.realpath(__file__)),
|
||||||
|
'data',
|
||||||
|
'load_local_data.txt')
|
||||||
|
try:
|
||||||
|
c.execute(
|
||||||
|
("LOAD DATA LOCAL INFILE '{0}' INTO TABLE " +
|
||||||
|
"test_load_local FIELDS TERMINATED BY ','").format(filename)
|
||||||
|
)
|
||||||
|
c.execute("SELECT COUNT(*) FROM test_load_local")
|
||||||
|
self.assertEqual(22749, c.fetchone()[0])
|
||||||
|
finally:
|
||||||
|
c.execute("DROP TABLE test_load_local")
|
||||||
|
|
||||||
|
def test_load_warnings(self):
|
||||||
|
"""Test load local infile produces the appropriate warnings"""
|
||||||
|
import warnings
|
||||||
|
conn = self.connections[2]
|
||||||
|
c = conn.cursor()
|
||||||
|
c.execute("CREATE TABLE test_load_local (a INTEGER, b INTEGER)")
|
||||||
|
filename = os.path.join(os.path.dirname(os.path.realpath(__file__)),
|
||||||
|
'data',
|
||||||
|
'load_local_warn_data.txt')
|
||||||
|
try:
|
||||||
|
with warnings.catch_warnings(record=True) as w:
|
||||||
|
c.execute(
|
||||||
|
("LOAD DATA LOCAL INFILE '{0}' INTO TABLE " +
|
||||||
|
"test_load_local FIELDS TERMINATED BY ','").format(filename)
|
||||||
|
)
|
||||||
|
self.assertEqual(True, "Incorrect integer value" in str(w[-1].message))
|
||||||
|
except Warning as w:
|
||||||
|
self.assertLess(0, str(w).find("Incorrect integer value"))
|
||||||
|
finally:
|
||||||
|
c.execute("DROP TABLE test_load_local")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
import unittest
|
||||||
|
unittest.main()
|
||||||
Reference in New Issue
Block a user