Merge pull request #405 from methane/feature/refactor
refactoring auth plugin support
This commit is contained in:
@@ -1,8 +1,9 @@
|
||||
sudo: false
|
||||
language: python
|
||||
python: "3.4"
|
||||
python: "3.5"
|
||||
cache:
|
||||
pip: true
|
||||
directories:
|
||||
- $HOME/.cache/pip
|
||||
|
||||
env:
|
||||
matrix:
|
||||
@@ -10,6 +11,7 @@ env:
|
||||
- TOX_ENV=py27
|
||||
- TOX_ENV=py33
|
||||
- TOX_ENV=py34
|
||||
- TOX_ENV=py35
|
||||
- TOX_ENV=pypy
|
||||
- TOX_ENV=pypy3
|
||||
|
||||
@@ -36,7 +38,7 @@ matrix:
|
||||
sudo: required
|
||||
- env:
|
||||
- TOX_ENV=py34
|
||||
- DB=5.6.26
|
||||
- DB=5.6.28
|
||||
addons:
|
||||
apt:
|
||||
packages:
|
||||
|
||||
@@ -525,6 +525,7 @@ class Connection(object):
|
||||
"""
|
||||
|
||||
socket = None
|
||||
_auth_plugin_name = ''
|
||||
|
||||
def __init__(self, host=None, user=None, password="",
|
||||
database=None, port=3306, unix_socket=None,
|
||||
@@ -535,7 +536,7 @@ class Connection(object):
|
||||
compress=None, named_pipe=None, no_delay=None,
|
||||
autocommit=False, db=None, passwd=None, local_infile=False,
|
||||
max_allowed_packet=16*1024*1024, defer_connect=False,
|
||||
plugin_map={}):
|
||||
auth_plugin_map={}):
|
||||
"""
|
||||
Establish a connection to the MySQL database. Accepts several
|
||||
arguments:
|
||||
@@ -571,12 +572,11 @@ class Connection(object):
|
||||
max_allowed_packet: Max size of packet sent to server in bytes. (default: 16MB)
|
||||
defer_connect: Don't explicitly connect on contruction - wait for connect call.
|
||||
(default: False)
|
||||
plugin_map: Map of plugin names to a class that processes that plugin. The class
|
||||
will take the Connection object as the argument to the constructor. The class
|
||||
needs an authenticate method taking an authentication packet as an argument.
|
||||
For the dialog plugin, a prompt(echo, prompt) method can be used (if no
|
||||
authenticate method) for returning a string from the user.
|
||||
|
||||
auth_plugin_map: A dict of plugin names to a class that processes that plugin.
|
||||
The class will take the Connection object as the argument to the constructor.
|
||||
The class needs an authenticate method taking an authentication packet as
|
||||
an argument. For the dialog plugin, a prompt(echo, prompt) method can be used
|
||||
(if no authenticate method) for returning a string from the user. (experimental)
|
||||
db: Alias for database. (for compatibility to MySQLdb)
|
||||
passwd: Alias for password. (for compatibility to MySQLdb)
|
||||
"""
|
||||
@@ -672,7 +672,7 @@ class Connection(object):
|
||||
self.sql_mode = sql_mode
|
||||
self.init_command = init_command
|
||||
self.max_allowed_packet = max_allowed_packet
|
||||
self.plugin_map = plugin_map
|
||||
self._auth_plugin_map = auth_plugin_map
|
||||
if defer_connect:
|
||||
self.socket = None
|
||||
else:
|
||||
@@ -726,7 +726,6 @@ class Connection(object):
|
||||
def autocommit(self, value):
|
||||
self.autocommit_mode = bool(value)
|
||||
current = self.get_autocommit()
|
||||
self.next_packet = 1
|
||||
if value != current:
|
||||
self._send_autocommit_mode()
|
||||
|
||||
@@ -816,7 +815,6 @@ class Connection(object):
|
||||
"You may not close previous cursor.")
|
||||
# if DEBUG:
|
||||
# print("DEBUG: sending query:", sql)
|
||||
self.next_packet = 1
|
||||
if isinstance(sql, text_type) and not (JYTHON or IRONPYTHON):
|
||||
if PY2:
|
||||
sql = sql.encode(self.encoding)
|
||||
@@ -824,7 +822,6 @@ class Connection(object):
|
||||
sql = sql.encode(self.encoding, 'surrogateescape')
|
||||
self._execute_command(COMMAND.COM_QUERY, sql)
|
||||
self._affected_rows = self._read_query_result(unbuffered=unbuffered)
|
||||
self.next_packet = 1
|
||||
return self._affected_rows
|
||||
|
||||
def next_result(self, unbuffered=False):
|
||||
@@ -892,7 +889,7 @@ class Connection(object):
|
||||
sock.setsockopt(socket.SOL_SOCKET, socket.SO_KEEPALIVE, 1)
|
||||
self.socket = sock
|
||||
self._rfile = _makefile(sock, 'rb')
|
||||
self.next_packet = 0
|
||||
self._next_seq_id = 0
|
||||
|
||||
self._get_server_information()
|
||||
self._request_authentication()
|
||||
@@ -933,16 +930,16 @@ class Connection(object):
|
||||
# So just reraise it.
|
||||
raise
|
||||
|
||||
def write_packet(self, data):
|
||||
def write_packet(self, payload):
|
||||
"""Writes an entire "mysql packet" in its entirety to the network
|
||||
addings its length and sequence number. Intended for use by plugins
|
||||
only.
|
||||
addings its length and sequence number.
|
||||
"""
|
||||
data = pack_int24(len(data)) + int2byte(self.next_packet) + data
|
||||
# Internal note: when you build packet manualy and calls _write_bytes()
|
||||
# directly, you should set self._next_seq_id properly.
|
||||
data = pack_int24(len(payload)) + int2byte(self._next_seq_id) + payload
|
||||
if DEBUG: dump_packet(data)
|
||||
|
||||
self._write_bytes(data)
|
||||
self.next_packet = (self.next_packet + 1) % 256
|
||||
self._next_seq_id = (self._next_seq_id + 1) % 256
|
||||
|
||||
def _read_packet(self, packet_type=MysqlPacket):
|
||||
"""Read an entire "mysql packet" in its entirety from the network
|
||||
@@ -952,8 +949,14 @@ class Connection(object):
|
||||
while True:
|
||||
packet_header = self._read_bytes(4)
|
||||
if DEBUG: dump_packet(packet_header)
|
||||
|
||||
btrl, btrh, packet_number = struct.unpack('<HBB', packet_header)
|
||||
bytes_to_read = btrl + (btrh << 16)
|
||||
if packet_number != self._next_seq_id:
|
||||
raise err.InternalError("Packet sequence number wrong - got %d expected %d" %
|
||||
(packet_number, self._next_seq_id))
|
||||
self._next_seq_id = (self._next_seq_id + 1) % 256
|
||||
|
||||
recv_data = self._read_bytes(bytes_to_read)
|
||||
if DEBUG: dump_packet(recv_data)
|
||||
buff += recv_data
|
||||
@@ -962,13 +965,7 @@ class Connection(object):
|
||||
continue
|
||||
if bytes_to_read < MAX_PACKET_LEN:
|
||||
break
|
||||
if packet_number != self.next_packet:
|
||||
pass
|
||||
#TODO: check sequence id
|
||||
#raise err.InternalError("Packet sequence number wrong - got %d expected %d" %
|
||||
# (packet_number, self.next_packet))
|
||||
|
||||
self.next_packet = (packet_number + 1) % 256
|
||||
packet = packet_type(buff, self.encoding)
|
||||
packet.check_error()
|
||||
return packet
|
||||
@@ -1027,33 +1024,32 @@ class Connection(object):
|
||||
if self._result is not None and self._result.unbuffered_active:
|
||||
warnings.warn("Previous unbuffered result was left incomplete")
|
||||
self._result._finish_unbuffered_query()
|
||||
self._result = None
|
||||
|
||||
if isinstance(sql, text_type):
|
||||
sql = sql.encode(self.encoding)
|
||||
|
||||
chunk_size = min(self.max_allowed_packet, len(sql) + 1) # +1 is for command
|
||||
# +1 is for command
|
||||
chunk_size = min(self.max_allowed_packet, len(sql) + 1)
|
||||
|
||||
# tiny optimization: build first packet manually instead of
|
||||
# calling self..write_packet()
|
||||
prelude = struct.pack('<iB', chunk_size, command)
|
||||
self._write_bytes(prelude + sql[:chunk_size-1])
|
||||
if DEBUG: dump_packet(prelude + sql)
|
||||
packet = prelude + sql[:chunk_size-1]
|
||||
self._write_bytes(packet)
|
||||
if DEBUG: dump_packet(packet)
|
||||
self._next_seq_id = 1
|
||||
|
||||
self.next_packet = 1
|
||||
if chunk_size < self.max_allowed_packet:
|
||||
return
|
||||
|
||||
seq_id = 1
|
||||
sql = sql[chunk_size-1:]
|
||||
while True:
|
||||
chunk_size = min(self.max_allowed_packet, len(sql))
|
||||
prelude = struct.pack('<i', chunk_size)[:3]
|
||||
data = prelude + int2byte(seq_id%256) + sql[:chunk_size]
|
||||
self._write_bytes(data)
|
||||
if DEBUG: dump_packet(data)
|
||||
self.write_packet(sql[:chunk_size])
|
||||
sql = sql[chunk_size:]
|
||||
if not sql and chunk_size < self.max_allowed_packet:
|
||||
break
|
||||
seq_id += 1
|
||||
self.next_packet = seq_id%256
|
||||
|
||||
def _request_authentication(self):
|
||||
# https://dev.mysql.com/doc/internals/en/connection-phase-packets.html#packet-Protocol::HandshakeResponse
|
||||
@@ -1078,18 +1074,15 @@ class Connection(object):
|
||||
|
||||
data = data_init + self.user + b'\0'
|
||||
|
||||
authresp = ''
|
||||
if self.plugin_name == 'mysql_native_password':
|
||||
authresp = b''
|
||||
if self._auth_plugin_name == 'mysql_native_password':
|
||||
authresp = _scramble(self.password.encode('latin1'), self.salt)
|
||||
|
||||
if self.server_capabilities & CLIENT.PLUGIN_AUTH_LENENC_CLIENT_DATA:
|
||||
data += lenenc_int(len(authresp))
|
||||
data += authresp
|
||||
data += lenenc_int(len(authresp)) + authresp
|
||||
elif self.server_capabilities & CLIENT.SECURE_CONNECTION:
|
||||
length = len(authresp)
|
||||
data += struct.pack('B', length)
|
||||
data += authresp
|
||||
else: # pragma: no cover - not testing against servers without secure auth (>=5.0)
|
||||
data += struct.pack('B', len(authresp)) + authresp
|
||||
else: # pragma: no cover - not testing against servers without secure auth (>=5.0)
|
||||
data += authresp + b'\0'
|
||||
|
||||
if self.db and self.server_capabilities & CLIENT.CONNECT_WITH_DB:
|
||||
@@ -1098,15 +1091,16 @@ class Connection(object):
|
||||
data += self.db + b'\0'
|
||||
|
||||
if self.server_capabilities & CLIENT.PLUGIN_AUTH:
|
||||
data += self.plugin_name.encode('latin1') + b'\0'
|
||||
name = self._auth_plugin_name
|
||||
if isinstance(name, text_type):
|
||||
name = name.encode('ascii')
|
||||
data += name + b'\0'
|
||||
|
||||
self.write_packet(data)
|
||||
|
||||
auth_packet = self._read_packet()
|
||||
|
||||
# if authentication method isn't accepted the first byte
|
||||
# will have the octet 254
|
||||
|
||||
if auth_packet.is_auth_switch_request():
|
||||
# https://dev.mysql.com/doc/internals/en/connection-phase-packets.html#packet-Protocol::AuthSwitchRequest
|
||||
auth_packet.read_uint8() # 0xfe packet identifier
|
||||
@@ -1119,8 +1113,13 @@ class Connection(object):
|
||||
self.write_packet(data)
|
||||
auth_packet = self._read_packet()
|
||||
|
||||
#TODO: ok packet or error packet?
|
||||
|
||||
|
||||
def _process_auth(self, plugin_name, auth_packet):
|
||||
plugin_class = self.plugin_map.get(plugin_name)
|
||||
plugin_class = self._auth_plugin_map.get(plugin_name)
|
||||
if not plugin_class:
|
||||
plugin_class = self._auth_plugin_map.get(plugin_name.decode('ascii'))
|
||||
if plugin_class:
|
||||
try:
|
||||
handler = plugin_class(self)
|
||||
@@ -1246,18 +1245,13 @@ class Connection(object):
|
||||
server_end = data.find(b'\0', i)
|
||||
if server_end < 0: # pragma: no cover - very specific upstream bug
|
||||
# not found \0 and last field so take it all
|
||||
self.plugin_name = data[i:].decode('latin1')
|
||||
self._auth_plugin_name = data[i:].decode('latin1')
|
||||
else:
|
||||
self.plugin_name = data[i:server_end].decode('latin1')
|
||||
else: # pragma: no cover - not testing against any plugin uncapable servers
|
||||
self.plugin_name = ''
|
||||
self._auth_plugin_name = data[i:server_end].decode('latin1')
|
||||
|
||||
def get_server_info(self):
|
||||
return self.server_version
|
||||
|
||||
def get_plugin_name(self):
|
||||
return self.plugin_name
|
||||
|
||||
Warning = err.Warning
|
||||
Error = err.Error
|
||||
InterfaceError = err.InterfaceError
|
||||
@@ -1331,7 +1325,11 @@ class MySQLResult(object):
|
||||
def _read_load_local_packet(self, first_packet):
|
||||
load_packet = LoadLocalPacketWrapper(first_packet)
|
||||
sender = LoadLocalFile(load_packet.filename, self.connection)
|
||||
sender.send_data()
|
||||
try:
|
||||
sender.send_data()
|
||||
except:
|
||||
self.connection._read_packet() # skip ok packet
|
||||
raise
|
||||
|
||||
ok_packet = self.connection._read_packet()
|
||||
if not ok_packet.is_ok_packet(): # pragma: no cover - upstream induced protocol error
|
||||
@@ -1448,27 +1446,20 @@ class LoadLocalFile(object):
|
||||
"""Send data packets from the local file to the server"""
|
||||
if not self.connection.socket:
|
||||
raise err.InterfaceError("(0, '')")
|
||||
conn = self.connection
|
||||
|
||||
# sequence id is 2 as we already sent a query packet
|
||||
seq_id = 2
|
||||
try:
|
||||
with open(self.filename, 'rb') as open_file:
|
||||
chunk_size = self.connection.max_allowed_packet
|
||||
chunk_size = conn.max_allowed_packet
|
||||
packet = b""
|
||||
|
||||
while True:
|
||||
chunk = open_file.read(chunk_size)
|
||||
if not chunk:
|
||||
break
|
||||
packet = struct.pack('<i', len(chunk))[:3] + int2byte(seq_id)
|
||||
format_str = '!{0}s'.format(len(chunk))
|
||||
packet += struct.pack(format_str, chunk)
|
||||
self.connection._write_bytes(packet)
|
||||
seq_id = (seq_id + 1) % 256
|
||||
conn.write_packet(chunk)
|
||||
except IOError:
|
||||
raise err.OperationalError(1017, "Can't find file '{0}'".format(self.filename))
|
||||
finally:
|
||||
# send the empty packet to signify we are done sending data
|
||||
packet = struct.pack('<i', 0)[:3] + int2byte(seq_id)
|
||||
self.connection._write_bytes(packet)
|
||||
self.next_packet = (seq_id + 1) % 256
|
||||
conn.write_packet(b'')
|
||||
|
||||
@@ -40,6 +40,7 @@ class TempUser:
|
||||
if self._created:
|
||||
self._c.execute("DROP USER %s" % self._user)
|
||||
|
||||
|
||||
class TestAuthentication(base.PyMySQLTestCase):
|
||||
|
||||
socket_auth = False
|
||||
@@ -95,7 +96,7 @@ class TestAuthentication(base.PyMySQLTestCase):
|
||||
|
||||
def test_plugin(self):
|
||||
# Bit of an assumption that the current user is a native password
|
||||
self.assertEqual('mysql_native_password', self.connections[0].get_plugin_name())
|
||||
self.assertEqual('mysql_native_password', self.connections[0]._auth_plugin_name)
|
||||
|
||||
@unittest2.skipUnless(socket_auth, "connection to unix_socket required")
|
||||
@unittest2.skipIf(socket_found, "socket plugin already installed")
|
||||
@@ -198,7 +199,7 @@ class TestAuthentication(base.PyMySQLTestCase):
|
||||
self.databases[0]['db'], 'two_questions', 'notverysecret') as u:
|
||||
with self.assertRaises(pymysql.err.OperationalError):
|
||||
pymysql.connect(user='pymysql_2q', **self.db)
|
||||
pymysql.connect(user='pymysql_2q', plugin_map={b'dialog': TestAuthentication.Dialog}, **self.db)
|
||||
pymysql.connect(user='pymysql_2q', auth_plugin_map={b'dialog': TestAuthentication.Dialog}, **self.db)
|
||||
|
||||
@unittest2.skipUnless(socket_auth, "connection to unix_socket required")
|
||||
@unittest2.skipIf(three_attempts_found, "three_attempts plugin already installed")
|
||||
@@ -225,21 +226,21 @@ class TestAuthentication(base.PyMySQLTestCase):
|
||||
TestAuthentication.Dialog.fail=True # fail just once. We've got three attempts after all
|
||||
with TempUser(self.connections[0].cursor(), 'pymysql_3a@localhost',
|
||||
self.databases[0]['db'], 'three_attempts', 'stillnotverysecret') as u:
|
||||
pymysql.connect(user='pymysql_3a', plugin_map={b'dialog': TestAuthentication.Dialog}, **self.db)
|
||||
pymysql.connect(user='pymysql_3a', plugin_map={b'dialog': TestAuthentication.DialogHandler}, **self.db)
|
||||
pymysql.connect(user='pymysql_3a', auth_plugin_map={b'dialog': TestAuthentication.Dialog}, **self.db)
|
||||
pymysql.connect(user='pymysql_3a', auth_plugin_map={b'dialog': TestAuthentication.DialogHandler}, **self.db)
|
||||
with self.assertRaises(pymysql.err.OperationalError):
|
||||
pymysql.connect(user='pymysql_3a', plugin_map={b'dialog': object}, **self.db)
|
||||
pymysql.connect(user='pymysql_3a', auth_plugin_map={b'dialog': object}, **self.db)
|
||||
|
||||
with self.assertRaises(pymysql.err.OperationalError):
|
||||
pymysql.connect(user='pymysql_3a', plugin_map={b'dialog': TestAuthentication.DefectiveHandler}, **self.db)
|
||||
pymysql.connect(user='pymysql_3a', auth_plugin_map={b'dialog': TestAuthentication.DefectiveHandler}, **self.db)
|
||||
with self.assertRaises(pymysql.err.OperationalError):
|
||||
pymysql.connect(user='pymysql_3a', plugin_map={b'notdialogplugin': TestAuthentication.Dialog}, **self.db)
|
||||
pymysql.connect(user='pymysql_3a', auth_plugin_map={b'notdialogplugin': TestAuthentication.Dialog}, **self.db)
|
||||
TestAuthentication.Dialog.m = {b'Password, please:': b'I do not know'}
|
||||
with self.assertRaises(pymysql.err.OperationalError):
|
||||
pymysql.connect(user='pymysql_3a', plugin_map={b'dialog': TestAuthentication.Dialog}, **self.db)
|
||||
pymysql.connect(user='pymysql_3a', auth_plugin_map={b'dialog': TestAuthentication.Dialog}, **self.db)
|
||||
TestAuthentication.Dialog.m = {b'Password, please:': None}
|
||||
with self.assertRaises(pymysql.err.OperationalError):
|
||||
pymysql.connect(user='pymysql_3a', plugin_map={b'dialog': TestAuthentication.Dialog}, **self.db)
|
||||
pymysql.connect(user='pymysql_3a', auth_plugin_map={b'dialog': TestAuthentication.Dialog}, **self.db)
|
||||
|
||||
@unittest2.skipUnless(socket_auth, "connection to unix_socket required")
|
||||
@unittest2.skipIf(pam_found, "pam plugin already installed")
|
||||
@@ -285,12 +286,16 @@ class TestAuthentication(base.PyMySQLTestCase):
|
||||
c = pymysql.connect(user=TestAuthentication.osuser, **db)
|
||||
db['password'] = 'very bad guess at password'
|
||||
with self.assertRaises(pymysql.err.OperationalError):
|
||||
pymysql.connect(user=TestAuthentication.osuser, plugin_map={b'mysql_cleartext_password': TestAuthentication.DefectiveHandler}, **self.db)
|
||||
pymysql.connect(user=TestAuthentication.osuser,
|
||||
auth_plugin_map={b'mysql_cleartext_password': TestAuthentication.DefectiveHandler},
|
||||
**self.db)
|
||||
except pymysql.OperationalError as e:
|
||||
self.assertEqual(1045, e.args[0])
|
||||
# we had 'bad guess at password' work with pam. Well at least we get a permission denied here
|
||||
with self.assertRaises(pymysql.err.OperationalError):
|
||||
pymysql.connect(user=TestAuthentication.osuser, plugin_map={b'mysql_cleartext_password': TestAuthentication.DefectiveHandler}, **self.db)
|
||||
pymysql.connect(user=TestAuthentication.osuser,
|
||||
auth_plugin_map={b'mysql_cleartext_password': TestAuthentication.DefectiveHandler},
|
||||
**self.db)
|
||||
if grants:
|
||||
# recreate the user
|
||||
cur.execute(grants)
|
||||
|
||||
@@ -61,6 +61,7 @@ class TestLoadLocal(base.PyMySQLTestCase):
|
||||
self.assertTrue("Incorrect integer value" in str(w[-1].message))
|
||||
finally:
|
||||
c.execute("DROP TABLE test_load_local")
|
||||
c.close()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
Reference in New Issue
Block a user