Files
deb-python-pymysql/pymysql/tests/test_connection.py
Daniel Black 2d1da09f07 Add basic connection plugin support
This extracts a plugin name from the server and exposes this as
connection.get_plugin_name().

The client side here adds capabilities PLUGIN_AUTH and
PLUGIN_AUTH_LENENC_CLIENT_DATA.

Because of this the HandshakeResponse response has been altered to use
the server_capabilities a bit mroe strictly.

The immediate upshot is that plugins like unix_socket are immediately
supported where they previously would of returned and error as without
PLUGIN_AUTH being advertised on the client side, the server only accepts
mysql_native_password and mysql_old_password.

To support other plugins some more work is needed to implement these.
2015-06-28 15:43:57 +10:00

174 lines
5.3 KiB
Python

import datetime
import decimal
import pymysql
import time
import os
import copy
from pymysql.tests import base
class TestConnection(base.PyMySQLTestCase):
def test_utf8mb4(self):
"""This test requires MySQL >= 5.5"""
arg = self.databases[0].copy()
arg['charset'] = 'utf8mb4'
conn = pymysql.connect(**arg)
def test_largedata(self):
"""Large query and response (>=16MB)"""
cur = self.connections[0].cursor()
cur.execute("SELECT @@max_allowed_packet")
if cur.fetchone()[0] < 16*1024*1024 + 10:
print("Set max_allowed_packet to bigger than 17MB")
return
t = 'a' * (16*1024*1024)
cur.execute("SELECT '" + t + "'")
assert cur.fetchone()[0] == t
def test_autocommit(self):
con = self.connections[0]
self.assertFalse(con.get_autocommit())
cur = con.cursor()
cur.execute("SET AUTOCOMMIT=1")
self.assertTrue(con.get_autocommit())
con.autocommit(False)
self.assertFalse(con.get_autocommit())
cur.execute("SELECT @@AUTOCOMMIT")
self.assertEqual(cur.fetchone()[0], 0)
def test_select_db(self):
con = self.connections[0]
current_db = self.databases[0]['db']
other_db = self.databases[1]['db']
cur = con.cursor()
cur.execute('SELECT database()')
self.assertEqual(cur.fetchone()[0], current_db)
con.select_db(other_db)
cur.execute('SELECT database()')
self.assertEqual(cur.fetchone()[0], other_db)
def test_connection_gone_away(self):
"""
http://dev.mysql.com/doc/refman/5.0/en/gone-away.html
http://dev.mysql.com/doc/refman/5.0/en/error-messages-client.html#error_cr_server_gone_error
"""
con = self.connections[0]
cur = con.cursor()
cur.execute("SET wait_timeout=1")
time.sleep(2)
with self.assertRaises(pymysql.OperationalError) as cm:
cur.execute("SELECT 1+1")
# error occures while reading, not writing because of socket buffer.
#self.assertEquals(cm.exception.args[0], 2006)
self.assertIn(cm.exception.args[0], (2006, 2013))
def test_init_command(self):
conn = pymysql.connect(
init_command='SELECT "bar"; SELECT "baz"',
**self.databases[0]
)
c = conn.cursor()
c.execute('select "foobar";')
self.assertEqual(('foobar',), c.fetchone())
conn.close()
def test_plugin(self):
con = self.connections[0]
self.assertEqual('mysql_native_password',con.get_plugin_name())
# attempt a unix socket test which is included in some versions
# and doesn't require a client side handler
user = os.environ.get('USER')
if not user or self.databases[0]['host'] != 'localhost':
return
cur = con.cursor()
cur.execute("SHOW PLUGINS")
found = False
for r in cur:
if r == (u'unix_socket', u'ACTIVE', u'AUTHENTICATION', u'auth_socket.so', u'GPL'):
found = True
break
# needs plugin. lets install it.
if not found:
cur.execute("install soname 'auth_socket'")
current_db = self.databases[0]['db']
cur.execute("GRANT ALL ON %s TO %s@localhost IDENTIFIED VIA unix_socket" % ( current_db, user))
db = copy.copy(self.databases[0])
del db['user']
c = pymysql.connect(user=user, **db)
if not found:
cur.execute("uninstall soname 'auth_socket'")
cur.execute("DROP USER %s@localhost" % user)
# A custom type and function to escape it
class Foo(object):
value = "bar"
def escape_foo(x, d):
return x.value
class TestEscape(base.PyMySQLTestCase):
def test_escape_string(self):
con = self.connections[0]
cur = con.cursor()
self.assertEqual(con.escape("foo'bar"), "'foo\\'bar'")
cur.execute("SET sql_mode='NO_BACKSLASH_ESCAPES'")
self.assertEqual(con.escape("foo'bar"), "'foo''bar'")
def test_escape_builtin_encoders(self):
con = self.connections[0]
cur = con.cursor()
val = datetime.datetime(2012, 3, 4, 5, 6)
self.assertEqual(con.escape(val, con.encoders), "'2012-03-04 05:06:00'")
def test_escape_custom_object(self):
con = self.connections[0]
cur = con.cursor()
mapping = {Foo: escape_foo}
self.assertEqual(con.escape(Foo(), mapping), "bar")
def test_escape_fallback_encoder(self):
con = self.connections[0]
cur = con.cursor()
class Custom(str):
pass
mapping = {pymysql.text_type: pymysql.escape_string}
self.assertEqual(con.escape(Custom('foobar'), mapping), "'foobar'")
def test_escape_no_default(self):
con = self.connections[0]
cur = con.cursor()
self.assertRaises(TypeError, con.escape, 42, {})
def test_escape_dict_value(self):
con = self.connections[0]
cur = con.cursor()
mapping = con.encoders.copy()
mapping[Foo] = escape_foo
self.assertEqual(con.escape({'foo': Foo()}, mapping), {'foo': "bar"})
def test_escape_list_item(self):
con = self.connections[0]
cur = con.cursor()
mapping = con.encoders.copy()
mapping[Foo] = escape_foo
self.assertEqual(con.escape([Foo()], mapping), "(bar)")