Merge pull request #279 from wraziens/load_local
Implement load data local infile command. Resolves #62
This commit is contained in:
@@ -1,4 +1,5 @@
|
||||
[
|
||||
{"host": "localhost", "user": "root", "passwd": "", "db": "test_pymysql", "use_unicode": true},
|
||||
{"host": "localhost", "user": "root", "passwd": "", "db": "test_pymysql2" }
|
||||
{"host": "localhost", "user": "root", "passwd": "", "db": "test_pymysql2" },
|
||||
{"host": "localhost", "user": "root", "passwd": "", "db": "test_pymysql", "local_infile": true}
|
||||
]
|
||||
|
||||
@@ -2,7 +2,6 @@
|
||||
# http://dev.mysql.com/doc/internals/en/client-server-protocol.html
|
||||
# Error codes:
|
||||
# http://dev.mysql.com/doc/refman/5.5/en/error-messages-client.html
|
||||
|
||||
from __future__ import print_function
|
||||
from ._compat import PY2, range_type, text_type, str_type, JYTHON, IRONPYTHON
|
||||
DEBUG = False
|
||||
@@ -12,6 +11,7 @@ from functools import partial
|
||||
import hashlib
|
||||
import io
|
||||
import os
|
||||
import re
|
||||
import socket
|
||||
import struct
|
||||
import sys
|
||||
@@ -334,6 +334,9 @@ class MysqlPacket(object):
|
||||
field_count = ord(self._data[0:1])
|
||||
return 1 <= field_count <= 250
|
||||
|
||||
def is_load_local_packet(self):
|
||||
return self._data[0:1] == b'\xfb'
|
||||
|
||||
def is_error_packet(self):
|
||||
return self._data[0:1] == b'\xff'
|
||||
|
||||
@@ -454,6 +457,27 @@ class EOFPacketWrapper(object):
|
||||
return getattr(self.packet, key)
|
||||
|
||||
|
||||
class LoadLocalPacketWrapper(object):
|
||||
"""
|
||||
Load Local Packet Wrapper. It uses an existing packet object, and wraps
|
||||
around it, exposing useful variables while still providing access
|
||||
to the original packet objects variables and methods.
|
||||
"""
|
||||
|
||||
def __init__(self, from_packet):
|
||||
if not from_packet.is_load_local_packet():
|
||||
raise ValueError(
|
||||
"Cannot create '{0}' object from invalid packet type".format(
|
||||
self.__class__))
|
||||
|
||||
self.packet = from_packet
|
||||
self.filename = self.packet.get_all_data()[1:]
|
||||
if DEBUG: print("filename=", self.filename)
|
||||
|
||||
def __getattr__(self, key):
|
||||
return getattr(self.packet, key)
|
||||
|
||||
|
||||
class Connection(object):
|
||||
"""
|
||||
Representation of a socket with a mysql server.
|
||||
@@ -472,7 +496,7 @@ class Connection(object):
|
||||
client_flag=0, cursorclass=Cursor, init_command=None,
|
||||
connect_timeout=None, ssl=None, read_default_group=None,
|
||||
compress=None, named_pipe=None, no_delay=False,
|
||||
autocommit=False, db=None, passwd=None):
|
||||
autocommit=False, db=None, passwd=None, local_infile=False):
|
||||
"""
|
||||
Establish a connection to the MySQL database. Accepts several
|
||||
arguments:
|
||||
@@ -505,6 +529,7 @@ class Connection(object):
|
||||
named_pipe: Not supported
|
||||
no_delay: Disable Nagle's algorithm on the socket
|
||||
autocommit: Autocommit mode. None means use server default. (default: False)
|
||||
local_infile: Boolean to enable the use of LOAD DATA LOCAL command. (default: False)
|
||||
|
||||
db: Alias for database. (for compatibility to MySQLdb)
|
||||
passwd: Alias for password. (for compatibility to MySQLdb)
|
||||
@@ -521,6 +546,9 @@ class Connection(object):
|
||||
if compress or named_pipe:
|
||||
raise NotImplementedError("compress and named_pipe arguments are not supported")
|
||||
|
||||
if local_infile:
|
||||
client_flag |= CLIENT.LOCAL_FILES
|
||||
|
||||
if ssl and ('capath' in ssl or 'cipher' in ssl):
|
||||
raise NotImplementedError('ssl options capath and cipher are not supported')
|
||||
|
||||
@@ -1057,6 +1085,7 @@ class MySQLResult(object):
|
||||
self.rows = None
|
||||
self.has_next = None
|
||||
self.unbuffered_active = False
|
||||
self.filename = None
|
||||
|
||||
def __del__(self):
|
||||
if self.unbuffered_active:
|
||||
@@ -1069,6 +1098,8 @@ class MySQLResult(object):
|
||||
# TODO: use classes for different packet types?
|
||||
if first_packet.is_ok_packet():
|
||||
self._read_ok_packet(first_packet)
|
||||
elif first_packet.is_load_local_packet():
|
||||
self._read_load_local_packet(first_packet)
|
||||
else:
|
||||
self._read_result_packet(first_packet)
|
||||
finally:
|
||||
@@ -1100,6 +1131,21 @@ class MySQLResult(object):
|
||||
self.message = ok_packet.message
|
||||
self.has_next = ok_packet.has_next
|
||||
|
||||
def _read_load_local_packet(self, first_packet):
|
||||
load_packet = LoadLocalPacketWrapper(first_packet)
|
||||
local_packet = LoadLocalFile(load_packet.filename, self.connection)
|
||||
self.filename = load_packet.filename
|
||||
local_packet.send_data()
|
||||
|
||||
ok_packet = self.connection._read_packet()
|
||||
if not ok_packet.is_ok_packet():
|
||||
raise OperationalError(2014, "Commands Out of Sync")
|
||||
self._read_ok_packet(ok_packet)
|
||||
|
||||
if self.warning_count > 0:
|
||||
self._print_warnings()
|
||||
self.filename = None
|
||||
|
||||
def _check_packet_is_eof(self, packet):
|
||||
if packet.is_eof_packet():
|
||||
eof_packet = EOFPacketWrapper(packet)
|
||||
@@ -1108,6 +1154,16 @@ class MySQLResult(object):
|
||||
return True
|
||||
return False
|
||||
|
||||
def _print_warnings(self):
|
||||
from warnings import warn
|
||||
self.connection._execute_command(COMMAND.COM_QUERY, 'SHOW WARNINGS')
|
||||
self.read()
|
||||
if self.rows:
|
||||
message = "\n"
|
||||
for db_warning in self.rows:
|
||||
message += "{0} in file '{1}'\n".format(db_warning[2], self.filename.decode('utf-8'))
|
||||
warn(message, Warning, 3)
|
||||
|
||||
def _read_result_packet(self, first_packet):
|
||||
self.field_count = first_packet.read_length_encoded_integer()
|
||||
self._get_descriptions()
|
||||
@@ -1191,4 +1247,40 @@ class MySQLResult(object):
|
||||
assert eof_packet.is_eof_packet(), 'Protocol error, expecting EOF'
|
||||
self.description = tuple(description)
|
||||
|
||||
|
||||
class LoadLocalFile(object):
|
||||
def __init__(self, filename, connection):
|
||||
self.filename = filename
|
||||
self.connection = connection
|
||||
|
||||
def send_data(self):
|
||||
"""Send data packets from the local file to the server"""
|
||||
if not self.connection.socket:
|
||||
raise InterfaceError("(0, '')")
|
||||
|
||||
# sequence id is 2 as we already sent a query packet
|
||||
seq_id = 2
|
||||
try:
|
||||
with open(self.filename, 'rb') as open_file:
|
||||
chunk_size = MAX_PACKET_LEN
|
||||
prelude = b""
|
||||
packet = b""
|
||||
packet_size = 0
|
||||
|
||||
while True:
|
||||
chunk = open_file.read(chunk_size)
|
||||
if not chunk:
|
||||
break
|
||||
packet = struct.pack('<i', len(chunk))[:3] + int2byte(seq_id)
|
||||
format_str = '!{0}s'.format(len(chunk))
|
||||
packet += struct.pack(format_str, chunk)
|
||||
self.connection._write_bytes(packet)
|
||||
seq_id += 1
|
||||
except IOError:
|
||||
raise OperationalError(1017, "Can't find file '{0}'".format(self.filename))
|
||||
finally:
|
||||
# send the empty packet to signify we are done sending data
|
||||
packet = struct.pack('<i', 0)[:3] + int2byte(seq_id)
|
||||
self.connection._write_bytes(packet)
|
||||
|
||||
# g:khuno_ignore='E226,E301,E701'
|
||||
|
||||
@@ -4,6 +4,7 @@ from pymysql.tests.test_nextset import *
|
||||
from pymysql.tests.test_DictCursor import *
|
||||
from pymysql.tests.test_connection import TestConnection
|
||||
from pymysql.tests.test_SSCursor import *
|
||||
from pymysql.tests.test_load_local import *
|
||||
|
||||
from pymysql.tests.thirdparty import *
|
||||
|
||||
|
||||
22749
pymysql/tests/data/load_local_data.txt
Normal file
22749
pymysql/tests/data/load_local_data.txt
Normal file
File diff suppressed because it is too large
Load Diff
50
pymysql/tests/data/load_local_warn_data.txt
Normal file
50
pymysql/tests/data/load_local_warn_data.txt
Normal file
@@ -0,0 +1,50 @@
|
||||
1,2,
|
||||
3,4,
|
||||
5,6,
|
||||
7,8,
|
||||
1,2,
|
||||
3,4,
|
||||
5,6,
|
||||
,8,
|
||||
1,2,
|
||||
3,4,
|
||||
5,6,
|
||||
7,8,
|
||||
1,2,
|
||||
3,4,
|
||||
5,6,
|
||||
7,8,
|
||||
1,2,
|
||||
3,4,
|
||||
5,6,
|
||||
7,8,
|
||||
1,2,
|
||||
3,4,
|
||||
5,6,
|
||||
7,8,
|
||||
1,2,
|
||||
3,4,
|
||||
5,6,
|
||||
7,8,
|
||||
1,2,
|
||||
3,4,
|
||||
5,6,
|
||||
7,8,
|
||||
1,2,
|
||||
3,4,
|
||||
5,6,
|
||||
7,8,
|
||||
1,2,
|
||||
3,4,
|
||||
5,6,
|
||||
7,8,
|
||||
1,2,
|
||||
3,4,
|
||||
5,6,
|
||||
7,8,
|
||||
1,2,
|
||||
3,4,
|
||||
5,6,
|
||||
7,8,
|
||||
1,2,
|
||||
3,4,
|
||||
68
pymysql/tests/test_load_local.py
Normal file
68
pymysql/tests/test_load_local.py
Normal file
@@ -0,0 +1,68 @@
|
||||
from pymysql.err import OperationalError
|
||||
from pymysql.tests import base
|
||||
|
||||
import os
|
||||
|
||||
__all__ = ["TestLoadLocal"]
|
||||
|
||||
|
||||
class TestLoadLocal(base.PyMySQLTestCase):
|
||||
def test_no_file(self):
|
||||
"""Test load local infile when the file does not exist"""
|
||||
conn = self.connections[2]
|
||||
c = conn.cursor()
|
||||
c.execute("CREATE TABLE test_load_local (a INTEGER, b INTEGER)")
|
||||
try:
|
||||
self.assertRaises(
|
||||
OperationalError,
|
||||
c.execute,
|
||||
("LOAD DATA LOCAL INFILE 'no_data.txt' INTO TABLE "
|
||||
"test_load_local fields terminated by ','")
|
||||
)
|
||||
finally:
|
||||
c.execute("DROP TABLE test_load_local")
|
||||
c.close()
|
||||
|
||||
def test_load_file(self):
|
||||
"""Test load local infile with a valid file"""
|
||||
conn = self.connections[2]
|
||||
c = conn.cursor()
|
||||
c.execute("CREATE TABLE test_load_local (a INTEGER, b INTEGER)")
|
||||
filename = os.path.join(os.path.dirname(os.path.realpath(__file__)),
|
||||
'data',
|
||||
'load_local_data.txt')
|
||||
try:
|
||||
c.execute(
|
||||
("LOAD DATA LOCAL INFILE '{0}' INTO TABLE " +
|
||||
"test_load_local FIELDS TERMINATED BY ','").format(filename)
|
||||
)
|
||||
c.execute("SELECT COUNT(*) FROM test_load_local")
|
||||
self.assertEqual(22749, c.fetchone()[0])
|
||||
finally:
|
||||
c.execute("DROP TABLE test_load_local")
|
||||
|
||||
def test_load_warnings(self):
|
||||
"""Test load local infile produces the appropriate warnings"""
|
||||
import warnings
|
||||
conn = self.connections[2]
|
||||
c = conn.cursor()
|
||||
c.execute("CREATE TABLE test_load_local (a INTEGER, b INTEGER)")
|
||||
filename = os.path.join(os.path.dirname(os.path.realpath(__file__)),
|
||||
'data',
|
||||
'load_local_warn_data.txt')
|
||||
try:
|
||||
with warnings.catch_warnings(record=True) as w:
|
||||
c.execute(
|
||||
("LOAD DATA LOCAL INFILE '{0}' INTO TABLE " +
|
||||
"test_load_local FIELDS TERMINATED BY ','").format(filename)
|
||||
)
|
||||
self.assertEqual(True, "Incorrect integer value" in str(w[-1].message))
|
||||
except Warning as w:
|
||||
self.assertLess(0, str(w).find("Incorrect integer value"))
|
||||
finally:
|
||||
c.execute("DROP TABLE test_load_local")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import unittest
|
||||
unittest.main()
|
||||
Reference in New Issue
Block a user