Merge pull request #405 from methane/feature/refactor

refactoring auth plugin support
This commit is contained in:
INADA Naoki
2016-01-04 21:13:24 +09:00
4 changed files with 79 additions and 80 deletions

View File

@@ -1,8 +1,9 @@
sudo: false sudo: false
language: python language: python
python: "3.4" python: "3.5"
cache: cache:
pip: true directories:
- $HOME/.cache/pip
env: env:
matrix: matrix:
@@ -10,6 +11,7 @@ env:
- TOX_ENV=py27 - TOX_ENV=py27
- TOX_ENV=py33 - TOX_ENV=py33
- TOX_ENV=py34 - TOX_ENV=py34
- TOX_ENV=py35
- TOX_ENV=pypy - TOX_ENV=pypy
- TOX_ENV=pypy3 - TOX_ENV=pypy3
@@ -36,7 +38,7 @@ matrix:
sudo: required sudo: required
- env: - env:
- TOX_ENV=py34 - TOX_ENV=py34
- DB=5.6.26 - DB=5.6.28
addons: addons:
apt: apt:
packages: packages:

View File

@@ -525,6 +525,7 @@ class Connection(object):
""" """
socket = None socket = None
_auth_plugin_name = ''
def __init__(self, host=None, user=None, password="", def __init__(self, host=None, user=None, password="",
database=None, port=3306, unix_socket=None, database=None, port=3306, unix_socket=None,
@@ -535,7 +536,7 @@ class Connection(object):
compress=None, named_pipe=None, no_delay=None, compress=None, named_pipe=None, no_delay=None,
autocommit=False, db=None, passwd=None, local_infile=False, autocommit=False, db=None, passwd=None, local_infile=False,
max_allowed_packet=16*1024*1024, defer_connect=False, max_allowed_packet=16*1024*1024, defer_connect=False,
plugin_map={}): auth_plugin_map={}):
""" """
Establish a connection to the MySQL database. Accepts several Establish a connection to the MySQL database. Accepts several
arguments: arguments:
@@ -571,12 +572,11 @@ class Connection(object):
max_allowed_packet: Max size of packet sent to server in bytes. (default: 16MB) max_allowed_packet: Max size of packet sent to server in bytes. (default: 16MB)
defer_connect: Don't explicitly connect on contruction - wait for connect call. defer_connect: Don't explicitly connect on contruction - wait for connect call.
(default: False) (default: False)
plugin_map: Map of plugin names to a class that processes that plugin. The class auth_plugin_map: A dict of plugin names to a class that processes that plugin.
will take the Connection object as the argument to the constructor. The class The class will take the Connection object as the argument to the constructor.
needs an authenticate method taking an authentication packet as an argument. The class needs an authenticate method taking an authentication packet as
For the dialog plugin, a prompt(echo, prompt) method can be used (if no an argument. For the dialog plugin, a prompt(echo, prompt) method can be used
authenticate method) for returning a string from the user. (if no authenticate method) for returning a string from the user. (experimental)
db: Alias for database. (for compatibility to MySQLdb) db: Alias for database. (for compatibility to MySQLdb)
passwd: Alias for password. (for compatibility to MySQLdb) passwd: Alias for password. (for compatibility to MySQLdb)
""" """
@@ -672,7 +672,7 @@ class Connection(object):
self.sql_mode = sql_mode self.sql_mode = sql_mode
self.init_command = init_command self.init_command = init_command
self.max_allowed_packet = max_allowed_packet self.max_allowed_packet = max_allowed_packet
self.plugin_map = plugin_map self._auth_plugin_map = auth_plugin_map
if defer_connect: if defer_connect:
self.socket = None self.socket = None
else: else:
@@ -726,7 +726,6 @@ class Connection(object):
def autocommit(self, value): def autocommit(self, value):
self.autocommit_mode = bool(value) self.autocommit_mode = bool(value)
current = self.get_autocommit() current = self.get_autocommit()
self.next_packet = 1
if value != current: if value != current:
self._send_autocommit_mode() self._send_autocommit_mode()
@@ -816,7 +815,6 @@ class Connection(object):
"You may not close previous cursor.") "You may not close previous cursor.")
# if DEBUG: # if DEBUG:
# print("DEBUG: sending query:", sql) # print("DEBUG: sending query:", sql)
self.next_packet = 1
if isinstance(sql, text_type) and not (JYTHON or IRONPYTHON): if isinstance(sql, text_type) and not (JYTHON or IRONPYTHON):
if PY2: if PY2:
sql = sql.encode(self.encoding) sql = sql.encode(self.encoding)
@@ -824,7 +822,6 @@ class Connection(object):
sql = sql.encode(self.encoding, 'surrogateescape') sql = sql.encode(self.encoding, 'surrogateescape')
self._execute_command(COMMAND.COM_QUERY, sql) self._execute_command(COMMAND.COM_QUERY, sql)
self._affected_rows = self._read_query_result(unbuffered=unbuffered) self._affected_rows = self._read_query_result(unbuffered=unbuffered)
self.next_packet = 1
return self._affected_rows return self._affected_rows
def next_result(self, unbuffered=False): def next_result(self, unbuffered=False):
@@ -892,7 +889,7 @@ class Connection(object):
sock.setsockopt(socket.SOL_SOCKET, socket.SO_KEEPALIVE, 1) sock.setsockopt(socket.SOL_SOCKET, socket.SO_KEEPALIVE, 1)
self.socket = sock self.socket = sock
self._rfile = _makefile(sock, 'rb') self._rfile = _makefile(sock, 'rb')
self.next_packet = 0 self._next_seq_id = 0
self._get_server_information() self._get_server_information()
self._request_authentication() self._request_authentication()
@@ -933,16 +930,16 @@ class Connection(object):
# So just reraise it. # So just reraise it.
raise raise
def write_packet(self, data): def write_packet(self, payload):
"""Writes an entire "mysql packet" in its entirety to the network """Writes an entire "mysql packet" in its entirety to the network
addings its length and sequence number. Intended for use by plugins addings its length and sequence number.
only.
""" """
data = pack_int24(len(data)) + int2byte(self.next_packet) + data # Internal note: when you build packet manualy and calls _write_bytes()
# directly, you should set self._next_seq_id properly.
data = pack_int24(len(payload)) + int2byte(self._next_seq_id) + payload
if DEBUG: dump_packet(data) if DEBUG: dump_packet(data)
self._write_bytes(data) self._write_bytes(data)
self.next_packet = (self.next_packet + 1) % 256 self._next_seq_id = (self._next_seq_id + 1) % 256
def _read_packet(self, packet_type=MysqlPacket): def _read_packet(self, packet_type=MysqlPacket):
"""Read an entire "mysql packet" in its entirety from the network """Read an entire "mysql packet" in its entirety from the network
@@ -952,8 +949,14 @@ class Connection(object):
while True: while True:
packet_header = self._read_bytes(4) packet_header = self._read_bytes(4)
if DEBUG: dump_packet(packet_header) if DEBUG: dump_packet(packet_header)
btrl, btrh, packet_number = struct.unpack('<HBB', packet_header) btrl, btrh, packet_number = struct.unpack('<HBB', packet_header)
bytes_to_read = btrl + (btrh << 16) bytes_to_read = btrl + (btrh << 16)
if packet_number != self._next_seq_id:
raise err.InternalError("Packet sequence number wrong - got %d expected %d" %
(packet_number, self._next_seq_id))
self._next_seq_id = (self._next_seq_id + 1) % 256
recv_data = self._read_bytes(bytes_to_read) recv_data = self._read_bytes(bytes_to_read)
if DEBUG: dump_packet(recv_data) if DEBUG: dump_packet(recv_data)
buff += recv_data buff += recv_data
@@ -962,13 +965,7 @@ class Connection(object):
continue continue
if bytes_to_read < MAX_PACKET_LEN: if bytes_to_read < MAX_PACKET_LEN:
break break
if packet_number != self.next_packet:
pass
#TODO: check sequence id
#raise err.InternalError("Packet sequence number wrong - got %d expected %d" %
# (packet_number, self.next_packet))
self.next_packet = (packet_number + 1) % 256
packet = packet_type(buff, self.encoding) packet = packet_type(buff, self.encoding)
packet.check_error() packet.check_error()
return packet return packet
@@ -1027,33 +1024,32 @@ class Connection(object):
if self._result is not None and self._result.unbuffered_active: if self._result is not None and self._result.unbuffered_active:
warnings.warn("Previous unbuffered result was left incomplete") warnings.warn("Previous unbuffered result was left incomplete")
self._result._finish_unbuffered_query() self._result._finish_unbuffered_query()
self._result = None
if isinstance(sql, text_type): if isinstance(sql, text_type):
sql = sql.encode(self.encoding) sql = sql.encode(self.encoding)
chunk_size = min(self.max_allowed_packet, len(sql) + 1) # +1 is for command # +1 is for command
chunk_size = min(self.max_allowed_packet, len(sql) + 1)
# tiny optimization: build first packet manually instead of
# calling self..write_packet()
prelude = struct.pack('<iB', chunk_size, command) prelude = struct.pack('<iB', chunk_size, command)
self._write_bytes(prelude + sql[:chunk_size-1]) packet = prelude + sql[:chunk_size-1]
if DEBUG: dump_packet(prelude + sql) self._write_bytes(packet)
if DEBUG: dump_packet(packet)
self._next_seq_id = 1
self.next_packet = 1
if chunk_size < self.max_allowed_packet: if chunk_size < self.max_allowed_packet:
return return
seq_id = 1
sql = sql[chunk_size-1:] sql = sql[chunk_size-1:]
while True: while True:
chunk_size = min(self.max_allowed_packet, len(sql)) chunk_size = min(self.max_allowed_packet, len(sql))
prelude = struct.pack('<i', chunk_size)[:3] self.write_packet(sql[:chunk_size])
data = prelude + int2byte(seq_id%256) + sql[:chunk_size]
self._write_bytes(data)
if DEBUG: dump_packet(data)
sql = sql[chunk_size:] sql = sql[chunk_size:]
if not sql and chunk_size < self.max_allowed_packet: if not sql and chunk_size < self.max_allowed_packet:
break break
seq_id += 1
self.next_packet = seq_id%256
def _request_authentication(self): def _request_authentication(self):
# https://dev.mysql.com/doc/internals/en/connection-phase-packets.html#packet-Protocol::HandshakeResponse # https://dev.mysql.com/doc/internals/en/connection-phase-packets.html#packet-Protocol::HandshakeResponse
@@ -1078,17 +1074,14 @@ class Connection(object):
data = data_init + self.user + b'\0' data = data_init + self.user + b'\0'
authresp = '' authresp = b''
if self.plugin_name == 'mysql_native_password': if self._auth_plugin_name == 'mysql_native_password':
authresp = _scramble(self.password.encode('latin1'), self.salt) authresp = _scramble(self.password.encode('latin1'), self.salt)
if self.server_capabilities & CLIENT.PLUGIN_AUTH_LENENC_CLIENT_DATA: if self.server_capabilities & CLIENT.PLUGIN_AUTH_LENENC_CLIENT_DATA:
data += lenenc_int(len(authresp)) data += lenenc_int(len(authresp)) + authresp
data += authresp
elif self.server_capabilities & CLIENT.SECURE_CONNECTION: elif self.server_capabilities & CLIENT.SECURE_CONNECTION:
length = len(authresp) data += struct.pack('B', len(authresp)) + authresp
data += struct.pack('B', length)
data += authresp
else: # pragma: no cover - not testing against servers without secure auth (>=5.0) else: # pragma: no cover - not testing against servers without secure auth (>=5.0)
data += authresp + b'\0' data += authresp + b'\0'
@@ -1098,15 +1091,16 @@ class Connection(object):
data += self.db + b'\0' data += self.db + b'\0'
if self.server_capabilities & CLIENT.PLUGIN_AUTH: if self.server_capabilities & CLIENT.PLUGIN_AUTH:
data += self.plugin_name.encode('latin1') + b'\0' name = self._auth_plugin_name
if isinstance(name, text_type):
name = name.encode('ascii')
data += name + b'\0'
self.write_packet(data) self.write_packet(data)
auth_packet = self._read_packet() auth_packet = self._read_packet()
# if authentication method isn't accepted the first byte # if authentication method isn't accepted the first byte
# will have the octet 254 # will have the octet 254
if auth_packet.is_auth_switch_request(): if auth_packet.is_auth_switch_request():
# https://dev.mysql.com/doc/internals/en/connection-phase-packets.html#packet-Protocol::AuthSwitchRequest # https://dev.mysql.com/doc/internals/en/connection-phase-packets.html#packet-Protocol::AuthSwitchRequest
auth_packet.read_uint8() # 0xfe packet identifier auth_packet.read_uint8() # 0xfe packet identifier
@@ -1119,8 +1113,13 @@ class Connection(object):
self.write_packet(data) self.write_packet(data)
auth_packet = self._read_packet() auth_packet = self._read_packet()
#TODO: ok packet or error packet?
def _process_auth(self, plugin_name, auth_packet): def _process_auth(self, plugin_name, auth_packet):
plugin_class = self.plugin_map.get(plugin_name) plugin_class = self._auth_plugin_map.get(plugin_name)
if not plugin_class:
plugin_class = self._auth_plugin_map.get(plugin_name.decode('ascii'))
if plugin_class: if plugin_class:
try: try:
handler = plugin_class(self) handler = plugin_class(self)
@@ -1246,18 +1245,13 @@ class Connection(object):
server_end = data.find(b'\0', i) server_end = data.find(b'\0', i)
if server_end < 0: # pragma: no cover - very specific upstream bug if server_end < 0: # pragma: no cover - very specific upstream bug
# not found \0 and last field so take it all # not found \0 and last field so take it all
self.plugin_name = data[i:].decode('latin1') self._auth_plugin_name = data[i:].decode('latin1')
else: else:
self.plugin_name = data[i:server_end].decode('latin1') self._auth_plugin_name = data[i:server_end].decode('latin1')
else: # pragma: no cover - not testing against any plugin uncapable servers
self.plugin_name = ''
def get_server_info(self): def get_server_info(self):
return self.server_version return self.server_version
def get_plugin_name(self):
return self.plugin_name
Warning = err.Warning Warning = err.Warning
Error = err.Error Error = err.Error
InterfaceError = err.InterfaceError InterfaceError = err.InterfaceError
@@ -1331,7 +1325,11 @@ class MySQLResult(object):
def _read_load_local_packet(self, first_packet): def _read_load_local_packet(self, first_packet):
load_packet = LoadLocalPacketWrapper(first_packet) load_packet = LoadLocalPacketWrapper(first_packet)
sender = LoadLocalFile(load_packet.filename, self.connection) sender = LoadLocalFile(load_packet.filename, self.connection)
try:
sender.send_data() sender.send_data()
except:
self.connection._read_packet() # skip ok packet
raise
ok_packet = self.connection._read_packet() ok_packet = self.connection._read_packet()
if not ok_packet.is_ok_packet(): # pragma: no cover - upstream induced protocol error if not ok_packet.is_ok_packet(): # pragma: no cover - upstream induced protocol error
@@ -1448,27 +1446,20 @@ class LoadLocalFile(object):
"""Send data packets from the local file to the server""" """Send data packets from the local file to the server"""
if not self.connection.socket: if not self.connection.socket:
raise err.InterfaceError("(0, '')") raise err.InterfaceError("(0, '')")
conn = self.connection
# sequence id is 2 as we already sent a query packet
seq_id = 2
try: try:
with open(self.filename, 'rb') as open_file: with open(self.filename, 'rb') as open_file:
chunk_size = self.connection.max_allowed_packet chunk_size = conn.max_allowed_packet
packet = b"" packet = b""
while True: while True:
chunk = open_file.read(chunk_size) chunk = open_file.read(chunk_size)
if not chunk: if not chunk:
break break
packet = struct.pack('<i', len(chunk))[:3] + int2byte(seq_id) conn.write_packet(chunk)
format_str = '!{0}s'.format(len(chunk))
packet += struct.pack(format_str, chunk)
self.connection._write_bytes(packet)
seq_id = (seq_id + 1) % 256
except IOError: except IOError:
raise err.OperationalError(1017, "Can't find file '{0}'".format(self.filename)) raise err.OperationalError(1017, "Can't find file '{0}'".format(self.filename))
finally: finally:
# send the empty packet to signify we are done sending data # send the empty packet to signify we are done sending data
packet = struct.pack('<i', 0)[:3] + int2byte(seq_id) conn.write_packet(b'')
self.connection._write_bytes(packet)
self.next_packet = (seq_id + 1) % 256

View File

@@ -40,6 +40,7 @@ class TempUser:
if self._created: if self._created:
self._c.execute("DROP USER %s" % self._user) self._c.execute("DROP USER %s" % self._user)
class TestAuthentication(base.PyMySQLTestCase): class TestAuthentication(base.PyMySQLTestCase):
socket_auth = False socket_auth = False
@@ -95,7 +96,7 @@ class TestAuthentication(base.PyMySQLTestCase):
def test_plugin(self): def test_plugin(self):
# Bit of an assumption that the current user is a native password # Bit of an assumption that the current user is a native password
self.assertEqual('mysql_native_password', self.connections[0].get_plugin_name()) self.assertEqual('mysql_native_password', self.connections[0]._auth_plugin_name)
@unittest2.skipUnless(socket_auth, "connection to unix_socket required") @unittest2.skipUnless(socket_auth, "connection to unix_socket required")
@unittest2.skipIf(socket_found, "socket plugin already installed") @unittest2.skipIf(socket_found, "socket plugin already installed")
@@ -198,7 +199,7 @@ class TestAuthentication(base.PyMySQLTestCase):
self.databases[0]['db'], 'two_questions', 'notverysecret') as u: self.databases[0]['db'], 'two_questions', 'notverysecret') as u:
with self.assertRaises(pymysql.err.OperationalError): with self.assertRaises(pymysql.err.OperationalError):
pymysql.connect(user='pymysql_2q', **self.db) pymysql.connect(user='pymysql_2q', **self.db)
pymysql.connect(user='pymysql_2q', plugin_map={b'dialog': TestAuthentication.Dialog}, **self.db) pymysql.connect(user='pymysql_2q', auth_plugin_map={b'dialog': TestAuthentication.Dialog}, **self.db)
@unittest2.skipUnless(socket_auth, "connection to unix_socket required") @unittest2.skipUnless(socket_auth, "connection to unix_socket required")
@unittest2.skipIf(three_attempts_found, "three_attempts plugin already installed") @unittest2.skipIf(three_attempts_found, "three_attempts plugin already installed")
@@ -225,21 +226,21 @@ class TestAuthentication(base.PyMySQLTestCase):
TestAuthentication.Dialog.fail=True # fail just once. We've got three attempts after all TestAuthentication.Dialog.fail=True # fail just once. We've got three attempts after all
with TempUser(self.connections[0].cursor(), 'pymysql_3a@localhost', with TempUser(self.connections[0].cursor(), 'pymysql_3a@localhost',
self.databases[0]['db'], 'three_attempts', 'stillnotverysecret') as u: self.databases[0]['db'], 'three_attempts', 'stillnotverysecret') as u:
pymysql.connect(user='pymysql_3a', plugin_map={b'dialog': TestAuthentication.Dialog}, **self.db) pymysql.connect(user='pymysql_3a', auth_plugin_map={b'dialog': TestAuthentication.Dialog}, **self.db)
pymysql.connect(user='pymysql_3a', plugin_map={b'dialog': TestAuthentication.DialogHandler}, **self.db) pymysql.connect(user='pymysql_3a', auth_plugin_map={b'dialog': TestAuthentication.DialogHandler}, **self.db)
with self.assertRaises(pymysql.err.OperationalError): with self.assertRaises(pymysql.err.OperationalError):
pymysql.connect(user='pymysql_3a', plugin_map={b'dialog': object}, **self.db) pymysql.connect(user='pymysql_3a', auth_plugin_map={b'dialog': object}, **self.db)
with self.assertRaises(pymysql.err.OperationalError): with self.assertRaises(pymysql.err.OperationalError):
pymysql.connect(user='pymysql_3a', plugin_map={b'dialog': TestAuthentication.DefectiveHandler}, **self.db) pymysql.connect(user='pymysql_3a', auth_plugin_map={b'dialog': TestAuthentication.DefectiveHandler}, **self.db)
with self.assertRaises(pymysql.err.OperationalError): with self.assertRaises(pymysql.err.OperationalError):
pymysql.connect(user='pymysql_3a', plugin_map={b'notdialogplugin': TestAuthentication.Dialog}, **self.db) pymysql.connect(user='pymysql_3a', auth_plugin_map={b'notdialogplugin': TestAuthentication.Dialog}, **self.db)
TestAuthentication.Dialog.m = {b'Password, please:': b'I do not know'} TestAuthentication.Dialog.m = {b'Password, please:': b'I do not know'}
with self.assertRaises(pymysql.err.OperationalError): with self.assertRaises(pymysql.err.OperationalError):
pymysql.connect(user='pymysql_3a', plugin_map={b'dialog': TestAuthentication.Dialog}, **self.db) pymysql.connect(user='pymysql_3a', auth_plugin_map={b'dialog': TestAuthentication.Dialog}, **self.db)
TestAuthentication.Dialog.m = {b'Password, please:': None} TestAuthentication.Dialog.m = {b'Password, please:': None}
with self.assertRaises(pymysql.err.OperationalError): with self.assertRaises(pymysql.err.OperationalError):
pymysql.connect(user='pymysql_3a', plugin_map={b'dialog': TestAuthentication.Dialog}, **self.db) pymysql.connect(user='pymysql_3a', auth_plugin_map={b'dialog': TestAuthentication.Dialog}, **self.db)
@unittest2.skipUnless(socket_auth, "connection to unix_socket required") @unittest2.skipUnless(socket_auth, "connection to unix_socket required")
@unittest2.skipIf(pam_found, "pam plugin already installed") @unittest2.skipIf(pam_found, "pam plugin already installed")
@@ -285,12 +286,16 @@ class TestAuthentication(base.PyMySQLTestCase):
c = pymysql.connect(user=TestAuthentication.osuser, **db) c = pymysql.connect(user=TestAuthentication.osuser, **db)
db['password'] = 'very bad guess at password' db['password'] = 'very bad guess at password'
with self.assertRaises(pymysql.err.OperationalError): with self.assertRaises(pymysql.err.OperationalError):
pymysql.connect(user=TestAuthentication.osuser, plugin_map={b'mysql_cleartext_password': TestAuthentication.DefectiveHandler}, **self.db) pymysql.connect(user=TestAuthentication.osuser,
auth_plugin_map={b'mysql_cleartext_password': TestAuthentication.DefectiveHandler},
**self.db)
except pymysql.OperationalError as e: except pymysql.OperationalError as e:
self.assertEqual(1045, e.args[0]) self.assertEqual(1045, e.args[0])
# we had 'bad guess at password' work with pam. Well at least we get a permission denied here # we had 'bad guess at password' work with pam. Well at least we get a permission denied here
with self.assertRaises(pymysql.err.OperationalError): with self.assertRaises(pymysql.err.OperationalError):
pymysql.connect(user=TestAuthentication.osuser, plugin_map={b'mysql_cleartext_password': TestAuthentication.DefectiveHandler}, **self.db) pymysql.connect(user=TestAuthentication.osuser,
auth_plugin_map={b'mysql_cleartext_password': TestAuthentication.DefectiveHandler},
**self.db)
if grants: if grants:
# recreate the user # recreate the user
cur.execute(grants) cur.execute(grants)

View File

@@ -61,6 +61,7 @@ class TestLoadLocal(base.PyMySQLTestCase):
self.assertTrue("Incorrect integer value" in str(w[-1].message)) self.assertTrue("Incorrect integer value" in str(w[-1].message))
finally: finally:
c.execute("DROP TABLE test_load_local") c.execute("DROP TABLE test_load_local")
c.close()
if __name__ == "__main__": if __name__ == "__main__":