Merge pull request #405 from methane/feature/refactor
refactoring auth plugin support
This commit is contained in:
@@ -1,8 +1,9 @@
|
|||||||
sudo: false
|
sudo: false
|
||||||
language: python
|
language: python
|
||||||
python: "3.4"
|
python: "3.5"
|
||||||
cache:
|
cache:
|
||||||
pip: true
|
directories:
|
||||||
|
- $HOME/.cache/pip
|
||||||
|
|
||||||
env:
|
env:
|
||||||
matrix:
|
matrix:
|
||||||
@@ -10,6 +11,7 @@ env:
|
|||||||
- TOX_ENV=py27
|
- TOX_ENV=py27
|
||||||
- TOX_ENV=py33
|
- TOX_ENV=py33
|
||||||
- TOX_ENV=py34
|
- TOX_ENV=py34
|
||||||
|
- TOX_ENV=py35
|
||||||
- TOX_ENV=pypy
|
- TOX_ENV=pypy
|
||||||
- TOX_ENV=pypy3
|
- TOX_ENV=pypy3
|
||||||
|
|
||||||
@@ -36,7 +38,7 @@ matrix:
|
|||||||
sudo: required
|
sudo: required
|
||||||
- env:
|
- env:
|
||||||
- TOX_ENV=py34
|
- TOX_ENV=py34
|
||||||
- DB=5.6.26
|
- DB=5.6.28
|
||||||
addons:
|
addons:
|
||||||
apt:
|
apt:
|
||||||
packages:
|
packages:
|
||||||
|
|||||||
@@ -525,6 +525,7 @@ class Connection(object):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
socket = None
|
socket = None
|
||||||
|
_auth_plugin_name = ''
|
||||||
|
|
||||||
def __init__(self, host=None, user=None, password="",
|
def __init__(self, host=None, user=None, password="",
|
||||||
database=None, port=3306, unix_socket=None,
|
database=None, port=3306, unix_socket=None,
|
||||||
@@ -535,7 +536,7 @@ class Connection(object):
|
|||||||
compress=None, named_pipe=None, no_delay=None,
|
compress=None, named_pipe=None, no_delay=None,
|
||||||
autocommit=False, db=None, passwd=None, local_infile=False,
|
autocommit=False, db=None, passwd=None, local_infile=False,
|
||||||
max_allowed_packet=16*1024*1024, defer_connect=False,
|
max_allowed_packet=16*1024*1024, defer_connect=False,
|
||||||
plugin_map={}):
|
auth_plugin_map={}):
|
||||||
"""
|
"""
|
||||||
Establish a connection to the MySQL database. Accepts several
|
Establish a connection to the MySQL database. Accepts several
|
||||||
arguments:
|
arguments:
|
||||||
@@ -571,12 +572,11 @@ class Connection(object):
|
|||||||
max_allowed_packet: Max size of packet sent to server in bytes. (default: 16MB)
|
max_allowed_packet: Max size of packet sent to server in bytes. (default: 16MB)
|
||||||
defer_connect: Don't explicitly connect on contruction - wait for connect call.
|
defer_connect: Don't explicitly connect on contruction - wait for connect call.
|
||||||
(default: False)
|
(default: False)
|
||||||
plugin_map: Map of plugin names to a class that processes that plugin. The class
|
auth_plugin_map: A dict of plugin names to a class that processes that plugin.
|
||||||
will take the Connection object as the argument to the constructor. The class
|
The class will take the Connection object as the argument to the constructor.
|
||||||
needs an authenticate method taking an authentication packet as an argument.
|
The class needs an authenticate method taking an authentication packet as
|
||||||
For the dialog plugin, a prompt(echo, prompt) method can be used (if no
|
an argument. For the dialog plugin, a prompt(echo, prompt) method can be used
|
||||||
authenticate method) for returning a string from the user.
|
(if no authenticate method) for returning a string from the user. (experimental)
|
||||||
|
|
||||||
db: Alias for database. (for compatibility to MySQLdb)
|
db: Alias for database. (for compatibility to MySQLdb)
|
||||||
passwd: Alias for password. (for compatibility to MySQLdb)
|
passwd: Alias for password. (for compatibility to MySQLdb)
|
||||||
"""
|
"""
|
||||||
@@ -672,7 +672,7 @@ class Connection(object):
|
|||||||
self.sql_mode = sql_mode
|
self.sql_mode = sql_mode
|
||||||
self.init_command = init_command
|
self.init_command = init_command
|
||||||
self.max_allowed_packet = max_allowed_packet
|
self.max_allowed_packet = max_allowed_packet
|
||||||
self.plugin_map = plugin_map
|
self._auth_plugin_map = auth_plugin_map
|
||||||
if defer_connect:
|
if defer_connect:
|
||||||
self.socket = None
|
self.socket = None
|
||||||
else:
|
else:
|
||||||
@@ -726,7 +726,6 @@ class Connection(object):
|
|||||||
def autocommit(self, value):
|
def autocommit(self, value):
|
||||||
self.autocommit_mode = bool(value)
|
self.autocommit_mode = bool(value)
|
||||||
current = self.get_autocommit()
|
current = self.get_autocommit()
|
||||||
self.next_packet = 1
|
|
||||||
if value != current:
|
if value != current:
|
||||||
self._send_autocommit_mode()
|
self._send_autocommit_mode()
|
||||||
|
|
||||||
@@ -816,7 +815,6 @@ class Connection(object):
|
|||||||
"You may not close previous cursor.")
|
"You may not close previous cursor.")
|
||||||
# if DEBUG:
|
# if DEBUG:
|
||||||
# print("DEBUG: sending query:", sql)
|
# print("DEBUG: sending query:", sql)
|
||||||
self.next_packet = 1
|
|
||||||
if isinstance(sql, text_type) and not (JYTHON or IRONPYTHON):
|
if isinstance(sql, text_type) and not (JYTHON or IRONPYTHON):
|
||||||
if PY2:
|
if PY2:
|
||||||
sql = sql.encode(self.encoding)
|
sql = sql.encode(self.encoding)
|
||||||
@@ -824,7 +822,6 @@ class Connection(object):
|
|||||||
sql = sql.encode(self.encoding, 'surrogateescape')
|
sql = sql.encode(self.encoding, 'surrogateescape')
|
||||||
self._execute_command(COMMAND.COM_QUERY, sql)
|
self._execute_command(COMMAND.COM_QUERY, sql)
|
||||||
self._affected_rows = self._read_query_result(unbuffered=unbuffered)
|
self._affected_rows = self._read_query_result(unbuffered=unbuffered)
|
||||||
self.next_packet = 1
|
|
||||||
return self._affected_rows
|
return self._affected_rows
|
||||||
|
|
||||||
def next_result(self, unbuffered=False):
|
def next_result(self, unbuffered=False):
|
||||||
@@ -892,7 +889,7 @@ class Connection(object):
|
|||||||
sock.setsockopt(socket.SOL_SOCKET, socket.SO_KEEPALIVE, 1)
|
sock.setsockopt(socket.SOL_SOCKET, socket.SO_KEEPALIVE, 1)
|
||||||
self.socket = sock
|
self.socket = sock
|
||||||
self._rfile = _makefile(sock, 'rb')
|
self._rfile = _makefile(sock, 'rb')
|
||||||
self.next_packet = 0
|
self._next_seq_id = 0
|
||||||
|
|
||||||
self._get_server_information()
|
self._get_server_information()
|
||||||
self._request_authentication()
|
self._request_authentication()
|
||||||
@@ -933,16 +930,16 @@ class Connection(object):
|
|||||||
# So just reraise it.
|
# So just reraise it.
|
||||||
raise
|
raise
|
||||||
|
|
||||||
def write_packet(self, data):
|
def write_packet(self, payload):
|
||||||
"""Writes an entire "mysql packet" in its entirety to the network
|
"""Writes an entire "mysql packet" in its entirety to the network
|
||||||
addings its length and sequence number. Intended for use by plugins
|
addings its length and sequence number.
|
||||||
only.
|
|
||||||
"""
|
"""
|
||||||
data = pack_int24(len(data)) + int2byte(self.next_packet) + data
|
# Internal note: when you build packet manualy and calls _write_bytes()
|
||||||
|
# directly, you should set self._next_seq_id properly.
|
||||||
|
data = pack_int24(len(payload)) + int2byte(self._next_seq_id) + payload
|
||||||
if DEBUG: dump_packet(data)
|
if DEBUG: dump_packet(data)
|
||||||
|
|
||||||
self._write_bytes(data)
|
self._write_bytes(data)
|
||||||
self.next_packet = (self.next_packet + 1) % 256
|
self._next_seq_id = (self._next_seq_id + 1) % 256
|
||||||
|
|
||||||
def _read_packet(self, packet_type=MysqlPacket):
|
def _read_packet(self, packet_type=MysqlPacket):
|
||||||
"""Read an entire "mysql packet" in its entirety from the network
|
"""Read an entire "mysql packet" in its entirety from the network
|
||||||
@@ -952,8 +949,14 @@ class Connection(object):
|
|||||||
while True:
|
while True:
|
||||||
packet_header = self._read_bytes(4)
|
packet_header = self._read_bytes(4)
|
||||||
if DEBUG: dump_packet(packet_header)
|
if DEBUG: dump_packet(packet_header)
|
||||||
|
|
||||||
btrl, btrh, packet_number = struct.unpack('<HBB', packet_header)
|
btrl, btrh, packet_number = struct.unpack('<HBB', packet_header)
|
||||||
bytes_to_read = btrl + (btrh << 16)
|
bytes_to_read = btrl + (btrh << 16)
|
||||||
|
if packet_number != self._next_seq_id:
|
||||||
|
raise err.InternalError("Packet sequence number wrong - got %d expected %d" %
|
||||||
|
(packet_number, self._next_seq_id))
|
||||||
|
self._next_seq_id = (self._next_seq_id + 1) % 256
|
||||||
|
|
||||||
recv_data = self._read_bytes(bytes_to_read)
|
recv_data = self._read_bytes(bytes_to_read)
|
||||||
if DEBUG: dump_packet(recv_data)
|
if DEBUG: dump_packet(recv_data)
|
||||||
buff += recv_data
|
buff += recv_data
|
||||||
@@ -962,13 +965,7 @@ class Connection(object):
|
|||||||
continue
|
continue
|
||||||
if bytes_to_read < MAX_PACKET_LEN:
|
if bytes_to_read < MAX_PACKET_LEN:
|
||||||
break
|
break
|
||||||
if packet_number != self.next_packet:
|
|
||||||
pass
|
|
||||||
#TODO: check sequence id
|
|
||||||
#raise err.InternalError("Packet sequence number wrong - got %d expected %d" %
|
|
||||||
# (packet_number, self.next_packet))
|
|
||||||
|
|
||||||
self.next_packet = (packet_number + 1) % 256
|
|
||||||
packet = packet_type(buff, self.encoding)
|
packet = packet_type(buff, self.encoding)
|
||||||
packet.check_error()
|
packet.check_error()
|
||||||
return packet
|
return packet
|
||||||
@@ -1027,33 +1024,32 @@ class Connection(object):
|
|||||||
if self._result is not None and self._result.unbuffered_active:
|
if self._result is not None and self._result.unbuffered_active:
|
||||||
warnings.warn("Previous unbuffered result was left incomplete")
|
warnings.warn("Previous unbuffered result was left incomplete")
|
||||||
self._result._finish_unbuffered_query()
|
self._result._finish_unbuffered_query()
|
||||||
|
self._result = None
|
||||||
|
|
||||||
if isinstance(sql, text_type):
|
if isinstance(sql, text_type):
|
||||||
sql = sql.encode(self.encoding)
|
sql = sql.encode(self.encoding)
|
||||||
|
|
||||||
chunk_size = min(self.max_allowed_packet, len(sql) + 1) # +1 is for command
|
# +1 is for command
|
||||||
|
chunk_size = min(self.max_allowed_packet, len(sql) + 1)
|
||||||
|
|
||||||
|
# tiny optimization: build first packet manually instead of
|
||||||
|
# calling self..write_packet()
|
||||||
prelude = struct.pack('<iB', chunk_size, command)
|
prelude = struct.pack('<iB', chunk_size, command)
|
||||||
self._write_bytes(prelude + sql[:chunk_size-1])
|
packet = prelude + sql[:chunk_size-1]
|
||||||
if DEBUG: dump_packet(prelude + sql)
|
self._write_bytes(packet)
|
||||||
|
if DEBUG: dump_packet(packet)
|
||||||
|
self._next_seq_id = 1
|
||||||
|
|
||||||
self.next_packet = 1
|
|
||||||
if chunk_size < self.max_allowed_packet:
|
if chunk_size < self.max_allowed_packet:
|
||||||
return
|
return
|
||||||
|
|
||||||
seq_id = 1
|
|
||||||
sql = sql[chunk_size-1:]
|
sql = sql[chunk_size-1:]
|
||||||
while True:
|
while True:
|
||||||
chunk_size = min(self.max_allowed_packet, len(sql))
|
chunk_size = min(self.max_allowed_packet, len(sql))
|
||||||
prelude = struct.pack('<i', chunk_size)[:3]
|
self.write_packet(sql[:chunk_size])
|
||||||
data = prelude + int2byte(seq_id%256) + sql[:chunk_size]
|
|
||||||
self._write_bytes(data)
|
|
||||||
if DEBUG: dump_packet(data)
|
|
||||||
sql = sql[chunk_size:]
|
sql = sql[chunk_size:]
|
||||||
if not sql and chunk_size < self.max_allowed_packet:
|
if not sql and chunk_size < self.max_allowed_packet:
|
||||||
break
|
break
|
||||||
seq_id += 1
|
|
||||||
self.next_packet = seq_id%256
|
|
||||||
|
|
||||||
def _request_authentication(self):
|
def _request_authentication(self):
|
||||||
# https://dev.mysql.com/doc/internals/en/connection-phase-packets.html#packet-Protocol::HandshakeResponse
|
# https://dev.mysql.com/doc/internals/en/connection-phase-packets.html#packet-Protocol::HandshakeResponse
|
||||||
@@ -1078,17 +1074,14 @@ class Connection(object):
|
|||||||
|
|
||||||
data = data_init + self.user + b'\0'
|
data = data_init + self.user + b'\0'
|
||||||
|
|
||||||
authresp = ''
|
authresp = b''
|
||||||
if self.plugin_name == 'mysql_native_password':
|
if self._auth_plugin_name == 'mysql_native_password':
|
||||||
authresp = _scramble(self.password.encode('latin1'), self.salt)
|
authresp = _scramble(self.password.encode('latin1'), self.salt)
|
||||||
|
|
||||||
if self.server_capabilities & CLIENT.PLUGIN_AUTH_LENENC_CLIENT_DATA:
|
if self.server_capabilities & CLIENT.PLUGIN_AUTH_LENENC_CLIENT_DATA:
|
||||||
data += lenenc_int(len(authresp))
|
data += lenenc_int(len(authresp)) + authresp
|
||||||
data += authresp
|
|
||||||
elif self.server_capabilities & CLIENT.SECURE_CONNECTION:
|
elif self.server_capabilities & CLIENT.SECURE_CONNECTION:
|
||||||
length = len(authresp)
|
data += struct.pack('B', len(authresp)) + authresp
|
||||||
data += struct.pack('B', length)
|
|
||||||
data += authresp
|
|
||||||
else: # pragma: no cover - not testing against servers without secure auth (>=5.0)
|
else: # pragma: no cover - not testing against servers without secure auth (>=5.0)
|
||||||
data += authresp + b'\0'
|
data += authresp + b'\0'
|
||||||
|
|
||||||
@@ -1098,15 +1091,16 @@ class Connection(object):
|
|||||||
data += self.db + b'\0'
|
data += self.db + b'\0'
|
||||||
|
|
||||||
if self.server_capabilities & CLIENT.PLUGIN_AUTH:
|
if self.server_capabilities & CLIENT.PLUGIN_AUTH:
|
||||||
data += self.plugin_name.encode('latin1') + b'\0'
|
name = self._auth_plugin_name
|
||||||
|
if isinstance(name, text_type):
|
||||||
|
name = name.encode('ascii')
|
||||||
|
data += name + b'\0'
|
||||||
|
|
||||||
self.write_packet(data)
|
self.write_packet(data)
|
||||||
|
|
||||||
auth_packet = self._read_packet()
|
auth_packet = self._read_packet()
|
||||||
|
|
||||||
# if authentication method isn't accepted the first byte
|
# if authentication method isn't accepted the first byte
|
||||||
# will have the octet 254
|
# will have the octet 254
|
||||||
|
|
||||||
if auth_packet.is_auth_switch_request():
|
if auth_packet.is_auth_switch_request():
|
||||||
# https://dev.mysql.com/doc/internals/en/connection-phase-packets.html#packet-Protocol::AuthSwitchRequest
|
# https://dev.mysql.com/doc/internals/en/connection-phase-packets.html#packet-Protocol::AuthSwitchRequest
|
||||||
auth_packet.read_uint8() # 0xfe packet identifier
|
auth_packet.read_uint8() # 0xfe packet identifier
|
||||||
@@ -1119,8 +1113,13 @@ class Connection(object):
|
|||||||
self.write_packet(data)
|
self.write_packet(data)
|
||||||
auth_packet = self._read_packet()
|
auth_packet = self._read_packet()
|
||||||
|
|
||||||
|
#TODO: ok packet or error packet?
|
||||||
|
|
||||||
|
|
||||||
def _process_auth(self, plugin_name, auth_packet):
|
def _process_auth(self, plugin_name, auth_packet):
|
||||||
plugin_class = self.plugin_map.get(plugin_name)
|
plugin_class = self._auth_plugin_map.get(plugin_name)
|
||||||
|
if not plugin_class:
|
||||||
|
plugin_class = self._auth_plugin_map.get(plugin_name.decode('ascii'))
|
||||||
if plugin_class:
|
if plugin_class:
|
||||||
try:
|
try:
|
||||||
handler = plugin_class(self)
|
handler = plugin_class(self)
|
||||||
@@ -1246,18 +1245,13 @@ class Connection(object):
|
|||||||
server_end = data.find(b'\0', i)
|
server_end = data.find(b'\0', i)
|
||||||
if server_end < 0: # pragma: no cover - very specific upstream bug
|
if server_end < 0: # pragma: no cover - very specific upstream bug
|
||||||
# not found \0 and last field so take it all
|
# not found \0 and last field so take it all
|
||||||
self.plugin_name = data[i:].decode('latin1')
|
self._auth_plugin_name = data[i:].decode('latin1')
|
||||||
else:
|
else:
|
||||||
self.plugin_name = data[i:server_end].decode('latin1')
|
self._auth_plugin_name = data[i:server_end].decode('latin1')
|
||||||
else: # pragma: no cover - not testing against any plugin uncapable servers
|
|
||||||
self.plugin_name = ''
|
|
||||||
|
|
||||||
def get_server_info(self):
|
def get_server_info(self):
|
||||||
return self.server_version
|
return self.server_version
|
||||||
|
|
||||||
def get_plugin_name(self):
|
|
||||||
return self.plugin_name
|
|
||||||
|
|
||||||
Warning = err.Warning
|
Warning = err.Warning
|
||||||
Error = err.Error
|
Error = err.Error
|
||||||
InterfaceError = err.InterfaceError
|
InterfaceError = err.InterfaceError
|
||||||
@@ -1331,7 +1325,11 @@ class MySQLResult(object):
|
|||||||
def _read_load_local_packet(self, first_packet):
|
def _read_load_local_packet(self, first_packet):
|
||||||
load_packet = LoadLocalPacketWrapper(first_packet)
|
load_packet = LoadLocalPacketWrapper(first_packet)
|
||||||
sender = LoadLocalFile(load_packet.filename, self.connection)
|
sender = LoadLocalFile(load_packet.filename, self.connection)
|
||||||
|
try:
|
||||||
sender.send_data()
|
sender.send_data()
|
||||||
|
except:
|
||||||
|
self.connection._read_packet() # skip ok packet
|
||||||
|
raise
|
||||||
|
|
||||||
ok_packet = self.connection._read_packet()
|
ok_packet = self.connection._read_packet()
|
||||||
if not ok_packet.is_ok_packet(): # pragma: no cover - upstream induced protocol error
|
if not ok_packet.is_ok_packet(): # pragma: no cover - upstream induced protocol error
|
||||||
@@ -1448,27 +1446,20 @@ class LoadLocalFile(object):
|
|||||||
"""Send data packets from the local file to the server"""
|
"""Send data packets from the local file to the server"""
|
||||||
if not self.connection.socket:
|
if not self.connection.socket:
|
||||||
raise err.InterfaceError("(0, '')")
|
raise err.InterfaceError("(0, '')")
|
||||||
|
conn = self.connection
|
||||||
|
|
||||||
# sequence id is 2 as we already sent a query packet
|
|
||||||
seq_id = 2
|
|
||||||
try:
|
try:
|
||||||
with open(self.filename, 'rb') as open_file:
|
with open(self.filename, 'rb') as open_file:
|
||||||
chunk_size = self.connection.max_allowed_packet
|
chunk_size = conn.max_allowed_packet
|
||||||
packet = b""
|
packet = b""
|
||||||
|
|
||||||
while True:
|
while True:
|
||||||
chunk = open_file.read(chunk_size)
|
chunk = open_file.read(chunk_size)
|
||||||
if not chunk:
|
if not chunk:
|
||||||
break
|
break
|
||||||
packet = struct.pack('<i', len(chunk))[:3] + int2byte(seq_id)
|
conn.write_packet(chunk)
|
||||||
format_str = '!{0}s'.format(len(chunk))
|
|
||||||
packet += struct.pack(format_str, chunk)
|
|
||||||
self.connection._write_bytes(packet)
|
|
||||||
seq_id = (seq_id + 1) % 256
|
|
||||||
except IOError:
|
except IOError:
|
||||||
raise err.OperationalError(1017, "Can't find file '{0}'".format(self.filename))
|
raise err.OperationalError(1017, "Can't find file '{0}'".format(self.filename))
|
||||||
finally:
|
finally:
|
||||||
# send the empty packet to signify we are done sending data
|
# send the empty packet to signify we are done sending data
|
||||||
packet = struct.pack('<i', 0)[:3] + int2byte(seq_id)
|
conn.write_packet(b'')
|
||||||
self.connection._write_bytes(packet)
|
|
||||||
self.next_packet = (seq_id + 1) % 256
|
|
||||||
|
|||||||
@@ -40,6 +40,7 @@ class TempUser:
|
|||||||
if self._created:
|
if self._created:
|
||||||
self._c.execute("DROP USER %s" % self._user)
|
self._c.execute("DROP USER %s" % self._user)
|
||||||
|
|
||||||
|
|
||||||
class TestAuthentication(base.PyMySQLTestCase):
|
class TestAuthentication(base.PyMySQLTestCase):
|
||||||
|
|
||||||
socket_auth = False
|
socket_auth = False
|
||||||
@@ -95,7 +96,7 @@ class TestAuthentication(base.PyMySQLTestCase):
|
|||||||
|
|
||||||
def test_plugin(self):
|
def test_plugin(self):
|
||||||
# Bit of an assumption that the current user is a native password
|
# Bit of an assumption that the current user is a native password
|
||||||
self.assertEqual('mysql_native_password', self.connections[0].get_plugin_name())
|
self.assertEqual('mysql_native_password', self.connections[0]._auth_plugin_name)
|
||||||
|
|
||||||
@unittest2.skipUnless(socket_auth, "connection to unix_socket required")
|
@unittest2.skipUnless(socket_auth, "connection to unix_socket required")
|
||||||
@unittest2.skipIf(socket_found, "socket plugin already installed")
|
@unittest2.skipIf(socket_found, "socket plugin already installed")
|
||||||
@@ -198,7 +199,7 @@ class TestAuthentication(base.PyMySQLTestCase):
|
|||||||
self.databases[0]['db'], 'two_questions', 'notverysecret') as u:
|
self.databases[0]['db'], 'two_questions', 'notverysecret') as u:
|
||||||
with self.assertRaises(pymysql.err.OperationalError):
|
with self.assertRaises(pymysql.err.OperationalError):
|
||||||
pymysql.connect(user='pymysql_2q', **self.db)
|
pymysql.connect(user='pymysql_2q', **self.db)
|
||||||
pymysql.connect(user='pymysql_2q', plugin_map={b'dialog': TestAuthentication.Dialog}, **self.db)
|
pymysql.connect(user='pymysql_2q', auth_plugin_map={b'dialog': TestAuthentication.Dialog}, **self.db)
|
||||||
|
|
||||||
@unittest2.skipUnless(socket_auth, "connection to unix_socket required")
|
@unittest2.skipUnless(socket_auth, "connection to unix_socket required")
|
||||||
@unittest2.skipIf(three_attempts_found, "three_attempts plugin already installed")
|
@unittest2.skipIf(three_attempts_found, "three_attempts plugin already installed")
|
||||||
@@ -225,21 +226,21 @@ class TestAuthentication(base.PyMySQLTestCase):
|
|||||||
TestAuthentication.Dialog.fail=True # fail just once. We've got three attempts after all
|
TestAuthentication.Dialog.fail=True # fail just once. We've got three attempts after all
|
||||||
with TempUser(self.connections[0].cursor(), 'pymysql_3a@localhost',
|
with TempUser(self.connections[0].cursor(), 'pymysql_3a@localhost',
|
||||||
self.databases[0]['db'], 'three_attempts', 'stillnotverysecret') as u:
|
self.databases[0]['db'], 'three_attempts', 'stillnotverysecret') as u:
|
||||||
pymysql.connect(user='pymysql_3a', plugin_map={b'dialog': TestAuthentication.Dialog}, **self.db)
|
pymysql.connect(user='pymysql_3a', auth_plugin_map={b'dialog': TestAuthentication.Dialog}, **self.db)
|
||||||
pymysql.connect(user='pymysql_3a', plugin_map={b'dialog': TestAuthentication.DialogHandler}, **self.db)
|
pymysql.connect(user='pymysql_3a', auth_plugin_map={b'dialog': TestAuthentication.DialogHandler}, **self.db)
|
||||||
with self.assertRaises(pymysql.err.OperationalError):
|
with self.assertRaises(pymysql.err.OperationalError):
|
||||||
pymysql.connect(user='pymysql_3a', plugin_map={b'dialog': object}, **self.db)
|
pymysql.connect(user='pymysql_3a', auth_plugin_map={b'dialog': object}, **self.db)
|
||||||
|
|
||||||
with self.assertRaises(pymysql.err.OperationalError):
|
with self.assertRaises(pymysql.err.OperationalError):
|
||||||
pymysql.connect(user='pymysql_3a', plugin_map={b'dialog': TestAuthentication.DefectiveHandler}, **self.db)
|
pymysql.connect(user='pymysql_3a', auth_plugin_map={b'dialog': TestAuthentication.DefectiveHandler}, **self.db)
|
||||||
with self.assertRaises(pymysql.err.OperationalError):
|
with self.assertRaises(pymysql.err.OperationalError):
|
||||||
pymysql.connect(user='pymysql_3a', plugin_map={b'notdialogplugin': TestAuthentication.Dialog}, **self.db)
|
pymysql.connect(user='pymysql_3a', auth_plugin_map={b'notdialogplugin': TestAuthentication.Dialog}, **self.db)
|
||||||
TestAuthentication.Dialog.m = {b'Password, please:': b'I do not know'}
|
TestAuthentication.Dialog.m = {b'Password, please:': b'I do not know'}
|
||||||
with self.assertRaises(pymysql.err.OperationalError):
|
with self.assertRaises(pymysql.err.OperationalError):
|
||||||
pymysql.connect(user='pymysql_3a', plugin_map={b'dialog': TestAuthentication.Dialog}, **self.db)
|
pymysql.connect(user='pymysql_3a', auth_plugin_map={b'dialog': TestAuthentication.Dialog}, **self.db)
|
||||||
TestAuthentication.Dialog.m = {b'Password, please:': None}
|
TestAuthentication.Dialog.m = {b'Password, please:': None}
|
||||||
with self.assertRaises(pymysql.err.OperationalError):
|
with self.assertRaises(pymysql.err.OperationalError):
|
||||||
pymysql.connect(user='pymysql_3a', plugin_map={b'dialog': TestAuthentication.Dialog}, **self.db)
|
pymysql.connect(user='pymysql_3a', auth_plugin_map={b'dialog': TestAuthentication.Dialog}, **self.db)
|
||||||
|
|
||||||
@unittest2.skipUnless(socket_auth, "connection to unix_socket required")
|
@unittest2.skipUnless(socket_auth, "connection to unix_socket required")
|
||||||
@unittest2.skipIf(pam_found, "pam plugin already installed")
|
@unittest2.skipIf(pam_found, "pam plugin already installed")
|
||||||
@@ -285,12 +286,16 @@ class TestAuthentication(base.PyMySQLTestCase):
|
|||||||
c = pymysql.connect(user=TestAuthentication.osuser, **db)
|
c = pymysql.connect(user=TestAuthentication.osuser, **db)
|
||||||
db['password'] = 'very bad guess at password'
|
db['password'] = 'very bad guess at password'
|
||||||
with self.assertRaises(pymysql.err.OperationalError):
|
with self.assertRaises(pymysql.err.OperationalError):
|
||||||
pymysql.connect(user=TestAuthentication.osuser, plugin_map={b'mysql_cleartext_password': TestAuthentication.DefectiveHandler}, **self.db)
|
pymysql.connect(user=TestAuthentication.osuser,
|
||||||
|
auth_plugin_map={b'mysql_cleartext_password': TestAuthentication.DefectiveHandler},
|
||||||
|
**self.db)
|
||||||
except pymysql.OperationalError as e:
|
except pymysql.OperationalError as e:
|
||||||
self.assertEqual(1045, e.args[0])
|
self.assertEqual(1045, e.args[0])
|
||||||
# we had 'bad guess at password' work with pam. Well at least we get a permission denied here
|
# we had 'bad guess at password' work with pam. Well at least we get a permission denied here
|
||||||
with self.assertRaises(pymysql.err.OperationalError):
|
with self.assertRaises(pymysql.err.OperationalError):
|
||||||
pymysql.connect(user=TestAuthentication.osuser, plugin_map={b'mysql_cleartext_password': TestAuthentication.DefectiveHandler}, **self.db)
|
pymysql.connect(user=TestAuthentication.osuser,
|
||||||
|
auth_plugin_map={b'mysql_cleartext_password': TestAuthentication.DefectiveHandler},
|
||||||
|
**self.db)
|
||||||
if grants:
|
if grants:
|
||||||
# recreate the user
|
# recreate the user
|
||||||
cur.execute(grants)
|
cur.execute(grants)
|
||||||
|
|||||||
@@ -61,6 +61,7 @@ class TestLoadLocal(base.PyMySQLTestCase):
|
|||||||
self.assertTrue("Incorrect integer value" in str(w[-1].message))
|
self.assertTrue("Incorrect integer value" in str(w[-1].message))
|
||||||
finally:
|
finally:
|
||||||
c.execute("DROP TABLE test_load_local")
|
c.execute("DROP TABLE test_load_local")
|
||||||
|
c.close()
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|||||||
Reference in New Issue
Block a user