87 lines
		
	
	
		
			2.6 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			87 lines
		
	
	
		
			2.6 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
import gc
 | 
						|
import json
 | 
						|
import os
 | 
						|
import re
 | 
						|
import warnings
 | 
						|
 | 
						|
import unittest2
 | 
						|
 | 
						|
import pymysql
 | 
						|
from .._compat import CPYTHON
 | 
						|
 | 
						|
 | 
						|
class PyMySQLTestCase(unittest2.TestCase):
 | 
						|
    # You can specify your test environment creating a file named
 | 
						|
    #  "databases.json" or editing the `databases` variable below.
 | 
						|
    fname = os.path.join(os.path.dirname(__file__), "databases.json")
 | 
						|
    if os.path.exists(fname):
 | 
						|
        with open(fname) as f:
 | 
						|
            databases = json.load(f)
 | 
						|
    else:
 | 
						|
        databases = [
 | 
						|
            {"host":"localhost","user":"root",
 | 
						|
             "passwd":"","db":"test_pymysql", "use_unicode": True, 'local_infile': True},
 | 
						|
            {"host":"localhost","user":"root","passwd":"","db":"test_pymysql2"}]
 | 
						|
 | 
						|
    def mysql_server_is(self, conn, version_tuple):
 | 
						|
        """Return True if the given connection is on the version given or
 | 
						|
        greater.
 | 
						|
 | 
						|
        e.g.::
 | 
						|
 | 
						|
            if self.mysql_server_is(conn, (5, 6, 4)):
 | 
						|
                # do something for MySQL 5.6.4 and above
 | 
						|
        """
 | 
						|
        server_version = conn.get_server_info()
 | 
						|
        server_version_tuple = tuple(
 | 
						|
            (int(dig) if dig is not None else 0)
 | 
						|
            for dig in
 | 
						|
            re.match(r'(\d+)\.(\d+)\.(\d+)', server_version).group(1, 2, 3)
 | 
						|
        )
 | 
						|
        return server_version_tuple >= version_tuple
 | 
						|
 | 
						|
    def setUp(self):
 | 
						|
        self.connections = []
 | 
						|
        for params in self.databases:
 | 
						|
            self.connections.append(pymysql.connect(**params))
 | 
						|
        self.addCleanup(self._teardown_connections)
 | 
						|
 | 
						|
    def _teardown_connections(self):
 | 
						|
        for connection in self.connections:
 | 
						|
            connection.close()
 | 
						|
 | 
						|
    def safe_create_table(self, connection, tablename, ddl, cleanup=True):
 | 
						|
        """create a table.
 | 
						|
 | 
						|
        Ensures any existing version of that table is first dropped.
 | 
						|
 | 
						|
        Also adds a cleanup rule to drop the table after the test
 | 
						|
        completes.
 | 
						|
        """
 | 
						|
        cursor = connection.cursor()
 | 
						|
 | 
						|
        with warnings.catch_warnings():
 | 
						|
            warnings.simplefilter("ignore")
 | 
						|
            cursor.execute("drop table if exists `%s`" % (tablename,))
 | 
						|
        cursor.execute(ddl)
 | 
						|
        cursor.close()
 | 
						|
        if cleanup:
 | 
						|
            self.addCleanup(self.drop_table, connection, tablename)
 | 
						|
 | 
						|
    def drop_table(self, connection, tablename):
 | 
						|
        cursor = connection.cursor()
 | 
						|
        with warnings.catch_warnings():
 | 
						|
            warnings.simplefilter("ignore")
 | 
						|
            cursor.execute("drop table if exists `%s`" % (tablename,))
 | 
						|
        cursor.close()
 | 
						|
 | 
						|
    def safe_gc_collect(self):
 | 
						|
        """Ensure cycles are collected via gc.
 | 
						|
 | 
						|
        Runs additional times on non-CPython platforms.
 | 
						|
 | 
						|
        """
 | 
						|
        gc.collect()
 | 
						|
        if not CPYTHON:
 | 
						|
            gc.collect()
 |