diff --git a/.travis.databases.json b/.travis.databases.json index 852209e..b700531 100644 --- a/.travis.databases.json +++ b/.travis.databases.json @@ -1,4 +1,5 @@ [ {"host": "localhost", "user": "root", "passwd": "", "db": "test_pymysql", "use_unicode": true}, - {"host": "localhost", "user": "root", "passwd": "", "db": "test_pymysql2" } + {"host": "localhost", "user": "root", "passwd": "", "db": "test_pymysql2" }, + {"host": "localhost", "user": "root", "passwd": "", "db": "test_pymysql", "local_infile": true} ] diff --git a/pymysql/connections.py b/pymysql/connections.py index b9449c9..2f16e9d 100644 --- a/pymysql/connections.py +++ b/pymysql/connections.py @@ -2,7 +2,6 @@ # http://dev.mysql.com/doc/internals/en/client-server-protocol.html # Error codes: # http://dev.mysql.com/doc/refman/5.5/en/error-messages-client.html - from __future__ import print_function from ._compat import PY2, range_type, text_type, str_type, JYTHON, IRONPYTHON DEBUG = False @@ -12,6 +11,7 @@ from functools import partial import hashlib import io import os +import re import socket import struct import sys @@ -334,6 +334,9 @@ class MysqlPacket(object): field_count = ord(self._data[0:1]) return 1 <= field_count <= 250 + def is_load_local_packet(self): + return self._data[0:1] == b'\xfb' + def is_error_packet(self): return self._data[0:1] == b'\xff' @@ -454,6 +457,27 @@ class EOFPacketWrapper(object): return getattr(self.packet, key) +class LoadLocalPacketWrapper(object): + """ + Load Local Packet Wrapper. It uses an existing packet object, and wraps + around it, exposing useful variables while still providing access + to the original packet objects variables and methods. + """ + + def __init__(self, from_packet): + if not from_packet.is_load_local_packet(): + raise ValueError( + "Cannot create '{0}' object from invalid packet type".format( + self.__class__)) + + self.packet = from_packet + self.filename = self.packet.get_all_data()[1:] + if DEBUG: print("filename=", self.filename) + + def __getattr__(self, key): + return getattr(self.packet, key) + + class Connection(object): """ Representation of a socket with a mysql server. @@ -472,7 +496,7 @@ class Connection(object): client_flag=0, cursorclass=Cursor, init_command=None, connect_timeout=None, ssl=None, read_default_group=None, compress=None, named_pipe=None, no_delay=False, - autocommit=False, db=None, passwd=None): + autocommit=False, db=None, passwd=None, local_infile=False): """ Establish a connection to the MySQL database. Accepts several arguments: @@ -505,6 +529,7 @@ class Connection(object): named_pipe: Not supported no_delay: Disable Nagle's algorithm on the socket autocommit: Autocommit mode. None means use server default. (default: False) + local_infile: Boolean to enable the use of LOAD DATA LOCAL command. (default: False) db: Alias for database. (for compatibility to MySQLdb) passwd: Alias for password. (for compatibility to MySQLdb) @@ -521,6 +546,9 @@ class Connection(object): if compress or named_pipe: raise NotImplementedError("compress and named_pipe arguments are not supported") + if local_infile: + client_flag |= CLIENT.LOCAL_FILES + if ssl and ('capath' in ssl or 'cipher' in ssl): raise NotImplementedError('ssl options capath and cipher are not supported') @@ -1057,6 +1085,7 @@ class MySQLResult(object): self.rows = None self.has_next = None self.unbuffered_active = False + self.filename = None def __del__(self): if self.unbuffered_active: @@ -1069,6 +1098,8 @@ class MySQLResult(object): # TODO: use classes for different packet types? if first_packet.is_ok_packet(): self._read_ok_packet(first_packet) + elif first_packet.is_load_local_packet(): + self._read_load_local_packet(first_packet) else: self._read_result_packet(first_packet) finally: @@ -1100,6 +1131,21 @@ class MySQLResult(object): self.message = ok_packet.message self.has_next = ok_packet.has_next + def _read_load_local_packet(self, first_packet): + load_packet = LoadLocalPacketWrapper(first_packet) + local_packet = LoadLocalFile(load_packet.filename, self.connection) + self.filename = load_packet.filename + local_packet.send_data() + + ok_packet = self.connection._read_packet() + if not ok_packet.is_ok_packet(): + raise OperationalError(2014, "Commands Out of Sync") + self._read_ok_packet(ok_packet) + + if self.warning_count > 0: + self._print_warnings() + self.filename = None + def _check_packet_is_eof(self, packet): if packet.is_eof_packet(): eof_packet = EOFPacketWrapper(packet) @@ -1108,6 +1154,16 @@ class MySQLResult(object): return True return False + def _print_warnings(self): + from warnings import warn + self.connection._execute_command(COMMAND.COM_QUERY, 'SHOW WARNINGS') + self.read() + if self.rows: + message = "\n" + for db_warning in self.rows: + message += "{0} in file '{1}'\n".format(db_warning[2], self.filename.decode('utf-8')) + warn(message, Warning, 3) + def _read_result_packet(self, first_packet): self.field_count = first_packet.read_length_encoded_integer() self._get_descriptions() @@ -1191,4 +1247,40 @@ class MySQLResult(object): assert eof_packet.is_eof_packet(), 'Protocol error, expecting EOF' self.description = tuple(description) + +class LoadLocalFile(object): + def __init__(self, filename, connection): + self.filename = filename + self.connection = connection + + def send_data(self): + """Send data packets from the local file to the server""" + if not self.connection.socket: + raise InterfaceError("(0, '')") + + # sequence id is 2 as we already sent a query packet + seq_id = 2 + try: + with open(self.filename, 'rb') as open_file: + chunk_size = MAX_PACKET_LEN + prelude = b"" + packet = b"" + packet_size = 0 + + while True: + chunk = open_file.read(chunk_size) + if not chunk: + break + packet = struct.pack('