Extended tpool's autowrap functionality to use the name of the attribute, which is a lot cleaner than type-checking (which is still supported). Added psycopg tests to db_pool tests.
This commit is contained in:
		@@ -57,7 +57,7 @@ class BaseConnectionPool(Pool):
 | 
			
		||||
        possible expiration.
 | 
			
		||||
 | 
			
		||||
        If this function is called when a timer is already scheduled, it does
 | 
			
		||||
        nothing.
 | 
			
		||||
       nothing.
 | 
			
		||||
 | 
			
		||||
        If max_age or max_idle is 0, _schedule_expiration likewise does nothing.
 | 
			
		||||
        """
 | 
			
		||||
@@ -251,13 +251,8 @@ class TpooledConnectionPool(BaseConnectionPool):
 | 
			
		||||
        timeout = api.exc_after(connect_timeout, ConnectTimeout())
 | 
			
		||||
        try:
 | 
			
		||||
            from eventlet import tpool
 | 
			
		||||
            try:
 | 
			
		||||
                # *FIX: this is a huge hack that will probably only work for MySQLdb
 | 
			
		||||
                autowrap = (db_module.cursors.DictCursor,)
 | 
			
		||||
            except:
 | 
			
		||||
                autowrap = ()
 | 
			
		||||
            conn = tpool.execute(db_module.connect, *args, **kw)
 | 
			
		||||
            return tpool.Proxy(conn, autowrap=autowrap)
 | 
			
		||||
            return tpool.Proxy(conn, autowrap_names=('cursor',))
 | 
			
		||||
        finally:
 | 
			
		||||
            timeout.cancel()
 | 
			
		||||
 | 
			
		||||
 
 | 
			
		||||
@@ -132,28 +132,42 @@ def proxy_call(autowrap, f, *args, **kwargs):
 | 
			
		||||
 | 
			
		||||
class Proxy(object):
 | 
			
		||||
    """
 | 
			
		||||
    A simple proxy-wrapper of any object, in order to forward every method
 | 
			
		||||
    invocation onto a thread in the native-thread pool.  A key restriction is
 | 
			
		||||
    that the object's methods cannot use Eventlet primitives without great care,
 | 
			
		||||
    since the Eventlet dispatcher runs on a different native thread.
 | 
			
		||||
    
 | 
			
		||||
    Construct the Proxy with the instance that you want proxied.  The optional 
 | 
			
		||||
    parameter *autowrap* is used when methods are called on the proxied object.  
 | 
			
		||||
    If a method on the proxied object returns something whose type is in 
 | 
			
		||||
    *autowrap*, then that object gets a Proxy wrapped around it, too.  An 
 | 
			
		||||
    example use case for this is ensuring that DB-API connection objects 
 | 
			
		||||
    return cursor objects that are also Proxy-wrapped.
 | 
			
		||||
    a simple proxy-wrapper of any object that comes with a
 | 
			
		||||
    methods-only interface, in order to forward every method
 | 
			
		||||
    invocation onto a thread in the native-thread pool.  A key
 | 
			
		||||
    restriction is that the object's methods should not switch
 | 
			
		||||
    greenlets or use Eventlet primitives, since they are in a
 | 
			
		||||
    different thread from the main hub, and therefore might behave
 | 
			
		||||
    unexpectedly.  This is for running native-threaded code
 | 
			
		||||
    only.
 | 
			
		||||
 | 
			
		||||
    It's common to want to have some of the attributes or return
 | 
			
		||||
    values also wrapped in Proxy objects (for example, database
 | 
			
		||||
    connection objects produce cursor objects which also should be
 | 
			
		||||
    wrapped in Proxy objects to remain nonblocking).  *autowrap*, if
 | 
			
		||||
    supplied, is a collection of types; if an attribute or return
 | 
			
		||||
    value matches one of those types (via isinstance), it will be
 | 
			
		||||
    wrapped in a Proxy.  *autowrap_names* is a collection
 | 
			
		||||
    of strings, which represent the names of attributes that should be
 | 
			
		||||
    wrapped in Proxy objects when accessed.
 | 
			
		||||
    """
 | 
			
		||||
    def __init__(self, obj,autowrap=()):
 | 
			
		||||
    def __init__(self, obj,autowrap=(), autowrap_names=()):
 | 
			
		||||
        self._obj = obj
 | 
			
		||||
        self._autowrap = autowrap
 | 
			
		||||
        self._autowrap_names = autowrap_names
 | 
			
		||||
 | 
			
		||||
    def __getattr__(self,attr_name):
 | 
			
		||||
        f = getattr(self._obj,attr_name)
 | 
			
		||||
        if not callable(f):
 | 
			
		||||
            if (isinstance(f, self._autowrap) or
 | 
			
		||||
                attr_name in self._autowrap_names):
 | 
			
		||||
                return Proxy(f, self._autowrap)
 | 
			
		||||
            return f
 | 
			
		||||
        def doit(*args, **kwargs):
 | 
			
		||||
            return proxy_call(self._autowrap, f, *args, **kwargs)
 | 
			
		||||
            result = proxy_call(self._autowrap, f, *args, **kwargs)
 | 
			
		||||
            if attr_name in self._autowrap_names and not isinstance(result, Proxy):
 | 
			
		||||
                return Proxy(result)
 | 
			
		||||
            return result
 | 
			
		||||
        return doit
 | 
			
		||||
 | 
			
		||||
    # the following are a buncha methods that the python interpeter
 | 
			
		||||
 
 | 
			
		||||
@@ -5,6 +5,7 @@ from unittest import TestCase, main
 | 
			
		||||
from eventlet import api
 | 
			
		||||
from eventlet import event
 | 
			
		||||
from eventlet import db_pool
 | 
			
		||||
import os
 | 
			
		||||
 | 
			
		||||
class DBTester(object):
 | 
			
		||||
    __test__ = False  # so that nose doesn't try to execute this directly
 | 
			
		||||
@@ -16,7 +17,7 @@ class DBTester(object):
 | 
			
		||||
        cursor.execute("""CREATE  TABLE gargleblatz 
 | 
			
		||||
        (
 | 
			
		||||
        a INTEGER
 | 
			
		||||
        ) ENGINE = InnoDB;""")
 | 
			
		||||
        );""")
 | 
			
		||||
        connection.commit()
 | 
			
		||||
        cursor.close()
 | 
			
		||||
        
 | 
			
		||||
@@ -25,7 +26,7 @@ class DBTester(object):
 | 
			
		||||
            self.connection.close()
 | 
			
		||||
        self.drop_db()
 | 
			
		||||
 | 
			
		||||
    def set_up_test_table(self, connection = None):
 | 
			
		||||
    def set_up_dummy_table(self, connection = None):
 | 
			
		||||
        close_connection = False
 | 
			
		||||
        if connection is None:
 | 
			
		||||
            close_connection = True
 | 
			
		||||
@@ -35,18 +36,7 @@ class DBTester(object):
 | 
			
		||||
                connection = self.connection
 | 
			
		||||
 | 
			
		||||
        cursor = connection.cursor()
 | 
			
		||||
        cursor.execute("""CREATE TEMPORARY TABLE test_table 
 | 
			
		||||
        (
 | 
			
		||||
        row_id INTEGER PRIMARY KEY AUTO_INCREMENT, 
 | 
			
		||||
        value_int INTEGER, 
 | 
			
		||||
        value_float FLOAT, 
 | 
			
		||||
        value_string VARCHAR(200), 
 | 
			
		||||
        value_uuid CHAR(36), 
 | 
			
		||||
        value_binary BLOB, 
 | 
			
		||||
        value_binary_string VARCHAR(200) BINARY, 
 | 
			
		||||
        value_enum ENUM('Y','N'), 
 | 
			
		||||
        created TIMESTAMP 
 | 
			
		||||
        ) ENGINE = InnoDB;""")
 | 
			
		||||
        cursor.execute(self.dummy_table_sql)
 | 
			
		||||
        connection.commit()
 | 
			
		||||
        cursor.close()
 | 
			
		||||
        if close_connection:
 | 
			
		||||
@@ -57,21 +47,21 @@ class Mock(object):
 | 
			
		||||
    pass
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class TestDBConnectionPool(DBTester):
 | 
			
		||||
class DBConnectionPool(DBTester):
 | 
			
		||||
    __test__ = False  # so that nose doesn't try to execute this directly
 | 
			
		||||
    def setUp(self):
 | 
			
		||||
        super(TestDBConnectionPool, self).setUp()
 | 
			
		||||
        super(DBConnectionPool, self).setUp()
 | 
			
		||||
        self.pool = self.create_pool()
 | 
			
		||||
        self.connection = self.pool.get()
 | 
			
		||||
    
 | 
			
		||||
    def tearDown(self):
 | 
			
		||||
        if self.connection:
 | 
			
		||||
            self.pool.put(self.connection)
 | 
			
		||||
        super(TestDBConnectionPool, self).tearDown()
 | 
			
		||||
        self.pool.clear()
 | 
			
		||||
        super(DBConnectionPool, self).tearDown()
 | 
			
		||||
 | 
			
		||||
    def assert_cursor_works(self, cursor):
 | 
			
		||||
        # TODO: this is pretty mysql-specific
 | 
			
		||||
        cursor.execute("show full processlist")
 | 
			
		||||
        cursor.execute("select 1")
 | 
			
		||||
        rows = cursor.fetchall()
 | 
			
		||||
        self.assert_(rows)
 | 
			
		||||
 | 
			
		||||
@@ -142,7 +132,7 @@ class TestDBConnectionPool(DBTester):
 | 
			
		||||
    def test_returns_immediately(self):
 | 
			
		||||
        self.pool = self.create_pool()
 | 
			
		||||
        conn = self.pool.get()
 | 
			
		||||
        self.set_up_test_table(conn)
 | 
			
		||||
        self.set_up_dummy_table(conn)
 | 
			
		||||
        self.fill_up_table(conn)
 | 
			
		||||
        curs = conn.cursor()
 | 
			
		||||
        results = []
 | 
			
		||||
@@ -163,7 +153,7 @@ class TestDBConnectionPool(DBTester):
 | 
			
		||||
    def test_connection_is_clean_after_put(self):
 | 
			
		||||
        self.pool = self.create_pool()
 | 
			
		||||
        conn = self.pool.get()
 | 
			
		||||
        self.set_up_test_table(conn)
 | 
			
		||||
        self.set_up_dummy_table(conn)
 | 
			
		||||
        curs = conn.cursor()
 | 
			
		||||
        for i in range(10):
 | 
			
		||||
            curs.execute('insert into test_table (value_int) values (%s)' % i)
 | 
			
		||||
@@ -175,9 +165,9 @@ class TestDBConnectionPool(DBTester):
 | 
			
		||||
        for i in range(10):
 | 
			
		||||
            curs2.execute('insert into test_table (value_int) values (%s)' % i)
 | 
			
		||||
        conn2.commit()
 | 
			
		||||
        rows = curs2.execute("select * from test_table")
 | 
			
		||||
        curs2.execute("select * from test_table")
 | 
			
		||||
        # we should have only inserted them once
 | 
			
		||||
        self.assertEqual(10, rows)
 | 
			
		||||
        self.assertEqual(10, curs2.rowcount)
 | 
			
		||||
 | 
			
		||||
    def test_visibility_from_other_connections(self):
 | 
			
		||||
        self.pool = self.create_pool(3)
 | 
			
		||||
@@ -186,26 +176,28 @@ class TestDBConnectionPool(DBTester):
 | 
			
		||||
        curs = conn.cursor()
 | 
			
		||||
        try:
 | 
			
		||||
            curs2 = conn2.cursor()
 | 
			
		||||
            rows2 = curs2.execute("insert into gargleblatz (a) values (%s)" % (314159))
 | 
			
		||||
            self.assertEqual(rows2, 1)
 | 
			
		||||
            curs2.execute("insert into gargleblatz (a) values (%s)" % (314159))
 | 
			
		||||
            self.assertEqual(curs2.rowcount, 1)
 | 
			
		||||
            conn2.commit()
 | 
			
		||||
            selection_query = "select * from gargleblatz"
 | 
			
		||||
            rows2 = curs2.execute(selection_query)
 | 
			
		||||
            self.assertEqual(rows2, 1)
 | 
			
		||||
            curs2.execute(selection_query)
 | 
			
		||||
            self.assertEqual(curs2.rowcount, 1)
 | 
			
		||||
            del curs2
 | 
			
		||||
            del conn2
 | 
			
		||||
            self.pool.put(conn2)
 | 
			
		||||
            # create a new connection, it should see the addition
 | 
			
		||||
            conn3 = self.pool.get()
 | 
			
		||||
            curs3 = conn3.cursor()
 | 
			
		||||
            rows3 = curs3.execute(selection_query)
 | 
			
		||||
            self.assertEqual(rows3, 1)
 | 
			
		||||
            curs3.execute(selection_query)
 | 
			
		||||
            self.assertEqual(curs3.rowcount, 1)
 | 
			
		||||
            # now, does the already-open connection see it?
 | 
			
		||||
            rows = curs.execute(selection_query)
 | 
			
		||||
            self.assertEqual(rows, 1)
 | 
			
		||||
            curs.execute(selection_query)
 | 
			
		||||
            self.assertEqual(curs.rowcount, 1)
 | 
			
		||||
            self.pool.put(conn3)
 | 
			
		||||
        finally:
 | 
			
		||||
            # clean up my litter
 | 
			
		||||
            curs.execute("delete from gargleblatz where a=314159")
 | 
			
		||||
            conn.commit()
 | 
			
		||||
            self.pool.put(conn)
 | 
			
		||||
        
 | 
			
		||||
    @skipped
 | 
			
		||||
    def test_two_simultaneous_connections(self):
 | 
			
		||||
@@ -213,11 +205,11 @@ class TestDBConnectionPool(DBTester):
 | 
			
		||||
        # way to do this
 | 
			
		||||
        self.pool = self.create_pool(2)
 | 
			
		||||
        conn = self.pool.get()
 | 
			
		||||
        self.set_up_test_table(conn)
 | 
			
		||||
        self.set_up_dummy_table(conn)
 | 
			
		||||
        self.fill_up_table(conn)
 | 
			
		||||
        curs = conn.cursor()
 | 
			
		||||
        conn2 = self.pool.get()
 | 
			
		||||
        self.set_up_test_table(conn2)
 | 
			
		||||
        self.set_up_dummy_table(conn2)
 | 
			
		||||
        self.fill_up_table(conn2)
 | 
			
		||||
        curs2 = conn2.cursor()
 | 
			
		||||
        results = []
 | 
			
		||||
@@ -359,12 +351,6 @@ class TestDBConnectionPool(DBTester):
 | 
			
		||||
        conn2.close()  # should not be added to the free items
 | 
			
		||||
        self.assertEquals(len(self.pool.free_items), 0)
 | 
			
		||||
 | 
			
		||||
    def test_connection_timeout(self):
 | 
			
		||||
        # use a nonexistent ip address -- this one is reserved by IANA
 | 
			
		||||
        self._auth['host'] = '192.0.2.1'
 | 
			
		||||
        pool = self.create_pool()
 | 
			
		||||
        self.assertRaises(db_pool.ConnectTimeout, pool.get)
 | 
			
		||||
 | 
			
		||||
    def test_waiters_get_woken(self):
 | 
			
		||||
        # verify that when there's someone waiting on an empty pool
 | 
			
		||||
        # and someone puts an immediately-closed connection back in
 | 
			
		||||
@@ -425,7 +411,7 @@ class RaisingDBModule(object):
 | 
			
		||||
        raise RuntimeError()
 | 
			
		||||
    
 | 
			
		||||
 | 
			
		||||
class TestTpoolConnectionPool(TestDBConnectionPool):
 | 
			
		||||
class TpoolConnectionPool(DBConnectionPool):
 | 
			
		||||
    __test__ = False  # so that nose doesn't try to execute this directly
 | 
			
		||||
    def create_pool(self, max_size = 1, max_idle = 10, max_age = 10, connect_timeout=0.5, module=None):
 | 
			
		||||
        if module is None:
 | 
			
		||||
@@ -440,16 +426,16 @@ class TestTpoolConnectionPool(TestDBConnectionPool):
 | 
			
		||||
    def setUp(self):
 | 
			
		||||
        from eventlet import tpool
 | 
			
		||||
        tpool.QUIET = True
 | 
			
		||||
        super(TestTpoolConnectionPool, self).setUp()
 | 
			
		||||
        super(TpoolConnectionPool, self).setUp()
 | 
			
		||||
        
 | 
			
		||||
    def tearDown(self):
 | 
			
		||||
        from eventlet import tpool
 | 
			
		||||
        tpool.QUIET = False
 | 
			
		||||
        tpool.killall()
 | 
			
		||||
        super(TestTpoolConnectionPool, self).tearDown()
 | 
			
		||||
        super(TpoolConnectionPool, self).tearDown()
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class TestRawConnectionPool(TestDBConnectionPool):
 | 
			
		||||
class RawConnectionPool(DBConnectionPool):
 | 
			
		||||
    __test__ = False  # so that nose doesn't try to execute this directly
 | 
			
		||||
    def create_pool(self, max_size = 1, max_idle = 10, max_age = 10, connect_timeout= 0.5, module=None):
 | 
			
		||||
        if module is None:
 | 
			
		||||
@@ -460,45 +446,69 @@ class TestRawConnectionPool(TestDBConnectionPool):
 | 
			
		||||
            connect_timeout=connect_timeout,
 | 
			
		||||
            **self._auth)
 | 
			
		||||
 | 
			
		||||
    def test_connection_timeout(self):
 | 
			
		||||
        pass # not gonna work for raw connections because they're blocking
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def get_auth():
 | 
			
		||||
    try:
 | 
			
		||||
        import simplejson
 | 
			
		||||
        import os.path
 | 
			
		||||
        auth_utf8 = simplejson.load(open(os.path.join(os.path.dirname(__file__), 'auth.json')))
 | 
			
		||||
        # have to convert unicode objects to str objects because mysqldb is dum
 | 
			
		||||
        return dict([(str(k), str(v))
 | 
			
		||||
                     for k, v in auth_utf8.items()])
 | 
			
		||||
    except (IOError, ImportError), e:
 | 
			
		||||
        return {'host': 'localhost','user': 'root','passwd': '','db': 'persist0'}
 | 
			
		||||
    """Looks in the local directory and in the user's home directory for a file named ".test_dbauth",
 | 
			
		||||
    which contains a json map of parameters to the connect function.
 | 
			
		||||
    """
 | 
			
		||||
    files = [os.path.join(os.path.dirname(__file__), '.test_dbauth'),
 | 
			
		||||
             os.path.join(os.path.expanduser('~'), '.test_dbauth')]
 | 
			
		||||
    for f in files:
 | 
			
		||||
        try:
 | 
			
		||||
            import simplejson
 | 
			
		||||
            auth_utf8 = simplejson.load(open(f))
 | 
			
		||||
            # have to convert unicode objects to str objects because mysqldb is dum
 | 
			
		||||
            # using a doubly-nested list comprehension because we know that the structure
 | 
			
		||||
            # of the structure is a two-level dict
 | 
			
		||||
            return dict([(str(modname), dict([(str(k), str(v))
 | 
			
		||||
                                       for k, v in connectargs.items()]))
 | 
			
		||||
                         for modname, connectargs in auth_utf8.items()])
 | 
			
		||||
        except (IOError, ImportError), e:
 | 
			
		||||
            pass
 | 
			
		||||
    return {'MySQLdb':{'host': 'localhost','user': 'root','passwd': '','db': 'persist0'},
 | 
			
		||||
            'psycopg2':{'database':'test', 'user':'test'}}
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def mysql_requirement(_f):
 | 
			
		||||
    try:
 | 
			
		||||
        import MySQLdb
 | 
			
		||||
        try:
 | 
			
		||||
            MySQLdb.connect(**get_auth())
 | 
			
		||||
            auth = get_auth()['MySQLdb'].copy()
 | 
			
		||||
            auth.pop('db')
 | 
			
		||||
            MySQLdb.connect(**auth)
 | 
			
		||||
            return True
 | 
			
		||||
        except MySQLdb.OperationalError:
 | 
			
		||||
            print "Skipping mysql tests, error when connecting"
 | 
			
		||||
            import traceback
 | 
			
		||||
            traceback.print_exc()
 | 
			
		||||
            return False
 | 
			
		||||
    except ImportError:
 | 
			
		||||
        print "Skipping mysql tests, MySQLdb not importable"
 | 
			
		||||
        return False
 | 
			
		||||
 | 
			
		||||
class TestMysqlConnectionPool(object):
 | 
			
		||||
    __test__ = True
 | 
			
		||||
class MysqlConnectionPool(object):
 | 
			
		||||
    dummy_table_sql = """CREATE TEMPORARY TABLE test_table 
 | 
			
		||||
        (
 | 
			
		||||
        row_id INTEGER PRIMARY KEY AUTO_INCREMENT,
 | 
			
		||||
        value_int INTEGER, 
 | 
			
		||||
        value_float FLOAT, 
 | 
			
		||||
        value_string VARCHAR(200), 
 | 
			
		||||
        value_uuid CHAR(36), 
 | 
			
		||||
        value_binary BLOB, 
 | 
			
		||||
        value_binary_string VARCHAR(200) BINARY, 
 | 
			
		||||
        value_enum ENUM('Y','N'), 
 | 
			
		||||
        created TIMESTAMP 
 | 
			
		||||
        ) ENGINE=InnoDB;"""
 | 
			
		||||
 | 
			
		||||
    @skip_unless(mysql_requirement)
 | 
			
		||||
    def setUp(self):
 | 
			
		||||
        import MySQLdb
 | 
			
		||||
        self._dbmodule = MySQLdb
 | 
			
		||||
        self._auth = get_auth()
 | 
			
		||||
        super(TestMysqlConnectionPool, self).setUp()
 | 
			
		||||
        self._auth = get_auth()['MySQLdb']
 | 
			
		||||
        super(MysqlConnectionPool, self).setUp()
 | 
			
		||||
        
 | 
			
		||||
    def tearDown(self):
 | 
			
		||||
        pass
 | 
			
		||||
        super(MysqlConnectionPool, self).tearDown()
 | 
			
		||||
 | 
			
		||||
    def create_db(self):
 | 
			
		||||
        auth = self._auth.copy()
 | 
			
		||||
@@ -519,13 +529,80 @@ class TestMysqlConnectionPool(object):
 | 
			
		||||
        del db
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class Test01MysqlTpool(TestMysqlConnectionPool, TestTpoolConnectionPool, TestCase):
 | 
			
		||||
    pass
 | 
			
		||||
class Test01MysqlTpool(MysqlConnectionPool, TpoolConnectionPool, TestCase):
 | 
			
		||||
    __test__ = True
 | 
			
		||||
 | 
			
		||||
class Test02MysqlRaw(TestMysqlConnectionPool, TestRawConnectionPool, TestCase):
 | 
			
		||||
    pass
 | 
			
		||||
class Test02MysqlRaw(MysqlConnectionPool, RawConnectionPool, TestCase):
 | 
			
		||||
    __test__ = True
 | 
			
		||||
 | 
			
		||||
def postgres_requirement(_f):
 | 
			
		||||
    try:
 | 
			
		||||
        import psycopg2
 | 
			
		||||
        try:
 | 
			
		||||
            auth = get_auth()['psycopg2'].copy()
 | 
			
		||||
            auth.pop('database')
 | 
			
		||||
            psycopg2.connect(**auth)
 | 
			
		||||
            return True
 | 
			
		||||
        except psycopg2.OperationalError:
 | 
			
		||||
            print "Skipping postgres tests, error when connecting"
 | 
			
		||||
            return False
 | 
			
		||||
    except ImportError:
 | 
			
		||||
        print "Skipping postgres tests, psycopg2 not importable"
 | 
			
		||||
        return False
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class Psycopg2ConnectionPool(object):
 | 
			
		||||
    dummy_table_sql = """CREATE TEMPORARY TABLE test_table 
 | 
			
		||||
        (
 | 
			
		||||
        row_id SERIAL PRIMARY KEY,
 | 
			
		||||
        value_int INTEGER, 
 | 
			
		||||
        value_float FLOAT, 
 | 
			
		||||
        value_string VARCHAR(200), 
 | 
			
		||||
        value_uuid CHAR(36), 
 | 
			
		||||
        value_binary BYTEA, 
 | 
			
		||||
        value_binary_string BYTEA,
 | 
			
		||||
        created TIMESTAMP 
 | 
			
		||||
        );"""
 | 
			
		||||
 | 
			
		||||
    @skip_unless(postgres_requirement)
 | 
			
		||||
    def setUp(self):
 | 
			
		||||
        import psycopg2
 | 
			
		||||
        self._dbmodule = psycopg2
 | 
			
		||||
        self._auth = get_auth()['psycopg2']
 | 
			
		||||
        super(Psycopg2ConnectionPool, self).setUp()
 | 
			
		||||
 | 
			
		||||
    def tearDown(self):
 | 
			
		||||
        super(Psycopg2ConnectionPool, self).tearDown()
 | 
			
		||||
 | 
			
		||||
    def create_db(self):
 | 
			
		||||
        try:
 | 
			
		||||
            self.drop_db()
 | 
			
		||||
        except Exception:
 | 
			
		||||
            pass
 | 
			
		||||
        auth = self._auth.copy()
 | 
			
		||||
        dbname = auth.pop('database')
 | 
			
		||||
        conn = self._dbmodule.connect(**auth)
 | 
			
		||||
        conn.set_isolation_level(0)
 | 
			
		||||
        db = conn.cursor()
 | 
			
		||||
        db.execute("create database "+dbname)
 | 
			
		||||
        db.close()
 | 
			
		||||
        del db
 | 
			
		||||
 | 
			
		||||
    def drop_db(self):
 | 
			
		||||
        auth = self._auth.copy()
 | 
			
		||||
        dbname = auth.pop('database')
 | 
			
		||||
        conn = self._dbmodule.connect(**auth)
 | 
			
		||||
        conn.set_isolation_level(0)
 | 
			
		||||
        db = conn.cursor()
 | 
			
		||||
        db.execute("drop database "+self._auth['database'])
 | 
			
		||||
        db.close()
 | 
			
		||||
        del db
 | 
			
		||||
 | 
			
		||||
class Test01Psycopg2Tpool(Psycopg2ConnectionPool, TpoolConnectionPool, TestCase):
 | 
			
		||||
    __test__ = True
 | 
			
		||||
 | 
			
		||||
class Test02Psycopg2Raw(Psycopg2ConnectionPool, RawConnectionPool, TestCase):
 | 
			
		||||
    __test__ = True
 | 
			
		||||
 | 
			
		||||
if __name__ == '__main__':
 | 
			
		||||
    main()
 | 
			
		||||
 
 | 
			
		||||
@@ -25,6 +25,7 @@ import eventlet
 | 
			
		||||
one = 1
 | 
			
		||||
two = 2
 | 
			
		||||
three = 3
 | 
			
		||||
none = None
 | 
			
		||||
 | 
			
		||||
def noop():
 | 
			
		||||
    pass
 | 
			
		||||
@@ -164,6 +165,34 @@ class TestTpool(LimitedTestCase):
 | 
			
		||||
        tpool.killall()
 | 
			
		||||
        tpool.setup()
 | 
			
		||||
 | 
			
		||||
    @skip_with_pyevent
 | 
			
		||||
    def test_autowrap(self):
 | 
			
		||||
        x = tpool.Proxy({'a':1, 'b':2}, autowrap=(int,))
 | 
			
		||||
        self.assert_(isinstance(x.get('a'), tpool.Proxy))
 | 
			
		||||
        self.assert_(not isinstance(x.items(), tpool.Proxy))
 | 
			
		||||
        # attributes as well as callables
 | 
			
		||||
        from tests import tpool_test
 | 
			
		||||
        x = tpool.Proxy(tpool_test, autowrap=(int,))
 | 
			
		||||
        self.assert_(isinstance(x.one, tpool.Proxy))
 | 
			
		||||
        self.assert_(not isinstance(x.none, tpool.Proxy))
 | 
			
		||||
 | 
			
		||||
    @skip_with_pyevent
 | 
			
		||||
    def test_autowrap_names(self):
 | 
			
		||||
        x = tpool.Proxy({'a':1, 'b':2}, autowrap_names=('get',))
 | 
			
		||||
        self.assert_(isinstance(x.get('a'), tpool.Proxy))
 | 
			
		||||
        self.assert_(not isinstance(x.items(), tpool.Proxy))
 | 
			
		||||
        from tests import tpool_test
 | 
			
		||||
        x = tpool.Proxy(tpool_test, autowrap_names=('one',))
 | 
			
		||||
        self.assert_(isinstance(x.one, tpool.Proxy))
 | 
			
		||||
        self.assert_(not isinstance(x.two, tpool.Proxy))
 | 
			
		||||
 | 
			
		||||
    @skip_with_pyevent
 | 
			
		||||
    def test_autowrap_both(self):
 | 
			
		||||
        from tests import tpool_test
 | 
			
		||||
        x = tpool.Proxy(tpool_test, autowrap=(int,), autowrap_names=('one',))
 | 
			
		||||
        self.assert_(isinstance(x.one, tpool.Proxy))
 | 
			
		||||
        # violating the abstraction to check that we didn't double-wrap
 | 
			
		||||
        self.assert_(not isinstance(x._obj, tpool.Proxy))
 | 
			
		||||
 | 
			
		||||
class TpoolLongTests(LimitedTestCase):
 | 
			
		||||
    TEST_TIMEOUT=60
 | 
			
		||||
 
 | 
			
		||||
		Reference in New Issue
	
	Block a user