diff --git a/eventlet/db_pool.py b/eventlet/db_pool.py index f8b437c..5769412 100644 --- a/eventlet/db_pool.py +++ b/eventlet/db_pool.py @@ -33,12 +33,15 @@ class DatabaseConnector(object): """\ @brief This is an object which will maintain a collection of database connection pools keyed on host,databasename""" - def __init__(self, module, credentials, min_size = 0, max_size = 4, *args, **kwargs): + def __init__(self, module, credentials, min_size = 0, max_size = 4, conn_pool=None, *args, **kwargs): """\ @brief constructor @param min_size the minimum size of a child pool. @param max_size the maximum size of a child pool.""" assert(module) + self._conn_pool_class = conn_pool + if self._conn_pool_class is None: + self._conn_pool_class = ConnectionPool self._module = module self._min_size = min_size self._max_size = max_size @@ -60,33 +63,35 @@ connection pools keyed on host,databasename""" new_kwargs['db'] = dbname new_kwargs['host'] = host new_kwargs.update(self.credentials_for(host)) - dbpool = ConnectionPool(self._module, self._min_size, self._max_size, *self._args, **new_kwargs) + dbpool = self._conn_pool_class(self._module, self._min_size, self._max_size, *self._args, **new_kwargs) self._databases[key] = dbpool return self._databases[key] - -class ConnectionPool(Pool): - """A pool which gives out saranwrapped database connections from a pool - """ - def __init__(self, module, min_size = 0, max_size = 4, *args, **kwargs): - assert(module) - self._module = module +class BaseConnectionPool(Pool): + # *TODO: we need to expire and close connections if they've been + # idle for a while, so that system-wide connection count doesn't + # monotonically increase forever + def __init__(self, db_module, min_size = 0, max_size = 4, *args, **kwargs): + assert(db_module) + self._db_module = db_module self._args = args self._kwargs = kwargs - super(ConnectionPool, self).__init__(min_size, max_size) + super(BaseConnectionPool, self).__init__(min_size, max_size) - def create(self): - return saranwrap.wrap(self._module).connect(*self._args, **self._kwargs) + def get(self): + # wrap the connection for easier use + conn = super(BaseConnectionPool, self).get() + return PooledConnectionWrapper(conn, self) def put(self, conn): # rollback any uncommitted changes, so that the next client - # has a clean slate. This also pokes the process to see if + # has a clean slate. This also pokes the connection to see if # it's dead or None try: conn.rollback() except AttributeError, e: - # this means it's already been destroyed + # this means it's already been destroyed, so we don't need to print anything conn = None except: # we don't care what the exception was, we just know the @@ -104,14 +109,32 @@ class ConnectionPool(Pool): conn = None if conn is not None: - super(ConnectionPool, self).put(conn) + super(BaseConnectionPool, self).put(conn) else: self.current_size -= 1 + - def get(self): - # wrap the connection for easier use - conn = super(ConnectionPool, self).get() - return PooledConnectionWrapper(conn, self) +class SaranwrappedConnectionPool(BaseConnectionPool): + """A pool which gives out saranwrapped database connections from a pool + """ + def create(self): + return saranwrap.wrap(self._db_module).connect(*self._args, **self._kwargs) + +class TpooledConnectionPool(BaseConnectionPool): + """A pool which gives out tpool.Proxy-based database connections from a pool. + """ + def create(self): + from eventlet import tpool + try: + # *FIX: this is a huge hack that will probably only work for MySQLdb + autowrap = (self._db_module.cursors.DictCursor,) + except: + autowrap = () + return tpool.Proxy(self._db_module.connect(*self._args, **self._kwargs), + autowrap=autowrap) + +# default connection pool is the tpool one +ConnectionPool = TpooledConnectionPool class GenericConnectionWrapper(object): @@ -188,4 +211,3 @@ class PooledConnectionWrapper(GenericConnectionWrapper): def __del__(self): self.close() - diff --git a/eventlet/db_pool_test.py b/eventlet/db_pool_test.py index 628beaa..499b723 100644 --- a/eventlet/db_pool_test.py +++ b/eventlet/db_pool_test.py @@ -79,9 +79,6 @@ class TestDBConnectionPool(DBTester): self.pool.put(self.connection) super(TestDBConnectionPool, self).tearDown() - def create_pool(self, max_items = 1): - return db_pool.ConnectionPool(self._dbmodule, 0, max_items, **self._auth) - def assert_cursor_works(self, cursor): cursor.execute("show full processlist") rows = cursor.fetchall() @@ -253,7 +250,17 @@ class TestDBConnectionPool(DBTester): results.sort() self.assertEqual([1, 2], results) -class TestMysqlConnectionPool(TestDBConnectionPool, unittest.TestCase): + +class TestTpoolConnectionPool(TestDBConnectionPool): + def create_pool(self, max_items = 1): + return db_pool.TpooledConnectionPool(self._dbmodule, 0, max_items, **self._auth) + + +class TestSaranwrapConnectionPool(TestDBConnectionPool): + def create_pool(self, max_items = 1): + return db_pool.SaranwrappedConnectionPool(self._dbmodule, 0, max_items, **self._auth) + +class TestMysqlConnectionPool(object): def setUp(self): import MySQLdb self._dbmodule = MySQLdb @@ -286,6 +293,12 @@ class TestMysqlConnectionPool(TestDBConnectionPool, unittest.TestCase): db.close() del db +class TestMysqlTpool(TestMysqlConnectionPool, TestTpoolConnectionPool, unittest.TestCase): + pass + +class TestMysqlSaranwrap(TestMysqlConnectionPool, TestSaranwrapConnectionPool, unittest.TestCase): + pass + if __name__ == '__main__': unittest.main()