Add basic connection plugin support
This extracts a plugin name from the server and exposes this as connection.get_plugin_name(). The client side here adds capabilities PLUGIN_AUTH and PLUGIN_AUTH_LENENC_CLIENT_DATA. Because of this the HandshakeResponse response has been altered to use the server_capabilities a bit mroe strictly. The immediate upshot is that plugins like unix_socket are immediately supported where they previously would of returned and error as without PLUGIN_AUTH being advertised on the client side, the server only accepts mysql_native_password and mysql_old_password. To support other plugins some more work is needed to implement these.
This commit is contained in:
@@ -19,7 +19,7 @@ import warnings
|
||||
from .charset import MBLENGTH, charset_by_name, charset_by_id
|
||||
from .cursors import Cursor
|
||||
from .constants import CLIENT, COMMAND, FIELD_TYPE, SERVER_STATUS
|
||||
from .util import byte2int, int2byte
|
||||
from .util import byte2int, int2byte, lenenc_int
|
||||
from .converters import (
|
||||
escape_item, encoders, decoders, escape_string, through)
|
||||
from . import err
|
||||
@@ -139,7 +139,7 @@ def dump_packet(data):
|
||||
|
||||
def _scramble(password, message):
|
||||
if not password:
|
||||
return b'\0'
|
||||
return b''
|
||||
if DEBUG: print('password=' + str(password))
|
||||
stage1 = sha_new(password).digest()
|
||||
stage2 = sha_new(stage1).digest()
|
||||
@@ -152,7 +152,7 @@ def _scramble(password, message):
|
||||
|
||||
def _my_crypt(message1, message2):
|
||||
length = len(message1)
|
||||
result = struct.pack('B', length)
|
||||
result = ''
|
||||
for i in range_type(length):
|
||||
x = (struct.unpack('B', message1[i:i+1])[0] ^
|
||||
struct.unpack('B', message2[i:i+1])[0])
|
||||
@@ -985,6 +985,7 @@ class Connection(object):
|
||||
seq_id += 1
|
||||
|
||||
def _request_authentication(self):
|
||||
# https://dev.mysql.com/doc/internals/en/connection-phase-packets.html#packet-Protocol::HandshakeResponse
|
||||
self.client_flag |= CLIENT.CAPABILITIES
|
||||
if self.server_version.startswith('5'):
|
||||
self.client_flag |= CLIENT.MULTI_RESULTS
|
||||
@@ -1000,7 +1001,7 @@ class Connection(object):
|
||||
|
||||
next_packet = 1
|
||||
|
||||
if self.ssl:
|
||||
if self.ssl and self.server_capabilities & CLIENT.SSL:
|
||||
data = pack_int24(len(data_init)) + int2byte(next_packet) + data_init
|
||||
next_packet += 1
|
||||
|
||||
@@ -1015,14 +1016,30 @@ class Connection(object):
|
||||
ca_certs=self.ca)
|
||||
self._rfile = _makefile(self.socket, 'rb')
|
||||
|
||||
data = data_init + self.user + b'\0' + \
|
||||
_scramble(self.password.encode('latin1'), self.salt)
|
||||
data = data_init + self.user + b'\0'
|
||||
|
||||
if self.db:
|
||||
authresp = ''
|
||||
if self.plugin_name == 'mysql_native_password':
|
||||
authresp = _scramble(self.password.encode('latin1'), self.salt)
|
||||
|
||||
if self.server_capabilities & CLIENT.PLUGIN_AUTH_LENENC_CLIENT_DATA:
|
||||
data += lenenc_int(len(authresp))
|
||||
data += authresp
|
||||
elif self.server_capabilities & CLIENT.SECURE_CONNECTION:
|
||||
length = len(authresp)
|
||||
data += struct.pack('B', length)
|
||||
data += authresp
|
||||
else:
|
||||
data += authresp + int2byte(0)
|
||||
|
||||
if self.db and self.server_capabilities & CLIENT.CONNECT_WITH_DB:
|
||||
if isinstance(self.db, text_type):
|
||||
self.db = self.db.encode(self.encoding)
|
||||
data += self.db + int2byte(0)
|
||||
|
||||
if self.server_capabilities & CLIENT.PLUGIN_AUTH:
|
||||
data += self.plugin_name.encode('latin1') + int2byte(0)
|
||||
|
||||
data = pack_int24(len(data)) + int2byte(next_packet) + data
|
||||
next_packet += 2
|
||||
|
||||
@@ -1095,11 +1112,31 @@ class Connection(object):
|
||||
if len(data) >= i + salt_len:
|
||||
# salt_len includes auth_plugin_data_part_1 and filler
|
||||
self.salt += data[i:i+salt_len]
|
||||
# TODO: AUTH PLUGIN NAME may appeare here.
|
||||
i += salt_len
|
||||
|
||||
i+=1
|
||||
# AUTH PLUGIN NAME may appear here.
|
||||
if self.server_capabilities & CLIENT.PLUGIN_AUTH and len(data) >= i:
|
||||
# Due to Bug#59453 the auth-plugin-name is missing the terminating
|
||||
# NUL-char in versions prior to 5.5.10 and 5.6.2.
|
||||
# ref: https://dev.mysql.com/doc/internals/en/connection-phase-packets.html#packet-Protocol::Handshake
|
||||
# didn't use version checks as mariadb is corrected and reports
|
||||
# earlier than those two.
|
||||
server_end = data.find(int2byte(0), i)
|
||||
if server_end < 0:
|
||||
# not found \0 and last field so take it all
|
||||
self.plugin_name = data[i:].decode('latin1')
|
||||
else:
|
||||
self.plugin_name = data[i:server_end].decode('latin1')
|
||||
else:
|
||||
self.plugin_name = ''
|
||||
|
||||
def get_server_info(self):
|
||||
return self.server_version
|
||||
|
||||
def get_plugin_name(self):
|
||||
return self.plugin_name
|
||||
|
||||
Warning = err.Warning
|
||||
Error = err.Error
|
||||
InterfaceError = err.InterfaceError
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
# https://dev.mysql.com/doc/internals/en/capability-flags.html#packet-Protocol::CapabilityFlags
|
||||
LONG_PASSWORD = 1
|
||||
FOUND_ROWS = 1 << 1
|
||||
LONG_FLAG = 1 << 2
|
||||
@@ -15,5 +16,14 @@ TRANSACTIONS = 1 << 13
|
||||
SECURE_CONNECTION = 1 << 15
|
||||
MULTI_STATEMENTS = 1 << 16
|
||||
MULTI_RESULTS = 1 << 17
|
||||
PS_MULTI_RESULTS = 1 << 18
|
||||
PLUGIN_AUTH = 1 << 19
|
||||
PLUGIN_AUTH_LENENC_CLIENT_DATA = 1 << 21
|
||||
CAPABILITIES = (LONG_PASSWORD | LONG_FLAG | TRANSACTIONS |
|
||||
PROTOCOL_41 | SECURE_CONNECTION)
|
||||
PROTOCOL_41 | SECURE_CONNECTION | PLUGIN_AUTH |
|
||||
PLUGIN_AUTH_LENENC_CLIENT_DATA)
|
||||
# Not done yet
|
||||
CONNECT_ATTRS = 1 << 20
|
||||
HANDLE_EXPIRED_PASSWORDS = 1 << 22
|
||||
SESSION_TRACK = 1 << 23
|
||||
CLIENT_DEPRECATE_EOF = 1 << 24
|
||||
|
||||
@@ -2,6 +2,8 @@ import datetime
|
||||
import decimal
|
||||
import pymysql
|
||||
import time
|
||||
import os
|
||||
import copy
|
||||
from pymysql.tests import base
|
||||
|
||||
|
||||
@@ -74,6 +76,37 @@ class TestConnection(base.PyMySQLTestCase):
|
||||
self.assertEqual(('foobar',), c.fetchone())
|
||||
conn.close()
|
||||
|
||||
def test_plugin(self):
|
||||
con = self.connections[0]
|
||||
self.assertEqual('mysql_native_password',con.get_plugin_name())
|
||||
|
||||
# attempt a unix socket test which is included in some versions
|
||||
# and doesn't require a client side handler
|
||||
user = os.environ.get('USER')
|
||||
if not user or self.databases[0]['host'] != 'localhost':
|
||||
return
|
||||
cur = con.cursor()
|
||||
cur.execute("SHOW PLUGINS")
|
||||
found = False
|
||||
for r in cur:
|
||||
if r == (u'unix_socket', u'ACTIVE', u'AUTHENTICATION', u'auth_socket.so', u'GPL'):
|
||||
found = True
|
||||
break
|
||||
# needs plugin. lets install it.
|
||||
if not found:
|
||||
cur.execute("install soname 'auth_socket'")
|
||||
|
||||
current_db = self.databases[0]['db']
|
||||
cur.execute("GRANT ALL ON %s TO %s@localhost IDENTIFIED VIA unix_socket" % ( current_db, user))
|
||||
db = copy.copy(self.databases[0])
|
||||
del db['user']
|
||||
c = pymysql.connect(user=user, **db)
|
||||
|
||||
if not found:
|
||||
cur.execute("uninstall soname 'auth_socket'")
|
||||
cur.execute("DROP USER %s@localhost" % user)
|
||||
|
||||
|
||||
|
||||
# A custom type and function to escape it
|
||||
class Foo(object):
|
||||
|
||||
@@ -17,3 +17,18 @@ def join_bytes(bs):
|
||||
for b in bs[1:]:
|
||||
rv += b
|
||||
return rv
|
||||
|
||||
# https://dev.mysql.com/doc/internals/en/integer.html#packet-Protocol::LengthEncodedInteger
|
||||
def lenenc_int(i):
|
||||
if (i < 0):
|
||||
raise ValueError("Encoding %d is less than 0 - no representation in LengthEncodedInteger" % i)
|
||||
elif (i < 0xfb):
|
||||
return int2byte(i)
|
||||
elif (i < (1 << 16)):
|
||||
return b'\xfc' + struct.pack('<H', i)
|
||||
elif (i < (1 << 24)):
|
||||
return b'\xfd' + struct.pack('<I', i)[:3]
|
||||
elif (i < (1 << 64)):
|
||||
return b'\xfe' + struct.pack('<Q', i)
|
||||
else:
|
||||
raise ValueError("Encoding %x is larger than %x - no representation in LengthEncodedInteger" % (i, (1 << 64)))
|
||||
|
||||
Reference in New Issue
Block a user