diff --git a/.travis.yml b/.travis.yml index ec3bd3f..71bc179 100644 --- a/.travis.yml +++ b/.travis.yml @@ -1,8 +1,9 @@ sudo: false language: python -python: "3.4" +python: "3.5" cache: - pip: true + directories: + - $HOME/.cache/pip env: matrix: @@ -10,6 +11,7 @@ env: - TOX_ENV=py27 - TOX_ENV=py33 - TOX_ENV=py34 + - TOX_ENV=py35 - TOX_ENV=pypy - TOX_ENV=pypy3 @@ -36,7 +38,7 @@ matrix: sudo: required - env: - TOX_ENV=py34 - - DB=5.6.26 + - DB=5.6.28 addons: apt: packages: diff --git a/pymysql/connections.py b/pymysql/connections.py index 737fb51..3f6702e 100644 --- a/pymysql/connections.py +++ b/pymysql/connections.py @@ -525,6 +525,7 @@ class Connection(object): """ socket = None + _auth_plugin_name = '' def __init__(self, host=None, user=None, password="", database=None, port=3306, unix_socket=None, @@ -535,7 +536,7 @@ class Connection(object): compress=None, named_pipe=None, no_delay=None, autocommit=False, db=None, passwd=None, local_infile=False, max_allowed_packet=16*1024*1024, defer_connect=False, - plugin_map={}): + auth_plugin_map={}): """ Establish a connection to the MySQL database. Accepts several arguments: @@ -571,12 +572,11 @@ class Connection(object): max_allowed_packet: Max size of packet sent to server in bytes. (default: 16MB) defer_connect: Don't explicitly connect on contruction - wait for connect call. (default: False) - plugin_map: Map of plugin names to a class that processes that plugin. The class - will take the Connection object as the argument to the constructor. The class - needs an authenticate method taking an authentication packet as an argument. - For the dialog plugin, a prompt(echo, prompt) method can be used (if no - authenticate method) for returning a string from the user. - + auth_plugin_map: A dict of plugin names to a class that processes that plugin. + The class will take the Connection object as the argument to the constructor. + The class needs an authenticate method taking an authentication packet as + an argument. For the dialog plugin, a prompt(echo, prompt) method can be used + (if no authenticate method) for returning a string from the user. (experimental) db: Alias for database. (for compatibility to MySQLdb) passwd: Alias for password. (for compatibility to MySQLdb) """ @@ -672,7 +672,7 @@ class Connection(object): self.sql_mode = sql_mode self.init_command = init_command self.max_allowed_packet = max_allowed_packet - self.plugin_map = plugin_map + self._auth_plugin_map = auth_plugin_map if defer_connect: self.socket = None else: @@ -726,7 +726,6 @@ class Connection(object): def autocommit(self, value): self.autocommit_mode = bool(value) current = self.get_autocommit() - self.next_packet = 1 if value != current: self._send_autocommit_mode() @@ -816,7 +815,6 @@ class Connection(object): "You may not close previous cursor.") # if DEBUG: # print("DEBUG: sending query:", sql) - self.next_packet = 1 if isinstance(sql, text_type) and not (JYTHON or IRONPYTHON): if PY2: sql = sql.encode(self.encoding) @@ -824,7 +822,6 @@ class Connection(object): sql = sql.encode(self.encoding, 'surrogateescape') self._execute_command(COMMAND.COM_QUERY, sql) self._affected_rows = self._read_query_result(unbuffered=unbuffered) - self.next_packet = 1 return self._affected_rows def next_result(self, unbuffered=False): @@ -892,7 +889,7 @@ class Connection(object): sock.setsockopt(socket.SOL_SOCKET, socket.SO_KEEPALIVE, 1) self.socket = sock self._rfile = _makefile(sock, 'rb') - self.next_packet = 0 + self._next_seq_id = 0 self._get_server_information() self._request_authentication() @@ -933,16 +930,16 @@ class Connection(object): # So just reraise it. raise - def write_packet(self, data): + def write_packet(self, payload): """Writes an entire "mysql packet" in its entirety to the network - addings its length and sequence number. Intended for use by plugins - only. + addings its length and sequence number. """ - data = pack_int24(len(data)) + int2byte(self.next_packet) + data + # Internal note: when you build packet manualy and calls _write_bytes() + # directly, you should set self._next_seq_id properly. + data = pack_int24(len(payload)) + int2byte(self._next_seq_id) + payload if DEBUG: dump_packet(data) - self._write_bytes(data) - self.next_packet = (self.next_packet + 1) % 256 + self._next_seq_id = (self._next_seq_id + 1) % 256 def _read_packet(self, packet_type=MysqlPacket): """Read an entire "mysql packet" in its entirety from the network @@ -952,8 +949,14 @@ class Connection(object): while True: packet_header = self._read_bytes(4) if DEBUG: dump_packet(packet_header) + btrl, btrh, packet_number = struct.unpack('=5.0) + data += struct.pack('B', len(authresp)) + authresp + else: # pragma: no cover - not testing against servers without secure auth (>=5.0) data += authresp + b'\0' if self.db and self.server_capabilities & CLIENT.CONNECT_WITH_DB: @@ -1098,15 +1091,16 @@ class Connection(object): data += self.db + b'\0' if self.server_capabilities & CLIENT.PLUGIN_AUTH: - data += self.plugin_name.encode('latin1') + b'\0' + name = self._auth_plugin_name + if isinstance(name, text_type): + name = name.encode('ascii') + data += name + b'\0' self.write_packet(data) - auth_packet = self._read_packet() # if authentication method isn't accepted the first byte # will have the octet 254 - if auth_packet.is_auth_switch_request(): # https://dev.mysql.com/doc/internals/en/connection-phase-packets.html#packet-Protocol::AuthSwitchRequest auth_packet.read_uint8() # 0xfe packet identifier @@ -1119,8 +1113,13 @@ class Connection(object): self.write_packet(data) auth_packet = self._read_packet() + #TODO: ok packet or error packet? + + def _process_auth(self, plugin_name, auth_packet): - plugin_class = self.plugin_map.get(plugin_name) + plugin_class = self._auth_plugin_map.get(plugin_name) + if not plugin_class: + plugin_class = self._auth_plugin_map.get(plugin_name.decode('ascii')) if plugin_class: try: handler = plugin_class(self) @@ -1246,18 +1245,13 @@ class Connection(object): server_end = data.find(b'\0', i) if server_end < 0: # pragma: no cover - very specific upstream bug # not found \0 and last field so take it all - self.plugin_name = data[i:].decode('latin1') + self._auth_plugin_name = data[i:].decode('latin1') else: - self.plugin_name = data[i:server_end].decode('latin1') - else: # pragma: no cover - not testing against any plugin uncapable servers - self.plugin_name = '' + self._auth_plugin_name = data[i:server_end].decode('latin1') def get_server_info(self): return self.server_version - def get_plugin_name(self): - return self.plugin_name - Warning = err.Warning Error = err.Error InterfaceError = err.InterfaceError @@ -1331,7 +1325,11 @@ class MySQLResult(object): def _read_load_local_packet(self, first_packet): load_packet = LoadLocalPacketWrapper(first_packet) sender = LoadLocalFile(load_packet.filename, self.connection) - sender.send_data() + try: + sender.send_data() + except: + self.connection._read_packet() # skip ok packet + raise ok_packet = self.connection._read_packet() if not ok_packet.is_ok_packet(): # pragma: no cover - upstream induced protocol error @@ -1448,27 +1446,20 @@ class LoadLocalFile(object): """Send data packets from the local file to the server""" if not self.connection.socket: raise err.InterfaceError("(0, '')") + conn = self.connection - # sequence id is 2 as we already sent a query packet - seq_id = 2 try: with open(self.filename, 'rb') as open_file: - chunk_size = self.connection.max_allowed_packet + chunk_size = conn.max_allowed_packet packet = b"" while True: chunk = open_file.read(chunk_size) if not chunk: break - packet = struct.pack('