196 lines
		
	
	
		
			6.2 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			196 lines
		
	
	
		
			6.2 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
| import datetime
 | |
| import decimal
 | |
| import pymysql
 | |
| import time
 | |
| from pymysql.tests import base
 | |
| 
 | |
| 
 | |
| class TestConnection(base.PyMySQLTestCase):
 | |
|     def test_utf8mb4(self):
 | |
|         """This test requires MySQL >= 5.5"""
 | |
|         arg = self.databases[0].copy()
 | |
|         arg['charset'] = 'utf8mb4'
 | |
|         conn = pymysql.connect(**arg)
 | |
| 
 | |
|     def test_largedata(self):
 | |
|         """Large query and response (>=16MB)"""
 | |
|         cur = self.connections[0].cursor()
 | |
|         cur.execute("SELECT @@max_allowed_packet")
 | |
|         if cur.fetchone()[0] < 16*1024*1024 + 10:
 | |
|             print("Set max_allowed_packet to bigger than 17MB")
 | |
|             return
 | |
|         t = 'a' * (16*1024*1024)
 | |
|         cur.execute("SELECT '" + t + "'")
 | |
|         assert cur.fetchone()[0] == t
 | |
| 
 | |
|     def test_autocommit(self):
 | |
|         con = self.connections[0]
 | |
|         self.assertFalse(con.get_autocommit())
 | |
| 
 | |
|         cur = con.cursor()
 | |
|         cur.execute("SET AUTOCOMMIT=1")
 | |
|         self.assertTrue(con.get_autocommit())
 | |
| 
 | |
|         con.autocommit(False)
 | |
|         self.assertFalse(con.get_autocommit())
 | |
|         cur.execute("SELECT @@AUTOCOMMIT")
 | |
|         self.assertEqual(cur.fetchone()[0], 0)
 | |
| 
 | |
|     def test_select_db(self):
 | |
|         con = self.connections[0]
 | |
|         current_db = self.databases[0]['db']
 | |
|         other_db = self.databases[1]['db']
 | |
| 
 | |
|         cur = con.cursor()
 | |
|         cur.execute('SELECT database()')
 | |
|         self.assertEqual(cur.fetchone()[0], current_db)
 | |
| 
 | |
|         con.select_db(other_db)
 | |
|         cur.execute('SELECT database()')
 | |
|         self.assertEqual(cur.fetchone()[0], other_db)
 | |
| 
 | |
|     def test_connection_gone_away(self):
 | |
|         """
 | |
|         http://dev.mysql.com/doc/refman/5.0/en/gone-away.html
 | |
|         http://dev.mysql.com/doc/refman/5.0/en/error-messages-client.html#error_cr_server_gone_error
 | |
|         """
 | |
|         con = self.connections[0]
 | |
|         cur = con.cursor()
 | |
|         cur.execute("SET wait_timeout=1")
 | |
|         time.sleep(2)
 | |
|         with self.assertRaises(pymysql.OperationalError) as cm:
 | |
|             cur.execute("SELECT 1+1")
 | |
|         # error occures while reading, not writing because of socket buffer.
 | |
|         #self.assertEquals(cm.exception.args[0], 2006)
 | |
|         self.assertIn(cm.exception.args[0], (2006, 2013))
 | |
| 
 | |
|     def test_init_command(self):
 | |
|         conn = pymysql.connect(
 | |
|             init_command='SELECT "bar"; SELECT "baz"',
 | |
|             **self.databases[0]
 | |
|         )
 | |
|         c = conn.cursor()
 | |
|         c.execute('select "foobar";')
 | |
|         self.assertEqual(('foobar',), c.fetchone())
 | |
|         conn.close()
 | |
|         with self.assertRaises(pymysql.err.Error):
 | |
|             conn.ping(reconnect=False)
 | |
| 
 | |
|     def test_read_default_group(self):
 | |
|         conn = pymysql.connect(
 | |
|             read_default_group='client',
 | |
|             **self.databases[0]
 | |
|         )
 | |
|         self.assertTrue(conn.open)
 | |
| 
 | |
|     def test_context(self):
 | |
|         with self.assertRaises(ValueError):
 | |
|             c = pymysql.connect(**self.databases[0])
 | |
|             with c as cur:
 | |
|                 cur.execute('create table test ( a int )')
 | |
|                 c.begin()
 | |
|                 cur.execute('insert into test values ((1))')
 | |
|                 raise ValueError('pseudo abort')
 | |
|                 c.commit()
 | |
|         c = pymysql.connect(**self.databases[0])
 | |
|         with c as cur:
 | |
|             cur.execute('select count(*) from test')
 | |
|             self.assertEqual(0, cur.fetchone()[0])
 | |
|             cur.execute('insert into test values ((1))')
 | |
|         with c as cur:
 | |
|             cur.execute('select count(*) from test')
 | |
|             self.assertEqual(1,cur.fetchone()[0])
 | |
|             cur.execute('drop table test')
 | |
| 
 | |
|     def test_set_charset(self):
 | |
|         c = pymysql.connect(**self.databases[0])
 | |
|         c.set_charset('utf8')
 | |
|         # TODO validate setting here
 | |
| 
 | |
|     def test_defer_connect(self):
 | |
|         import socket
 | |
|         for db in self.databases:
 | |
|             d = db.copy()
 | |
|             try:
 | |
|                 sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
 | |
|                 sock.connect(d['unix_socket'])
 | |
|             except KeyError:
 | |
|                 sock = socket.create_connection(
 | |
|                                 (d.get('host', 'localhost'), d.get('port', 3306)))
 | |
|             for k in ['unix_socket', 'host', 'port']:
 | |
|                 try:
 | |
|                     del d[k]
 | |
|                 except KeyError:
 | |
|                     pass
 | |
| 
 | |
|             c = pymysql.connect(defer_connect=True, **d)
 | |
|             self.assertFalse(c.open)
 | |
|             c.connect(sock)
 | |
|             c.close()
 | |
| 
 | |
| 
 | |
| # A custom type and function to escape it
 | |
| class Foo(object):
 | |
|     value = "bar"
 | |
| 
 | |
| 
 | |
| def escape_foo(x, d):
 | |
|     return x.value
 | |
| 
 | |
| 
 | |
| class TestEscape(base.PyMySQLTestCase):
 | |
|     def test_escape_string(self):
 | |
|         con = self.connections[0]
 | |
|         cur = con.cursor()
 | |
| 
 | |
|         self.assertEqual(con.escape("foo'bar"), "'foo\\'bar'")
 | |
|         # added NO_AUTO_CREATE_USER as not including it in 5.7 generates warnings
 | |
|         cur.execute("SET sql_mode='NO_BACKSLASH_ESCAPES,NO_AUTO_CREATE_USER'")
 | |
|         self.assertEqual(con.escape("foo'bar"), "'foo''bar'")
 | |
| 
 | |
|     def test_escape_builtin_encoders(self):
 | |
|         con = self.connections[0]
 | |
|         cur = con.cursor()
 | |
| 
 | |
|         val = datetime.datetime(2012, 3, 4, 5, 6)
 | |
|         self.assertEqual(con.escape(val, con.encoders), "'2012-03-04 05:06:00'")
 | |
| 
 | |
|     def test_escape_custom_object(self):
 | |
|         con = self.connections[0]
 | |
|         cur = con.cursor()
 | |
| 
 | |
|         mapping = {Foo: escape_foo}
 | |
|         self.assertEqual(con.escape(Foo(), mapping), "bar")
 | |
| 
 | |
|     def test_escape_fallback_encoder(self):
 | |
|         con = self.connections[0]
 | |
|         cur = con.cursor()
 | |
| 
 | |
|         class Custom(str):
 | |
|             pass
 | |
| 
 | |
|         mapping = {pymysql.text_type: pymysql.escape_string}
 | |
|         self.assertEqual(con.escape(Custom('foobar'), mapping), "'foobar'")
 | |
| 
 | |
|     def test_escape_no_default(self):
 | |
|         con = self.connections[0]
 | |
|         cur = con.cursor()
 | |
| 
 | |
|         self.assertRaises(TypeError, con.escape, 42, {})
 | |
| 
 | |
|     def test_escape_dict_value(self):
 | |
|         con = self.connections[0]
 | |
|         cur = con.cursor()
 | |
| 
 | |
|         mapping = con.encoders.copy()
 | |
|         mapping[Foo] = escape_foo
 | |
|         self.assertEqual(con.escape({'foo': Foo()}, mapping), {'foo': "bar"})
 | |
| 
 | |
|     def test_escape_list_item(self):
 | |
|         con = self.connections[0]
 | |
|         cur = con.cursor()
 | |
| 
 | |
|         mapping = con.encoders.copy()
 | |
|         mapping[Foo] = escape_foo
 | |
|         self.assertEqual(con.escape([Foo()], mapping), "(bar)")
 | 
