diff --git a/pymysql/connections.py b/pymysql/connections.py index ca8ae2f..dbc8a6b 100644 --- a/pymysql/connections.py +++ b/pymysql/connections.py @@ -19,7 +19,7 @@ import warnings from .charset import MBLENGTH, charset_by_name, charset_by_id from .cursors import Cursor from .constants import CLIENT, COMMAND, FIELD_TYPE, SERVER_STATUS -from .util import byte2int, int2byte +from .util import byte2int, int2byte, lenenc_int from .converters import ( escape_item, encoders, decoders, escape_string, through) from . import err @@ -139,7 +139,7 @@ def dump_packet(data): def _scramble(password, message): if not password: - return b'\0' + return b'' if DEBUG: print('password=' + str(password)) stage1 = sha_new(password).digest() stage2 = sha_new(stage1).digest() @@ -152,7 +152,7 @@ def _scramble(password, message): def _my_crypt(message1, message2): length = len(message1) - result = struct.pack('B', length) + result = '' for i in range_type(length): x = (struct.unpack('B', message1[i:i+1])[0] ^ struct.unpack('B', message2[i:i+1])[0]) @@ -985,6 +985,7 @@ class Connection(object): seq_id += 1 def _request_authentication(self): + # https://dev.mysql.com/doc/internals/en/connection-phase-packets.html#packet-Protocol::HandshakeResponse self.client_flag |= CLIENT.CAPABILITIES if self.server_version.startswith('5'): self.client_flag |= CLIENT.MULTI_RESULTS @@ -1000,7 +1001,7 @@ class Connection(object): next_packet = 1 - if self.ssl: + if self.ssl and self.server_capabilities & CLIENT.SSL: data = pack_int24(len(data_init)) + int2byte(next_packet) + data_init next_packet += 1 @@ -1015,14 +1016,30 @@ class Connection(object): ca_certs=self.ca) self._rfile = _makefile(self.socket, 'rb') - data = data_init + self.user + b'\0' + \ - _scramble(self.password.encode('latin1'), self.salt) + data = data_init + self.user + b'\0' - if self.db: + authresp = '' + if self.plugin_name == 'mysql_native_password': + authresp = _scramble(self.password.encode('latin1'), self.salt) + + if self.server_capabilities & CLIENT.PLUGIN_AUTH_LENENC_CLIENT_DATA: + data += lenenc_int(len(authresp)) + data += authresp + elif self.server_capabilities & CLIENT.SECURE_CONNECTION: + length = len(authresp) + data += struct.pack('B', length) + data += authresp + else: + data += authresp + int2byte(0) + + if self.db and self.server_capabilities & CLIENT.CONNECT_WITH_DB: if isinstance(self.db, text_type): self.db = self.db.encode(self.encoding) data += self.db + int2byte(0) + if self.server_capabilities & CLIENT.PLUGIN_AUTH: + data += self.plugin_name.encode('latin1') + int2byte(0) + data = pack_int24(len(data)) + int2byte(next_packet) + data next_packet += 2 @@ -1095,11 +1112,31 @@ class Connection(object): if len(data) >= i + salt_len: # salt_len includes auth_plugin_data_part_1 and filler self.salt += data[i:i+salt_len] - # TODO: AUTH PLUGIN NAME may appeare here. + i += salt_len + + i+=1 + # AUTH PLUGIN NAME may appear here. + if self.server_capabilities & CLIENT.PLUGIN_AUTH and len(data) >= i: + # Due to Bug#59453 the auth-plugin-name is missing the terminating + # NUL-char in versions prior to 5.5.10 and 5.6.2. + # ref: https://dev.mysql.com/doc/internals/en/connection-phase-packets.html#packet-Protocol::Handshake + # didn't use version checks as mariadb is corrected and reports + # earlier than those two. + server_end = data.find(int2byte(0), i) + if server_end < 0: + # not found \0 and last field so take it all + self.plugin_name = data[i:].decode('latin1') + else: + self.plugin_name = data[i:server_end].decode('latin1') + else: + self.plugin_name = '' def get_server_info(self): return self.server_version + def get_plugin_name(self): + return self.plugin_name + Warning = err.Warning Error = err.Error InterfaceError = err.InterfaceError diff --git a/pymysql/constants/CLIENT.py b/pymysql/constants/CLIENT.py index 1396cff..7f8c590 100644 --- a/pymysql/constants/CLIENT.py +++ b/pymysql/constants/CLIENT.py @@ -1,3 +1,4 @@ +# https://dev.mysql.com/doc/internals/en/capability-flags.html#packet-Protocol::CapabilityFlags LONG_PASSWORD = 1 FOUND_ROWS = 1 << 1 LONG_FLAG = 1 << 2 @@ -15,5 +16,14 @@ TRANSACTIONS = 1 << 13 SECURE_CONNECTION = 1 << 15 MULTI_STATEMENTS = 1 << 16 MULTI_RESULTS = 1 << 17 +PS_MULTI_RESULTS = 1 << 18 +PLUGIN_AUTH = 1 << 19 +PLUGIN_AUTH_LENENC_CLIENT_DATA = 1 << 21 CAPABILITIES = (LONG_PASSWORD | LONG_FLAG | TRANSACTIONS | - PROTOCOL_41 | SECURE_CONNECTION) + PROTOCOL_41 | SECURE_CONNECTION | PLUGIN_AUTH | + PLUGIN_AUTH_LENENC_CLIENT_DATA) +# Not done yet +CONNECT_ATTRS = 1 << 20 +HANDLE_EXPIRED_PASSWORDS = 1 << 22 +SESSION_TRACK = 1 << 23 +CLIENT_DEPRECATE_EOF = 1 << 24 diff --git a/pymysql/tests/test_connection.py b/pymysql/tests/test_connection.py index c1e1934..f9bf82d 100644 --- a/pymysql/tests/test_connection.py +++ b/pymysql/tests/test_connection.py @@ -2,6 +2,8 @@ import datetime import decimal import pymysql import time +import os +import copy from pymysql.tests import base @@ -74,6 +76,37 @@ class TestConnection(base.PyMySQLTestCase): self.assertEqual(('foobar',), c.fetchone()) conn.close() + def test_plugin(self): + con = self.connections[0] + self.assertEqual('mysql_native_password',con.get_plugin_name()) + + # attempt a unix socket test which is included in some versions + # and doesn't require a client side handler + user = os.environ.get('USER') + if not user or self.databases[0]['host'] != 'localhost': + return + cur = con.cursor() + cur.execute("SHOW PLUGINS") + found = False + for r in cur: + if r == (u'unix_socket', u'ACTIVE', u'AUTHENTICATION', u'auth_socket.so', u'GPL'): + found = True + break + # needs plugin. lets install it. + if not found: + cur.execute("install soname 'auth_socket'") + + current_db = self.databases[0]['db'] + cur.execute("GRANT ALL ON %s TO %s@localhost IDENTIFIED VIA unix_socket" % ( current_db, user)) + db = copy.copy(self.databases[0]) + del db['user'] + c = pymysql.connect(user=user, **db) + + if not found: + cur.execute("uninstall soname 'auth_socket'") + cur.execute("DROP USER %s@localhost" % user) + + # A custom type and function to escape it class Foo(object): diff --git a/pymysql/util.py b/pymysql/util.py index cc622e5..9342d62 100644 --- a/pymysql/util.py +++ b/pymysql/util.py @@ -17,3 +17,18 @@ def join_bytes(bs): for b in bs[1:]: rv += b return rv + +# https://dev.mysql.com/doc/internals/en/integer.html#packet-Protocol::LengthEncodedInteger +def lenenc_int(i): + if (i < 0): + raise ValueError("Encoding %d is less than 0 - no representation in LengthEncodedInteger" % i) + elif (i < 0xfb): + return int2byte(i) + elif (i < (1 << 16)): + return b'\xfc' + struct.pack('