Merge pull request #279 from wraziens/load_local

Implement load data local infile command. Resolves #62
This commit is contained in:
INADA Naoki
2015-01-14 11:03:59 +09:00
6 changed files with 22964 additions and 3 deletions

View File

@@ -1,4 +1,5 @@
[
{"host": "localhost", "user": "root", "passwd": "", "db": "test_pymysql", "use_unicode": true},
{"host": "localhost", "user": "root", "passwd": "", "db": "test_pymysql2" }
{"host": "localhost", "user": "root", "passwd": "", "db": "test_pymysql2" },
{"host": "localhost", "user": "root", "passwd": "", "db": "test_pymysql", "local_infile": true}
]

View File

@@ -2,7 +2,6 @@
# http://dev.mysql.com/doc/internals/en/client-server-protocol.html
# Error codes:
# http://dev.mysql.com/doc/refman/5.5/en/error-messages-client.html
from __future__ import print_function
from ._compat import PY2, range_type, text_type, str_type, JYTHON, IRONPYTHON
DEBUG = False
@@ -12,6 +11,7 @@ from functools import partial
import hashlib
import io
import os
import re
import socket
import struct
import sys
@@ -334,6 +334,9 @@ class MysqlPacket(object):
field_count = ord(self._data[0:1])
return 1 <= field_count <= 250
def is_load_local_packet(self):
return self._data[0:1] == b'\xfb'
def is_error_packet(self):
return self._data[0:1] == b'\xff'
@@ -454,6 +457,27 @@ class EOFPacketWrapper(object):
return getattr(self.packet, key)
class LoadLocalPacketWrapper(object):
"""
Load Local Packet Wrapper. It uses an existing packet object, and wraps
around it, exposing useful variables while still providing access
to the original packet objects variables and methods.
"""
def __init__(self, from_packet):
if not from_packet.is_load_local_packet():
raise ValueError(
"Cannot create '{0}' object from invalid packet type".format(
self.__class__))
self.packet = from_packet
self.filename = self.packet.get_all_data()[1:]
if DEBUG: print("filename=", self.filename)
def __getattr__(self, key):
return getattr(self.packet, key)
class Connection(object):
"""
Representation of a socket with a mysql server.
@@ -472,7 +496,7 @@ class Connection(object):
client_flag=0, cursorclass=Cursor, init_command=None,
connect_timeout=None, ssl=None, read_default_group=None,
compress=None, named_pipe=None, no_delay=False,
autocommit=False, db=None, passwd=None):
autocommit=False, db=None, passwd=None, local_infile=False):
"""
Establish a connection to the MySQL database. Accepts several
arguments:
@@ -505,6 +529,7 @@ class Connection(object):
named_pipe: Not supported
no_delay: Disable Nagle's algorithm on the socket
autocommit: Autocommit mode. None means use server default. (default: False)
local_infile: Boolean to enable the use of LOAD DATA LOCAL command. (default: False)
db: Alias for database. (for compatibility to MySQLdb)
passwd: Alias for password. (for compatibility to MySQLdb)
@@ -521,6 +546,9 @@ class Connection(object):
if compress or named_pipe:
raise NotImplementedError("compress and named_pipe arguments are not supported")
if local_infile:
client_flag |= CLIENT.LOCAL_FILES
if ssl and ('capath' in ssl or 'cipher' in ssl):
raise NotImplementedError('ssl options capath and cipher are not supported')
@@ -1057,6 +1085,7 @@ class MySQLResult(object):
self.rows = None
self.has_next = None
self.unbuffered_active = False
self.filename = None
def __del__(self):
if self.unbuffered_active:
@@ -1069,6 +1098,8 @@ class MySQLResult(object):
# TODO: use classes for different packet types?
if first_packet.is_ok_packet():
self._read_ok_packet(first_packet)
elif first_packet.is_load_local_packet():
self._read_load_local_packet(first_packet)
else:
self._read_result_packet(first_packet)
finally:
@@ -1100,6 +1131,21 @@ class MySQLResult(object):
self.message = ok_packet.message
self.has_next = ok_packet.has_next
def _read_load_local_packet(self, first_packet):
load_packet = LoadLocalPacketWrapper(first_packet)
local_packet = LoadLocalFile(load_packet.filename, self.connection)
self.filename = load_packet.filename
local_packet.send_data()
ok_packet = self.connection._read_packet()
if not ok_packet.is_ok_packet():
raise OperationalError(2014, "Commands Out of Sync")
self._read_ok_packet(ok_packet)
if self.warning_count > 0:
self._print_warnings()
self.filename = None
def _check_packet_is_eof(self, packet):
if packet.is_eof_packet():
eof_packet = EOFPacketWrapper(packet)
@@ -1108,6 +1154,16 @@ class MySQLResult(object):
return True
return False
def _print_warnings(self):
from warnings import warn
self.connection._execute_command(COMMAND.COM_QUERY, 'SHOW WARNINGS')
self.read()
if self.rows:
message = "\n"
for db_warning in self.rows:
message += "{0} in file '{1}'\n".format(db_warning[2], self.filename.decode('utf-8'))
warn(message, Warning, 3)
def _read_result_packet(self, first_packet):
self.field_count = first_packet.read_length_encoded_integer()
self._get_descriptions()
@@ -1191,4 +1247,40 @@ class MySQLResult(object):
assert eof_packet.is_eof_packet(), 'Protocol error, expecting EOF'
self.description = tuple(description)
class LoadLocalFile(object):
def __init__(self, filename, connection):
self.filename = filename
self.connection = connection
def send_data(self):
"""Send data packets from the local file to the server"""
if not self.connection.socket:
raise InterfaceError("(0, '')")
# sequence id is 2 as we already sent a query packet
seq_id = 2
try:
with open(self.filename, 'rb') as open_file:
chunk_size = MAX_PACKET_LEN
prelude = b""
packet = b""
packet_size = 0
while True:
chunk = open_file.read(chunk_size)
if not chunk:
break
packet = struct.pack('<i', len(chunk))[:3] + int2byte(seq_id)
format_str = '!{0}s'.format(len(chunk))
packet += struct.pack(format_str, chunk)
self.connection._write_bytes(packet)
seq_id += 1
except IOError:
raise OperationalError(1017, "Can't find file '{0}'".format(self.filename))
finally:
# send the empty packet to signify we are done sending data
packet = struct.pack('<i', 0)[:3] + int2byte(seq_id)
self.connection._write_bytes(packet)
# g:khuno_ignore='E226,E301,E701'

View File

@@ -4,6 +4,7 @@ from pymysql.tests.test_nextset import *
from pymysql.tests.test_DictCursor import *
from pymysql.tests.test_connection import TestConnection
from pymysql.tests.test_SSCursor import *
from pymysql.tests.test_load_local import *
from pymysql.tests.thirdparty import *

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,50 @@
1,2,
3,4,
5,6,
7,8,
1,2,
3,4,
5,6,
,8,
1,2,
3,4,
5,6,
7,8,
1,2,
3,4,
5,6,
7,8,
1,2,
3,4,
5,6,
7,8,
1,2,
3,4,
5,6,
7,8,
1,2,
3,4,
5,6,
7,8,
1,2,
3,4,
5,6,
7,8,
1,2,
3,4,
5,6,
7,8,
1,2,
3,4,
5,6,
7,8,
1,2,
3,4,
5,6,
7,8,
1,2,
3,4,
5,6,
7,8,
1,2,
3,4,

View File

@@ -0,0 +1,68 @@
from pymysql.err import OperationalError
from pymysql.tests import base
import os
__all__ = ["TestLoadLocal"]
class TestLoadLocal(base.PyMySQLTestCase):
def test_no_file(self):
"""Test load local infile when the file does not exist"""
conn = self.connections[2]
c = conn.cursor()
c.execute("CREATE TABLE test_load_local (a INTEGER, b INTEGER)")
try:
self.assertRaises(
OperationalError,
c.execute,
("LOAD DATA LOCAL INFILE 'no_data.txt' INTO TABLE "
"test_load_local fields terminated by ','")
)
finally:
c.execute("DROP TABLE test_load_local")
c.close()
def test_load_file(self):
"""Test load local infile with a valid file"""
conn = self.connections[2]
c = conn.cursor()
c.execute("CREATE TABLE test_load_local (a INTEGER, b INTEGER)")
filename = os.path.join(os.path.dirname(os.path.realpath(__file__)),
'data',
'load_local_data.txt')
try:
c.execute(
("LOAD DATA LOCAL INFILE '{0}' INTO TABLE " +
"test_load_local FIELDS TERMINATED BY ','").format(filename)
)
c.execute("SELECT COUNT(*) FROM test_load_local")
self.assertEqual(22749, c.fetchone()[0])
finally:
c.execute("DROP TABLE test_load_local")
def test_load_warnings(self):
"""Test load local infile produces the appropriate warnings"""
import warnings
conn = self.connections[2]
c = conn.cursor()
c.execute("CREATE TABLE test_load_local (a INTEGER, b INTEGER)")
filename = os.path.join(os.path.dirname(os.path.realpath(__file__)),
'data',
'load_local_warn_data.txt')
try:
with warnings.catch_warnings(record=True) as w:
c.execute(
("LOAD DATA LOCAL INFILE '{0}' INTO TABLE " +
"test_load_local FIELDS TERMINATED BY ','").format(filename)
)
self.assertEqual(True, "Incorrect integer value" in str(w[-1].message))
except Warning as w:
self.assertLess(0, str(w).find("Incorrect integer value"))
finally:
c.execute("DROP TABLE test_load_local")
if __name__ == "__main__":
import unittest
unittest.main()