Merge pull request #279 from wraziens/load_local

Implement load data local infile command. Resolves #62
This commit is contained in:
INADA Naoki
2015-01-14 11:03:59 +09:00
6 changed files with 22964 additions and 3 deletions

View File

@@ -1,4 +1,5 @@
[ [
{"host": "localhost", "user": "root", "passwd": "", "db": "test_pymysql", "use_unicode": true}, {"host": "localhost", "user": "root", "passwd": "", "db": "test_pymysql", "use_unicode": true},
{"host": "localhost", "user": "root", "passwd": "", "db": "test_pymysql2" } {"host": "localhost", "user": "root", "passwd": "", "db": "test_pymysql2" },
{"host": "localhost", "user": "root", "passwd": "", "db": "test_pymysql", "local_infile": true}
] ]

View File

@@ -2,7 +2,6 @@
# http://dev.mysql.com/doc/internals/en/client-server-protocol.html # http://dev.mysql.com/doc/internals/en/client-server-protocol.html
# Error codes: # Error codes:
# http://dev.mysql.com/doc/refman/5.5/en/error-messages-client.html # http://dev.mysql.com/doc/refman/5.5/en/error-messages-client.html
from __future__ import print_function from __future__ import print_function
from ._compat import PY2, range_type, text_type, str_type, JYTHON, IRONPYTHON from ._compat import PY2, range_type, text_type, str_type, JYTHON, IRONPYTHON
DEBUG = False DEBUG = False
@@ -12,6 +11,7 @@ from functools import partial
import hashlib import hashlib
import io import io
import os import os
import re
import socket import socket
import struct import struct
import sys import sys
@@ -334,6 +334,9 @@ class MysqlPacket(object):
field_count = ord(self._data[0:1]) field_count = ord(self._data[0:1])
return 1 <= field_count <= 250 return 1 <= field_count <= 250
def is_load_local_packet(self):
return self._data[0:1] == b'\xfb'
def is_error_packet(self): def is_error_packet(self):
return self._data[0:1] == b'\xff' return self._data[0:1] == b'\xff'
@@ -454,6 +457,27 @@ class EOFPacketWrapper(object):
return getattr(self.packet, key) return getattr(self.packet, key)
class LoadLocalPacketWrapper(object):
"""
Load Local Packet Wrapper. It uses an existing packet object, and wraps
around it, exposing useful variables while still providing access
to the original packet objects variables and methods.
"""
def __init__(self, from_packet):
if not from_packet.is_load_local_packet():
raise ValueError(
"Cannot create '{0}' object from invalid packet type".format(
self.__class__))
self.packet = from_packet
self.filename = self.packet.get_all_data()[1:]
if DEBUG: print("filename=", self.filename)
def __getattr__(self, key):
return getattr(self.packet, key)
class Connection(object): class Connection(object):
""" """
Representation of a socket with a mysql server. Representation of a socket with a mysql server.
@@ -472,7 +496,7 @@ class Connection(object):
client_flag=0, cursorclass=Cursor, init_command=None, client_flag=0, cursorclass=Cursor, init_command=None,
connect_timeout=None, ssl=None, read_default_group=None, connect_timeout=None, ssl=None, read_default_group=None,
compress=None, named_pipe=None, no_delay=False, compress=None, named_pipe=None, no_delay=False,
autocommit=False, db=None, passwd=None): autocommit=False, db=None, passwd=None, local_infile=False):
""" """
Establish a connection to the MySQL database. Accepts several Establish a connection to the MySQL database. Accepts several
arguments: arguments:
@@ -505,6 +529,7 @@ class Connection(object):
named_pipe: Not supported named_pipe: Not supported
no_delay: Disable Nagle's algorithm on the socket no_delay: Disable Nagle's algorithm on the socket
autocommit: Autocommit mode. None means use server default. (default: False) autocommit: Autocommit mode. None means use server default. (default: False)
local_infile: Boolean to enable the use of LOAD DATA LOCAL command. (default: False)
db: Alias for database. (for compatibility to MySQLdb) db: Alias for database. (for compatibility to MySQLdb)
passwd: Alias for password. (for compatibility to MySQLdb) passwd: Alias for password. (for compatibility to MySQLdb)
@@ -521,6 +546,9 @@ class Connection(object):
if compress or named_pipe: if compress or named_pipe:
raise NotImplementedError("compress and named_pipe arguments are not supported") raise NotImplementedError("compress and named_pipe arguments are not supported")
if local_infile:
client_flag |= CLIENT.LOCAL_FILES
if ssl and ('capath' in ssl or 'cipher' in ssl): if ssl and ('capath' in ssl or 'cipher' in ssl):
raise NotImplementedError('ssl options capath and cipher are not supported') raise NotImplementedError('ssl options capath and cipher are not supported')
@@ -1057,6 +1085,7 @@ class MySQLResult(object):
self.rows = None self.rows = None
self.has_next = None self.has_next = None
self.unbuffered_active = False self.unbuffered_active = False
self.filename = None
def __del__(self): def __del__(self):
if self.unbuffered_active: if self.unbuffered_active:
@@ -1069,6 +1098,8 @@ class MySQLResult(object):
# TODO: use classes for different packet types? # TODO: use classes for different packet types?
if first_packet.is_ok_packet(): if first_packet.is_ok_packet():
self._read_ok_packet(first_packet) self._read_ok_packet(first_packet)
elif first_packet.is_load_local_packet():
self._read_load_local_packet(first_packet)
else: else:
self._read_result_packet(first_packet) self._read_result_packet(first_packet)
finally: finally:
@@ -1100,6 +1131,21 @@ class MySQLResult(object):
self.message = ok_packet.message self.message = ok_packet.message
self.has_next = ok_packet.has_next self.has_next = ok_packet.has_next
def _read_load_local_packet(self, first_packet):
load_packet = LoadLocalPacketWrapper(first_packet)
local_packet = LoadLocalFile(load_packet.filename, self.connection)
self.filename = load_packet.filename
local_packet.send_data()
ok_packet = self.connection._read_packet()
if not ok_packet.is_ok_packet():
raise OperationalError(2014, "Commands Out of Sync")
self._read_ok_packet(ok_packet)
if self.warning_count > 0:
self._print_warnings()
self.filename = None
def _check_packet_is_eof(self, packet): def _check_packet_is_eof(self, packet):
if packet.is_eof_packet(): if packet.is_eof_packet():
eof_packet = EOFPacketWrapper(packet) eof_packet = EOFPacketWrapper(packet)
@@ -1108,6 +1154,16 @@ class MySQLResult(object):
return True return True
return False return False
def _print_warnings(self):
from warnings import warn
self.connection._execute_command(COMMAND.COM_QUERY, 'SHOW WARNINGS')
self.read()
if self.rows:
message = "\n"
for db_warning in self.rows:
message += "{0} in file '{1}'\n".format(db_warning[2], self.filename.decode('utf-8'))
warn(message, Warning, 3)
def _read_result_packet(self, first_packet): def _read_result_packet(self, first_packet):
self.field_count = first_packet.read_length_encoded_integer() self.field_count = first_packet.read_length_encoded_integer()
self._get_descriptions() self._get_descriptions()
@@ -1191,4 +1247,40 @@ class MySQLResult(object):
assert eof_packet.is_eof_packet(), 'Protocol error, expecting EOF' assert eof_packet.is_eof_packet(), 'Protocol error, expecting EOF'
self.description = tuple(description) self.description = tuple(description)
class LoadLocalFile(object):
def __init__(self, filename, connection):
self.filename = filename
self.connection = connection
def send_data(self):
"""Send data packets from the local file to the server"""
if not self.connection.socket:
raise InterfaceError("(0, '')")
# sequence id is 2 as we already sent a query packet
seq_id = 2
try:
with open(self.filename, 'rb') as open_file:
chunk_size = MAX_PACKET_LEN
prelude = b""
packet = b""
packet_size = 0
while True:
chunk = open_file.read(chunk_size)
if not chunk:
break
packet = struct.pack('<i', len(chunk))[:3] + int2byte(seq_id)
format_str = '!{0}s'.format(len(chunk))
packet += struct.pack(format_str, chunk)
self.connection._write_bytes(packet)
seq_id += 1
except IOError:
raise OperationalError(1017, "Can't find file '{0}'".format(self.filename))
finally:
# send the empty packet to signify we are done sending data
packet = struct.pack('<i', 0)[:3] + int2byte(seq_id)
self.connection._write_bytes(packet)
# g:khuno_ignore='E226,E301,E701' # g:khuno_ignore='E226,E301,E701'

View File

@@ -4,6 +4,7 @@ from pymysql.tests.test_nextset import *
from pymysql.tests.test_DictCursor import * from pymysql.tests.test_DictCursor import *
from pymysql.tests.test_connection import TestConnection from pymysql.tests.test_connection import TestConnection
from pymysql.tests.test_SSCursor import * from pymysql.tests.test_SSCursor import *
from pymysql.tests.test_load_local import *
from pymysql.tests.thirdparty import * from pymysql.tests.thirdparty import *

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,50 @@
1,2,
3,4,
5,6,
7,8,
1,2,
3,4,
5,6,
,8,
1,2,
3,4,
5,6,
7,8,
1,2,
3,4,
5,6,
7,8,
1,2,
3,4,
5,6,
7,8,
1,2,
3,4,
5,6,
7,8,
1,2,
3,4,
5,6,
7,8,
1,2,
3,4,
5,6,
7,8,
1,2,
3,4,
5,6,
7,8,
1,2,
3,4,
5,6,
7,8,
1,2,
3,4,
5,6,
7,8,
1,2,
3,4,
5,6,
7,8,
1,2,
3,4,

View File

@@ -0,0 +1,68 @@
from pymysql.err import OperationalError
from pymysql.tests import base
import os
__all__ = ["TestLoadLocal"]
class TestLoadLocal(base.PyMySQLTestCase):
def test_no_file(self):
"""Test load local infile when the file does not exist"""
conn = self.connections[2]
c = conn.cursor()
c.execute("CREATE TABLE test_load_local (a INTEGER, b INTEGER)")
try:
self.assertRaises(
OperationalError,
c.execute,
("LOAD DATA LOCAL INFILE 'no_data.txt' INTO TABLE "
"test_load_local fields terminated by ','")
)
finally:
c.execute("DROP TABLE test_load_local")
c.close()
def test_load_file(self):
"""Test load local infile with a valid file"""
conn = self.connections[2]
c = conn.cursor()
c.execute("CREATE TABLE test_load_local (a INTEGER, b INTEGER)")
filename = os.path.join(os.path.dirname(os.path.realpath(__file__)),
'data',
'load_local_data.txt')
try:
c.execute(
("LOAD DATA LOCAL INFILE '{0}' INTO TABLE " +
"test_load_local FIELDS TERMINATED BY ','").format(filename)
)
c.execute("SELECT COUNT(*) FROM test_load_local")
self.assertEqual(22749, c.fetchone()[0])
finally:
c.execute("DROP TABLE test_load_local")
def test_load_warnings(self):
"""Test load local infile produces the appropriate warnings"""
import warnings
conn = self.connections[2]
c = conn.cursor()
c.execute("CREATE TABLE test_load_local (a INTEGER, b INTEGER)")
filename = os.path.join(os.path.dirname(os.path.realpath(__file__)),
'data',
'load_local_warn_data.txt')
try:
with warnings.catch_warnings(record=True) as w:
c.execute(
("LOAD DATA LOCAL INFILE '{0}' INTO TABLE " +
"test_load_local FIELDS TERMINATED BY ','").format(filename)
)
self.assertEqual(True, "Incorrect integer value" in str(w[-1].message))
except Warning as w:
self.assertLess(0, str(w).find("Incorrect integer value"))
finally:
c.execute("DROP TABLE test_load_local")
if __name__ == "__main__":
import unittest
unittest.main()