Add basic connection plugin support

This extracts a plugin name from the server and exposes this as
connection.get_plugin_name().

The client side here adds capabilities PLUGIN_AUTH and
PLUGIN_AUTH_LENENC_CLIENT_DATA.

Because of this the HandshakeResponse response has been altered to use
the server_capabilities a bit mroe strictly.

The immediate upshot is that plugins like unix_socket are immediately
supported where they previously would of returned and error as without
PLUGIN_AUTH being advertised on the client side, the server only accepts
mysql_native_password and mysql_old_password.

To support other plugins some more work is needed to implement these.
This commit is contained in:
Daniel Black
2015-06-28 15:43:57 +10:00
parent 4fe1073c4d
commit 2d1da09f07
4 changed files with 104 additions and 9 deletions

View File

@@ -19,7 +19,7 @@ import warnings
from .charset import MBLENGTH, charset_by_name, charset_by_id
from .cursors import Cursor
from .constants import CLIENT, COMMAND, FIELD_TYPE, SERVER_STATUS
from .util import byte2int, int2byte
from .util import byte2int, int2byte, lenenc_int
from .converters import (
escape_item, encoders, decoders, escape_string, through)
from . import err
@@ -139,7 +139,7 @@ def dump_packet(data):
def _scramble(password, message):
if not password:
return b'\0'
return b''
if DEBUG: print('password=' + str(password))
stage1 = sha_new(password).digest()
stage2 = sha_new(stage1).digest()
@@ -152,7 +152,7 @@ def _scramble(password, message):
def _my_crypt(message1, message2):
length = len(message1)
result = struct.pack('B', length)
result = ''
for i in range_type(length):
x = (struct.unpack('B', message1[i:i+1])[0] ^
struct.unpack('B', message2[i:i+1])[0])
@@ -985,6 +985,7 @@ class Connection(object):
seq_id += 1
def _request_authentication(self):
# https://dev.mysql.com/doc/internals/en/connection-phase-packets.html#packet-Protocol::HandshakeResponse
self.client_flag |= CLIENT.CAPABILITIES
if self.server_version.startswith('5'):
self.client_flag |= CLIENT.MULTI_RESULTS
@@ -1000,7 +1001,7 @@ class Connection(object):
next_packet = 1
if self.ssl:
if self.ssl and self.server_capabilities & CLIENT.SSL:
data = pack_int24(len(data_init)) + int2byte(next_packet) + data_init
next_packet += 1
@@ -1015,14 +1016,30 @@ class Connection(object):
ca_certs=self.ca)
self._rfile = _makefile(self.socket, 'rb')
data = data_init + self.user + b'\0' + \
_scramble(self.password.encode('latin1'), self.salt)
data = data_init + self.user + b'\0'
if self.db:
authresp = ''
if self.plugin_name == 'mysql_native_password':
authresp = _scramble(self.password.encode('latin1'), self.salt)
if self.server_capabilities & CLIENT.PLUGIN_AUTH_LENENC_CLIENT_DATA:
data += lenenc_int(len(authresp))
data += authresp
elif self.server_capabilities & CLIENT.SECURE_CONNECTION:
length = len(authresp)
data += struct.pack('B', length)
data += authresp
else:
data += authresp + int2byte(0)
if self.db and self.server_capabilities & CLIENT.CONNECT_WITH_DB:
if isinstance(self.db, text_type):
self.db = self.db.encode(self.encoding)
data += self.db + int2byte(0)
if self.server_capabilities & CLIENT.PLUGIN_AUTH:
data += self.plugin_name.encode('latin1') + int2byte(0)
data = pack_int24(len(data)) + int2byte(next_packet) + data
next_packet += 2
@@ -1095,11 +1112,31 @@ class Connection(object):
if len(data) >= i + salt_len:
# salt_len includes auth_plugin_data_part_1 and filler
self.salt += data[i:i+salt_len]
# TODO: AUTH PLUGIN NAME may appeare here.
i += salt_len
i+=1
# AUTH PLUGIN NAME may appear here.
if self.server_capabilities & CLIENT.PLUGIN_AUTH and len(data) >= i:
# Due to Bug#59453 the auth-plugin-name is missing the terminating
# NUL-char in versions prior to 5.5.10 and 5.6.2.
# ref: https://dev.mysql.com/doc/internals/en/connection-phase-packets.html#packet-Protocol::Handshake
# didn't use version checks as mariadb is corrected and reports
# earlier than those two.
server_end = data.find(int2byte(0), i)
if server_end < 0:
# not found \0 and last field so take it all
self.plugin_name = data[i:].decode('latin1')
else:
self.plugin_name = data[i:server_end].decode('latin1')
else:
self.plugin_name = ''
def get_server_info(self):
return self.server_version
def get_plugin_name(self):
return self.plugin_name
Warning = err.Warning
Error = err.Error
InterfaceError = err.InterfaceError

View File

@@ -1,3 +1,4 @@
# https://dev.mysql.com/doc/internals/en/capability-flags.html#packet-Protocol::CapabilityFlags
LONG_PASSWORD = 1
FOUND_ROWS = 1 << 1
LONG_FLAG = 1 << 2
@@ -15,5 +16,14 @@ TRANSACTIONS = 1 << 13
SECURE_CONNECTION = 1 << 15
MULTI_STATEMENTS = 1 << 16
MULTI_RESULTS = 1 << 17
PS_MULTI_RESULTS = 1 << 18
PLUGIN_AUTH = 1 << 19
PLUGIN_AUTH_LENENC_CLIENT_DATA = 1 << 21
CAPABILITIES = (LONG_PASSWORD | LONG_FLAG | TRANSACTIONS |
PROTOCOL_41 | SECURE_CONNECTION)
PROTOCOL_41 | SECURE_CONNECTION | PLUGIN_AUTH |
PLUGIN_AUTH_LENENC_CLIENT_DATA)
# Not done yet
CONNECT_ATTRS = 1 << 20
HANDLE_EXPIRED_PASSWORDS = 1 << 22
SESSION_TRACK = 1 << 23
CLIENT_DEPRECATE_EOF = 1 << 24

View File

@@ -2,6 +2,8 @@ import datetime
import decimal
import pymysql
import time
import os
import copy
from pymysql.tests import base
@@ -74,6 +76,37 @@ class TestConnection(base.PyMySQLTestCase):
self.assertEqual(('foobar',), c.fetchone())
conn.close()
def test_plugin(self):
con = self.connections[0]
self.assertEqual('mysql_native_password',con.get_plugin_name())
# attempt a unix socket test which is included in some versions
# and doesn't require a client side handler
user = os.environ.get('USER')
if not user or self.databases[0]['host'] != 'localhost':
return
cur = con.cursor()
cur.execute("SHOW PLUGINS")
found = False
for r in cur:
if r == (u'unix_socket', u'ACTIVE', u'AUTHENTICATION', u'auth_socket.so', u'GPL'):
found = True
break
# needs plugin. lets install it.
if not found:
cur.execute("install soname 'auth_socket'")
current_db = self.databases[0]['db']
cur.execute("GRANT ALL ON %s TO %s@localhost IDENTIFIED VIA unix_socket" % ( current_db, user))
db = copy.copy(self.databases[0])
del db['user']
c = pymysql.connect(user=user, **db)
if not found:
cur.execute("uninstall soname 'auth_socket'")
cur.execute("DROP USER %s@localhost" % user)
# A custom type and function to escape it
class Foo(object):

View File

@@ -17,3 +17,18 @@ def join_bytes(bs):
for b in bs[1:]:
rv += b
return rv
# https://dev.mysql.com/doc/internals/en/integer.html#packet-Protocol::LengthEncodedInteger
def lenenc_int(i):
if (i < 0):
raise ValueError("Encoding %d is less than 0 - no representation in LengthEncodedInteger" % i)
elif (i < 0xfb):
return int2byte(i)
elif (i < (1 << 16)):
return b'\xfc' + struct.pack('<H', i)
elif (i < (1 << 24)):
return b'\xfd' + struct.pack('<I', i)[:3]
elif (i < (1 << 64)):
return b'\xfe' + struct.pack('<Q', i)
else:
raise ValueError("Encoding %x is larger than %x - no representation in LengthEncodedInteger" % (i, (1 << 64)))