Merge pull request #405 from methane/feature/refactor

refactoring auth plugin support
This commit is contained in:
INADA Naoki
2016-01-04 21:13:24 +09:00
4 changed files with 79 additions and 80 deletions

View File

@@ -1,8 +1,9 @@
sudo: false
language: python
python: "3.4"
python: "3.5"
cache:
pip: true
directories:
- $HOME/.cache/pip
env:
matrix:
@@ -10,6 +11,7 @@ env:
- TOX_ENV=py27
- TOX_ENV=py33
- TOX_ENV=py34
- TOX_ENV=py35
- TOX_ENV=pypy
- TOX_ENV=pypy3
@@ -36,7 +38,7 @@ matrix:
sudo: required
- env:
- TOX_ENV=py34
- DB=5.6.26
- DB=5.6.28
addons:
apt:
packages:

View File

@@ -525,6 +525,7 @@ class Connection(object):
"""
socket = None
_auth_plugin_name = ''
def __init__(self, host=None, user=None, password="",
database=None, port=3306, unix_socket=None,
@@ -535,7 +536,7 @@ class Connection(object):
compress=None, named_pipe=None, no_delay=None,
autocommit=False, db=None, passwd=None, local_infile=False,
max_allowed_packet=16*1024*1024, defer_connect=False,
plugin_map={}):
auth_plugin_map={}):
"""
Establish a connection to the MySQL database. Accepts several
arguments:
@@ -571,12 +572,11 @@ class Connection(object):
max_allowed_packet: Max size of packet sent to server in bytes. (default: 16MB)
defer_connect: Don't explicitly connect on contruction - wait for connect call.
(default: False)
plugin_map: Map of plugin names to a class that processes that plugin. The class
will take the Connection object as the argument to the constructor. The class
needs an authenticate method taking an authentication packet as an argument.
For the dialog plugin, a prompt(echo, prompt) method can be used (if no
authenticate method) for returning a string from the user.
auth_plugin_map: A dict of plugin names to a class that processes that plugin.
The class will take the Connection object as the argument to the constructor.
The class needs an authenticate method taking an authentication packet as
an argument. For the dialog plugin, a prompt(echo, prompt) method can be used
(if no authenticate method) for returning a string from the user. (experimental)
db: Alias for database. (for compatibility to MySQLdb)
passwd: Alias for password. (for compatibility to MySQLdb)
"""
@@ -672,7 +672,7 @@ class Connection(object):
self.sql_mode = sql_mode
self.init_command = init_command
self.max_allowed_packet = max_allowed_packet
self.plugin_map = plugin_map
self._auth_plugin_map = auth_plugin_map
if defer_connect:
self.socket = None
else:
@@ -726,7 +726,6 @@ class Connection(object):
def autocommit(self, value):
self.autocommit_mode = bool(value)
current = self.get_autocommit()
self.next_packet = 1
if value != current:
self._send_autocommit_mode()
@@ -816,7 +815,6 @@ class Connection(object):
"You may not close previous cursor.")
# if DEBUG:
# print("DEBUG: sending query:", sql)
self.next_packet = 1
if isinstance(sql, text_type) and not (JYTHON or IRONPYTHON):
if PY2:
sql = sql.encode(self.encoding)
@@ -824,7 +822,6 @@ class Connection(object):
sql = sql.encode(self.encoding, 'surrogateescape')
self._execute_command(COMMAND.COM_QUERY, sql)
self._affected_rows = self._read_query_result(unbuffered=unbuffered)
self.next_packet = 1
return self._affected_rows
def next_result(self, unbuffered=False):
@@ -892,7 +889,7 @@ class Connection(object):
sock.setsockopt(socket.SOL_SOCKET, socket.SO_KEEPALIVE, 1)
self.socket = sock
self._rfile = _makefile(sock, 'rb')
self.next_packet = 0
self._next_seq_id = 0
self._get_server_information()
self._request_authentication()
@@ -933,16 +930,16 @@ class Connection(object):
# So just reraise it.
raise
def write_packet(self, data):
def write_packet(self, payload):
"""Writes an entire "mysql packet" in its entirety to the network
addings its length and sequence number. Intended for use by plugins
only.
addings its length and sequence number.
"""
data = pack_int24(len(data)) + int2byte(self.next_packet) + data
# Internal note: when you build packet manualy and calls _write_bytes()
# directly, you should set self._next_seq_id properly.
data = pack_int24(len(payload)) + int2byte(self._next_seq_id) + payload
if DEBUG: dump_packet(data)
self._write_bytes(data)
self.next_packet = (self.next_packet + 1) % 256
self._next_seq_id = (self._next_seq_id + 1) % 256
def _read_packet(self, packet_type=MysqlPacket):
"""Read an entire "mysql packet" in its entirety from the network
@@ -952,8 +949,14 @@ class Connection(object):
while True:
packet_header = self._read_bytes(4)
if DEBUG: dump_packet(packet_header)
btrl, btrh, packet_number = struct.unpack('<HBB', packet_header)
bytes_to_read = btrl + (btrh << 16)
if packet_number != self._next_seq_id:
raise err.InternalError("Packet sequence number wrong - got %d expected %d" %
(packet_number, self._next_seq_id))
self._next_seq_id = (self._next_seq_id + 1) % 256
recv_data = self._read_bytes(bytes_to_read)
if DEBUG: dump_packet(recv_data)
buff += recv_data
@@ -962,13 +965,7 @@ class Connection(object):
continue
if bytes_to_read < MAX_PACKET_LEN:
break
if packet_number != self.next_packet:
pass
#TODO: check sequence id
#raise err.InternalError("Packet sequence number wrong - got %d expected %d" %
# (packet_number, self.next_packet))
self.next_packet = (packet_number + 1) % 256
packet = packet_type(buff, self.encoding)
packet.check_error()
return packet
@@ -1027,33 +1024,32 @@ class Connection(object):
if self._result is not None and self._result.unbuffered_active:
warnings.warn("Previous unbuffered result was left incomplete")
self._result._finish_unbuffered_query()
self._result = None
if isinstance(sql, text_type):
sql = sql.encode(self.encoding)
chunk_size = min(self.max_allowed_packet, len(sql) + 1) # +1 is for command
# +1 is for command
chunk_size = min(self.max_allowed_packet, len(sql) + 1)
# tiny optimization: build first packet manually instead of
# calling self..write_packet()
prelude = struct.pack('<iB', chunk_size, command)
self._write_bytes(prelude + sql[:chunk_size-1])
if DEBUG: dump_packet(prelude + sql)
packet = prelude + sql[:chunk_size-1]
self._write_bytes(packet)
if DEBUG: dump_packet(packet)
self._next_seq_id = 1
self.next_packet = 1
if chunk_size < self.max_allowed_packet:
return
seq_id = 1
sql = sql[chunk_size-1:]
while True:
chunk_size = min(self.max_allowed_packet, len(sql))
prelude = struct.pack('<i', chunk_size)[:3]
data = prelude + int2byte(seq_id%256) + sql[:chunk_size]
self._write_bytes(data)
if DEBUG: dump_packet(data)
self.write_packet(sql[:chunk_size])
sql = sql[chunk_size:]
if not sql and chunk_size < self.max_allowed_packet:
break
seq_id += 1
self.next_packet = seq_id%256
def _request_authentication(self):
# https://dev.mysql.com/doc/internals/en/connection-phase-packets.html#packet-Protocol::HandshakeResponse
@@ -1078,18 +1074,15 @@ class Connection(object):
data = data_init + self.user + b'\0'
authresp = ''
if self.plugin_name == 'mysql_native_password':
authresp = b''
if self._auth_plugin_name == 'mysql_native_password':
authresp = _scramble(self.password.encode('latin1'), self.salt)
if self.server_capabilities & CLIENT.PLUGIN_AUTH_LENENC_CLIENT_DATA:
data += lenenc_int(len(authresp))
data += authresp
data += lenenc_int(len(authresp)) + authresp
elif self.server_capabilities & CLIENT.SECURE_CONNECTION:
length = len(authresp)
data += struct.pack('B', length)
data += authresp
else: # pragma: no cover - not testing against servers without secure auth (>=5.0)
data += struct.pack('B', len(authresp)) + authresp
else: # pragma: no cover - not testing against servers without secure auth (>=5.0)
data += authresp + b'\0'
if self.db and self.server_capabilities & CLIENT.CONNECT_WITH_DB:
@@ -1098,15 +1091,16 @@ class Connection(object):
data += self.db + b'\0'
if self.server_capabilities & CLIENT.PLUGIN_AUTH:
data += self.plugin_name.encode('latin1') + b'\0'
name = self._auth_plugin_name
if isinstance(name, text_type):
name = name.encode('ascii')
data += name + b'\0'
self.write_packet(data)
auth_packet = self._read_packet()
# if authentication method isn't accepted the first byte
# will have the octet 254
if auth_packet.is_auth_switch_request():
# https://dev.mysql.com/doc/internals/en/connection-phase-packets.html#packet-Protocol::AuthSwitchRequest
auth_packet.read_uint8() # 0xfe packet identifier
@@ -1119,8 +1113,13 @@ class Connection(object):
self.write_packet(data)
auth_packet = self._read_packet()
#TODO: ok packet or error packet?
def _process_auth(self, plugin_name, auth_packet):
plugin_class = self.plugin_map.get(plugin_name)
plugin_class = self._auth_plugin_map.get(plugin_name)
if not plugin_class:
plugin_class = self._auth_plugin_map.get(plugin_name.decode('ascii'))
if plugin_class:
try:
handler = plugin_class(self)
@@ -1246,18 +1245,13 @@ class Connection(object):
server_end = data.find(b'\0', i)
if server_end < 0: # pragma: no cover - very specific upstream bug
# not found \0 and last field so take it all
self.plugin_name = data[i:].decode('latin1')
self._auth_plugin_name = data[i:].decode('latin1')
else:
self.plugin_name = data[i:server_end].decode('latin1')
else: # pragma: no cover - not testing against any plugin uncapable servers
self.plugin_name = ''
self._auth_plugin_name = data[i:server_end].decode('latin1')
def get_server_info(self):
return self.server_version
def get_plugin_name(self):
return self.plugin_name
Warning = err.Warning
Error = err.Error
InterfaceError = err.InterfaceError
@@ -1331,7 +1325,11 @@ class MySQLResult(object):
def _read_load_local_packet(self, first_packet):
load_packet = LoadLocalPacketWrapper(first_packet)
sender = LoadLocalFile(load_packet.filename, self.connection)
sender.send_data()
try:
sender.send_data()
except:
self.connection._read_packet() # skip ok packet
raise
ok_packet = self.connection._read_packet()
if not ok_packet.is_ok_packet(): # pragma: no cover - upstream induced protocol error
@@ -1448,27 +1446,20 @@ class LoadLocalFile(object):
"""Send data packets from the local file to the server"""
if not self.connection.socket:
raise err.InterfaceError("(0, '')")
conn = self.connection
# sequence id is 2 as we already sent a query packet
seq_id = 2
try:
with open(self.filename, 'rb') as open_file:
chunk_size = self.connection.max_allowed_packet
chunk_size = conn.max_allowed_packet
packet = b""
while True:
chunk = open_file.read(chunk_size)
if not chunk:
break
packet = struct.pack('<i', len(chunk))[:3] + int2byte(seq_id)
format_str = '!{0}s'.format(len(chunk))
packet += struct.pack(format_str, chunk)
self.connection._write_bytes(packet)
seq_id = (seq_id + 1) % 256
conn.write_packet(chunk)
except IOError:
raise err.OperationalError(1017, "Can't find file '{0}'".format(self.filename))
finally:
# send the empty packet to signify we are done sending data
packet = struct.pack('<i', 0)[:3] + int2byte(seq_id)
self.connection._write_bytes(packet)
self.next_packet = (seq_id + 1) % 256
conn.write_packet(b'')

View File

@@ -40,6 +40,7 @@ class TempUser:
if self._created:
self._c.execute("DROP USER %s" % self._user)
class TestAuthentication(base.PyMySQLTestCase):
socket_auth = False
@@ -95,7 +96,7 @@ class TestAuthentication(base.PyMySQLTestCase):
def test_plugin(self):
# Bit of an assumption that the current user is a native password
self.assertEqual('mysql_native_password', self.connections[0].get_plugin_name())
self.assertEqual('mysql_native_password', self.connections[0]._auth_plugin_name)
@unittest2.skipUnless(socket_auth, "connection to unix_socket required")
@unittest2.skipIf(socket_found, "socket plugin already installed")
@@ -198,7 +199,7 @@ class TestAuthentication(base.PyMySQLTestCase):
self.databases[0]['db'], 'two_questions', 'notverysecret') as u:
with self.assertRaises(pymysql.err.OperationalError):
pymysql.connect(user='pymysql_2q', **self.db)
pymysql.connect(user='pymysql_2q', plugin_map={b'dialog': TestAuthentication.Dialog}, **self.db)
pymysql.connect(user='pymysql_2q', auth_plugin_map={b'dialog': TestAuthentication.Dialog}, **self.db)
@unittest2.skipUnless(socket_auth, "connection to unix_socket required")
@unittest2.skipIf(three_attempts_found, "three_attempts plugin already installed")
@@ -225,21 +226,21 @@ class TestAuthentication(base.PyMySQLTestCase):
TestAuthentication.Dialog.fail=True # fail just once. We've got three attempts after all
with TempUser(self.connections[0].cursor(), 'pymysql_3a@localhost',
self.databases[0]['db'], 'three_attempts', 'stillnotverysecret') as u:
pymysql.connect(user='pymysql_3a', plugin_map={b'dialog': TestAuthentication.Dialog}, **self.db)
pymysql.connect(user='pymysql_3a', plugin_map={b'dialog': TestAuthentication.DialogHandler}, **self.db)
pymysql.connect(user='pymysql_3a', auth_plugin_map={b'dialog': TestAuthentication.Dialog}, **self.db)
pymysql.connect(user='pymysql_3a', auth_plugin_map={b'dialog': TestAuthentication.DialogHandler}, **self.db)
with self.assertRaises(pymysql.err.OperationalError):
pymysql.connect(user='pymysql_3a', plugin_map={b'dialog': object}, **self.db)
pymysql.connect(user='pymysql_3a', auth_plugin_map={b'dialog': object}, **self.db)
with self.assertRaises(pymysql.err.OperationalError):
pymysql.connect(user='pymysql_3a', plugin_map={b'dialog': TestAuthentication.DefectiveHandler}, **self.db)
pymysql.connect(user='pymysql_3a', auth_plugin_map={b'dialog': TestAuthentication.DefectiveHandler}, **self.db)
with self.assertRaises(pymysql.err.OperationalError):
pymysql.connect(user='pymysql_3a', plugin_map={b'notdialogplugin': TestAuthentication.Dialog}, **self.db)
pymysql.connect(user='pymysql_3a', auth_plugin_map={b'notdialogplugin': TestAuthentication.Dialog}, **self.db)
TestAuthentication.Dialog.m = {b'Password, please:': b'I do not know'}
with self.assertRaises(pymysql.err.OperationalError):
pymysql.connect(user='pymysql_3a', plugin_map={b'dialog': TestAuthentication.Dialog}, **self.db)
pymysql.connect(user='pymysql_3a', auth_plugin_map={b'dialog': TestAuthentication.Dialog}, **self.db)
TestAuthentication.Dialog.m = {b'Password, please:': None}
with self.assertRaises(pymysql.err.OperationalError):
pymysql.connect(user='pymysql_3a', plugin_map={b'dialog': TestAuthentication.Dialog}, **self.db)
pymysql.connect(user='pymysql_3a', auth_plugin_map={b'dialog': TestAuthentication.Dialog}, **self.db)
@unittest2.skipUnless(socket_auth, "connection to unix_socket required")
@unittest2.skipIf(pam_found, "pam plugin already installed")
@@ -285,12 +286,16 @@ class TestAuthentication(base.PyMySQLTestCase):
c = pymysql.connect(user=TestAuthentication.osuser, **db)
db['password'] = 'very bad guess at password'
with self.assertRaises(pymysql.err.OperationalError):
pymysql.connect(user=TestAuthentication.osuser, plugin_map={b'mysql_cleartext_password': TestAuthentication.DefectiveHandler}, **self.db)
pymysql.connect(user=TestAuthentication.osuser,
auth_plugin_map={b'mysql_cleartext_password': TestAuthentication.DefectiveHandler},
**self.db)
except pymysql.OperationalError as e:
self.assertEqual(1045, e.args[0])
# we had 'bad guess at password' work with pam. Well at least we get a permission denied here
with self.assertRaises(pymysql.err.OperationalError):
pymysql.connect(user=TestAuthentication.osuser, plugin_map={b'mysql_cleartext_password': TestAuthentication.DefectiveHandler}, **self.db)
pymysql.connect(user=TestAuthentication.osuser,
auth_plugin_map={b'mysql_cleartext_password': TestAuthentication.DefectiveHandler},
**self.db)
if grants:
# recreate the user
cur.execute(grants)

View File

@@ -61,6 +61,7 @@ class TestLoadLocal(base.PyMySQLTestCase):
self.assertTrue("Incorrect integer value" in str(w[-1].message))
finally:
c.execute("DROP TABLE test_load_local")
c.close()
if __name__ == "__main__":