Extended tpool's autowrap functionality to use the name of the attribute, which is a lot cleaner than type-checking (which is still supported). Added psycopg tests to db_pool tests.

This commit is contained in:
Ryan Williams
2010-02-06 20:27:49 -08:00
parent 363653c9c5
commit ebff4177d2
4 changed files with 201 additions and 86 deletions

View File

@@ -57,7 +57,7 @@ class BaseConnectionPool(Pool):
possible expiration. possible expiration.
If this function is called when a timer is already scheduled, it does If this function is called when a timer is already scheduled, it does
nothing. nothing.
If max_age or max_idle is 0, _schedule_expiration likewise does nothing. If max_age or max_idle is 0, _schedule_expiration likewise does nothing.
""" """
@@ -251,13 +251,8 @@ class TpooledConnectionPool(BaseConnectionPool):
timeout = api.exc_after(connect_timeout, ConnectTimeout()) timeout = api.exc_after(connect_timeout, ConnectTimeout())
try: try:
from eventlet import tpool from eventlet import tpool
try:
# *FIX: this is a huge hack that will probably only work for MySQLdb
autowrap = (db_module.cursors.DictCursor,)
except:
autowrap = ()
conn = tpool.execute(db_module.connect, *args, **kw) conn = tpool.execute(db_module.connect, *args, **kw)
return tpool.Proxy(conn, autowrap=autowrap) return tpool.Proxy(conn, autowrap_names=('cursor',))
finally: finally:
timeout.cancel() timeout.cancel()

View File

@@ -132,28 +132,42 @@ def proxy_call(autowrap, f, *args, **kwargs):
class Proxy(object): class Proxy(object):
""" """
A simple proxy-wrapper of any object, in order to forward every method a simple proxy-wrapper of any object that comes with a
invocation onto a thread in the native-thread pool. A key restriction is methods-only interface, in order to forward every method
that the object's methods cannot use Eventlet primitives without great care, invocation onto a thread in the native-thread pool. A key
since the Eventlet dispatcher runs on a different native thread. restriction is that the object's methods should not switch
greenlets or use Eventlet primitives, since they are in a
Construct the Proxy with the instance that you want proxied. The optional different thread from the main hub, and therefore might behave
parameter *autowrap* is used when methods are called on the proxied object. unexpectedly. This is for running native-threaded code
If a method on the proxied object returns something whose type is in only.
*autowrap*, then that object gets a Proxy wrapped around it, too. An
example use case for this is ensuring that DB-API connection objects It's common to want to have some of the attributes or return
return cursor objects that are also Proxy-wrapped. values also wrapped in Proxy objects (for example, database
connection objects produce cursor objects which also should be
wrapped in Proxy objects to remain nonblocking). *autowrap*, if
supplied, is a collection of types; if an attribute or return
value matches one of those types (via isinstance), it will be
wrapped in a Proxy. *autowrap_names* is a collection
of strings, which represent the names of attributes that should be
wrapped in Proxy objects when accessed.
""" """
def __init__(self, obj,autowrap=()): def __init__(self, obj,autowrap=(), autowrap_names=()):
self._obj = obj self._obj = obj
self._autowrap = autowrap self._autowrap = autowrap
self._autowrap_names = autowrap_names
def __getattr__(self,attr_name): def __getattr__(self,attr_name):
f = getattr(self._obj,attr_name) f = getattr(self._obj,attr_name)
if not callable(f): if not callable(f):
if (isinstance(f, self._autowrap) or
attr_name in self._autowrap_names):
return Proxy(f, self._autowrap)
return f return f
def doit(*args, **kwargs): def doit(*args, **kwargs):
return proxy_call(self._autowrap, f, *args, **kwargs) result = proxy_call(self._autowrap, f, *args, **kwargs)
if attr_name in self._autowrap_names and not isinstance(result, Proxy):
return Proxy(result)
return result
return doit return doit
# the following are a buncha methods that the python interpeter # the following are a buncha methods that the python interpeter

View File

@@ -5,6 +5,7 @@ from unittest import TestCase, main
from eventlet import api from eventlet import api
from eventlet import event from eventlet import event
from eventlet import db_pool from eventlet import db_pool
import os
class DBTester(object): class DBTester(object):
__test__ = False # so that nose doesn't try to execute this directly __test__ = False # so that nose doesn't try to execute this directly
@@ -16,7 +17,7 @@ class DBTester(object):
cursor.execute("""CREATE TABLE gargleblatz cursor.execute("""CREATE TABLE gargleblatz
( (
a INTEGER a INTEGER
) ENGINE = InnoDB;""") );""")
connection.commit() connection.commit()
cursor.close() cursor.close()
@@ -25,7 +26,7 @@ class DBTester(object):
self.connection.close() self.connection.close()
self.drop_db() self.drop_db()
def set_up_test_table(self, connection = None): def set_up_dummy_table(self, connection = None):
close_connection = False close_connection = False
if connection is None: if connection is None:
close_connection = True close_connection = True
@@ -35,18 +36,7 @@ class DBTester(object):
connection = self.connection connection = self.connection
cursor = connection.cursor() cursor = connection.cursor()
cursor.execute("""CREATE TEMPORARY TABLE test_table cursor.execute(self.dummy_table_sql)
(
row_id INTEGER PRIMARY KEY AUTO_INCREMENT,
value_int INTEGER,
value_float FLOAT,
value_string VARCHAR(200),
value_uuid CHAR(36),
value_binary BLOB,
value_binary_string VARCHAR(200) BINARY,
value_enum ENUM('Y','N'),
created TIMESTAMP
) ENGINE = InnoDB;""")
connection.commit() connection.commit()
cursor.close() cursor.close()
if close_connection: if close_connection:
@@ -57,21 +47,21 @@ class Mock(object):
pass pass
class TestDBConnectionPool(DBTester): class DBConnectionPool(DBTester):
__test__ = False # so that nose doesn't try to execute this directly __test__ = False # so that nose doesn't try to execute this directly
def setUp(self): def setUp(self):
super(TestDBConnectionPool, self).setUp() super(DBConnectionPool, self).setUp()
self.pool = self.create_pool() self.pool = self.create_pool()
self.connection = self.pool.get() self.connection = self.pool.get()
def tearDown(self): def tearDown(self):
if self.connection: if self.connection:
self.pool.put(self.connection) self.pool.put(self.connection)
super(TestDBConnectionPool, self).tearDown() self.pool.clear()
super(DBConnectionPool, self).tearDown()
def assert_cursor_works(self, cursor): def assert_cursor_works(self, cursor):
# TODO: this is pretty mysql-specific cursor.execute("select 1")
cursor.execute("show full processlist")
rows = cursor.fetchall() rows = cursor.fetchall()
self.assert_(rows) self.assert_(rows)
@@ -142,7 +132,7 @@ class TestDBConnectionPool(DBTester):
def test_returns_immediately(self): def test_returns_immediately(self):
self.pool = self.create_pool() self.pool = self.create_pool()
conn = self.pool.get() conn = self.pool.get()
self.set_up_test_table(conn) self.set_up_dummy_table(conn)
self.fill_up_table(conn) self.fill_up_table(conn)
curs = conn.cursor() curs = conn.cursor()
results = [] results = []
@@ -163,7 +153,7 @@ class TestDBConnectionPool(DBTester):
def test_connection_is_clean_after_put(self): def test_connection_is_clean_after_put(self):
self.pool = self.create_pool() self.pool = self.create_pool()
conn = self.pool.get() conn = self.pool.get()
self.set_up_test_table(conn) self.set_up_dummy_table(conn)
curs = conn.cursor() curs = conn.cursor()
for i in range(10): for i in range(10):
curs.execute('insert into test_table (value_int) values (%s)' % i) curs.execute('insert into test_table (value_int) values (%s)' % i)
@@ -175,9 +165,9 @@ class TestDBConnectionPool(DBTester):
for i in range(10): for i in range(10):
curs2.execute('insert into test_table (value_int) values (%s)' % i) curs2.execute('insert into test_table (value_int) values (%s)' % i)
conn2.commit() conn2.commit()
rows = curs2.execute("select * from test_table") curs2.execute("select * from test_table")
# we should have only inserted them once # we should have only inserted them once
self.assertEqual(10, rows) self.assertEqual(10, curs2.rowcount)
def test_visibility_from_other_connections(self): def test_visibility_from_other_connections(self):
self.pool = self.create_pool(3) self.pool = self.create_pool(3)
@@ -186,26 +176,28 @@ class TestDBConnectionPool(DBTester):
curs = conn.cursor() curs = conn.cursor()
try: try:
curs2 = conn2.cursor() curs2 = conn2.cursor()
rows2 = curs2.execute("insert into gargleblatz (a) values (%s)" % (314159)) curs2.execute("insert into gargleblatz (a) values (%s)" % (314159))
self.assertEqual(rows2, 1) self.assertEqual(curs2.rowcount, 1)
conn2.commit() conn2.commit()
selection_query = "select * from gargleblatz" selection_query = "select * from gargleblatz"
rows2 = curs2.execute(selection_query) curs2.execute(selection_query)
self.assertEqual(rows2, 1) self.assertEqual(curs2.rowcount, 1)
del curs2 del curs2
del conn2 self.pool.put(conn2)
# create a new connection, it should see the addition # create a new connection, it should see the addition
conn3 = self.pool.get() conn3 = self.pool.get()
curs3 = conn3.cursor() curs3 = conn3.cursor()
rows3 = curs3.execute(selection_query) curs3.execute(selection_query)
self.assertEqual(rows3, 1) self.assertEqual(curs3.rowcount, 1)
# now, does the already-open connection see it? # now, does the already-open connection see it?
rows = curs.execute(selection_query) curs.execute(selection_query)
self.assertEqual(rows, 1) self.assertEqual(curs.rowcount, 1)
self.pool.put(conn3)
finally: finally:
# clean up my litter # clean up my litter
curs.execute("delete from gargleblatz where a=314159") curs.execute("delete from gargleblatz where a=314159")
conn.commit() conn.commit()
self.pool.put(conn)
@skipped @skipped
def test_two_simultaneous_connections(self): def test_two_simultaneous_connections(self):
@@ -213,11 +205,11 @@ class TestDBConnectionPool(DBTester):
# way to do this # way to do this
self.pool = self.create_pool(2) self.pool = self.create_pool(2)
conn = self.pool.get() conn = self.pool.get()
self.set_up_test_table(conn) self.set_up_dummy_table(conn)
self.fill_up_table(conn) self.fill_up_table(conn)
curs = conn.cursor() curs = conn.cursor()
conn2 = self.pool.get() conn2 = self.pool.get()
self.set_up_test_table(conn2) self.set_up_dummy_table(conn2)
self.fill_up_table(conn2) self.fill_up_table(conn2)
curs2 = conn2.cursor() curs2 = conn2.cursor()
results = [] results = []
@@ -359,12 +351,6 @@ class TestDBConnectionPool(DBTester):
conn2.close() # should not be added to the free items conn2.close() # should not be added to the free items
self.assertEquals(len(self.pool.free_items), 0) self.assertEquals(len(self.pool.free_items), 0)
def test_connection_timeout(self):
# use a nonexistent ip address -- this one is reserved by IANA
self._auth['host'] = '192.0.2.1'
pool = self.create_pool()
self.assertRaises(db_pool.ConnectTimeout, pool.get)
def test_waiters_get_woken(self): def test_waiters_get_woken(self):
# verify that when there's someone waiting on an empty pool # verify that when there's someone waiting on an empty pool
# and someone puts an immediately-closed connection back in # and someone puts an immediately-closed connection back in
@@ -425,7 +411,7 @@ class RaisingDBModule(object):
raise RuntimeError() raise RuntimeError()
class TestTpoolConnectionPool(TestDBConnectionPool): class TpoolConnectionPool(DBConnectionPool):
__test__ = False # so that nose doesn't try to execute this directly __test__ = False # so that nose doesn't try to execute this directly
def create_pool(self, max_size = 1, max_idle = 10, max_age = 10, connect_timeout=0.5, module=None): def create_pool(self, max_size = 1, max_idle = 10, max_age = 10, connect_timeout=0.5, module=None):
if module is None: if module is None:
@@ -440,16 +426,16 @@ class TestTpoolConnectionPool(TestDBConnectionPool):
def setUp(self): def setUp(self):
from eventlet import tpool from eventlet import tpool
tpool.QUIET = True tpool.QUIET = True
super(TestTpoolConnectionPool, self).setUp() super(TpoolConnectionPool, self).setUp()
def tearDown(self): def tearDown(self):
from eventlet import tpool from eventlet import tpool
tpool.QUIET = False tpool.QUIET = False
tpool.killall() tpool.killall()
super(TestTpoolConnectionPool, self).tearDown() super(TpoolConnectionPool, self).tearDown()
class TestRawConnectionPool(TestDBConnectionPool): class RawConnectionPool(DBConnectionPool):
__test__ = False # so that nose doesn't try to execute this directly __test__ = False # so that nose doesn't try to execute this directly
def create_pool(self, max_size = 1, max_idle = 10, max_age = 10, connect_timeout= 0.5, module=None): def create_pool(self, max_size = 1, max_idle = 10, max_age = 10, connect_timeout= 0.5, module=None):
if module is None: if module is None:
@@ -460,45 +446,69 @@ class TestRawConnectionPool(TestDBConnectionPool):
connect_timeout=connect_timeout, connect_timeout=connect_timeout,
**self._auth) **self._auth)
def test_connection_timeout(self):
pass # not gonna work for raw connections because they're blocking
def get_auth(): def get_auth():
try: """Looks in the local directory and in the user's home directory for a file named ".test_dbauth",
import simplejson which contains a json map of parameters to the connect function.
import os.path """
auth_utf8 = simplejson.load(open(os.path.join(os.path.dirname(__file__), 'auth.json'))) files = [os.path.join(os.path.dirname(__file__), '.test_dbauth'),
# have to convert unicode objects to str objects because mysqldb is dum os.path.join(os.path.expanduser('~'), '.test_dbauth')]
return dict([(str(k), str(v)) for f in files:
for k, v in auth_utf8.items()]) try:
except (IOError, ImportError), e: import simplejson
return {'host': 'localhost','user': 'root','passwd': '','db': 'persist0'} auth_utf8 = simplejson.load(open(f))
# have to convert unicode objects to str objects because mysqldb is dum
# using a doubly-nested list comprehension because we know that the structure
# of the structure is a two-level dict
return dict([(str(modname), dict([(str(k), str(v))
for k, v in connectargs.items()]))
for modname, connectargs in auth_utf8.items()])
except (IOError, ImportError), e:
pass
return {'MySQLdb':{'host': 'localhost','user': 'root','passwd': '','db': 'persist0'},
'psycopg2':{'database':'test', 'user':'test'}}
def mysql_requirement(_f): def mysql_requirement(_f):
try: try:
import MySQLdb import MySQLdb
try: try:
MySQLdb.connect(**get_auth()) auth = get_auth()['MySQLdb'].copy()
auth.pop('db')
MySQLdb.connect(**auth)
return True return True
except MySQLdb.OperationalError: except MySQLdb.OperationalError:
print "Skipping mysql tests, error when connecting"
import traceback
traceback.print_exc()
return False return False
except ImportError: except ImportError:
print "Skipping mysql tests, MySQLdb not importable"
return False return False
class TestMysqlConnectionPool(object): class MysqlConnectionPool(object):
__test__ = True dummy_table_sql = """CREATE TEMPORARY TABLE test_table
(
row_id INTEGER PRIMARY KEY AUTO_INCREMENT,
value_int INTEGER,
value_float FLOAT,
value_string VARCHAR(200),
value_uuid CHAR(36),
value_binary BLOB,
value_binary_string VARCHAR(200) BINARY,
value_enum ENUM('Y','N'),
created TIMESTAMP
) ENGINE=InnoDB;"""
@skip_unless(mysql_requirement) @skip_unless(mysql_requirement)
def setUp(self): def setUp(self):
import MySQLdb import MySQLdb
self._dbmodule = MySQLdb self._dbmodule = MySQLdb
self._auth = get_auth() self._auth = get_auth()['MySQLdb']
super(TestMysqlConnectionPool, self).setUp() super(MysqlConnectionPool, self).setUp()
def tearDown(self): def tearDown(self):
pass super(MysqlConnectionPool, self).tearDown()
def create_db(self): def create_db(self):
auth = self._auth.copy() auth = self._auth.copy()
@@ -519,13 +529,80 @@ class TestMysqlConnectionPool(object):
del db del db
class Test01MysqlTpool(TestMysqlConnectionPool, TestTpoolConnectionPool, TestCase): class Test01MysqlTpool(MysqlConnectionPool, TpoolConnectionPool, TestCase):
pass __test__ = True
class Test02MysqlRaw(TestMysqlConnectionPool, TestRawConnectionPool, TestCase): class Test02MysqlRaw(MysqlConnectionPool, RawConnectionPool, TestCase):
pass __test__ = True
def postgres_requirement(_f):
try:
import psycopg2
try:
auth = get_auth()['psycopg2'].copy()
auth.pop('database')
psycopg2.connect(**auth)
return True
except psycopg2.OperationalError:
print "Skipping postgres tests, error when connecting"
return False
except ImportError:
print "Skipping postgres tests, psycopg2 not importable"
return False
class Psycopg2ConnectionPool(object):
dummy_table_sql = """CREATE TEMPORARY TABLE test_table
(
row_id SERIAL PRIMARY KEY,
value_int INTEGER,
value_float FLOAT,
value_string VARCHAR(200),
value_uuid CHAR(36),
value_binary BYTEA,
value_binary_string BYTEA,
created TIMESTAMP
);"""
@skip_unless(postgres_requirement)
def setUp(self):
import psycopg2
self._dbmodule = psycopg2
self._auth = get_auth()['psycopg2']
super(Psycopg2ConnectionPool, self).setUp()
def tearDown(self):
super(Psycopg2ConnectionPool, self).tearDown()
def create_db(self):
try:
self.drop_db()
except Exception:
pass
auth = self._auth.copy()
dbname = auth.pop('database')
conn = self._dbmodule.connect(**auth)
conn.set_isolation_level(0)
db = conn.cursor()
db.execute("create database "+dbname)
db.close()
del db
def drop_db(self):
auth = self._auth.copy()
dbname = auth.pop('database')
conn = self._dbmodule.connect(**auth)
conn.set_isolation_level(0)
db = conn.cursor()
db.execute("drop database "+self._auth['database'])
db.close()
del db
class Test01Psycopg2Tpool(Psycopg2ConnectionPool, TpoolConnectionPool, TestCase):
__test__ = True
class Test02Psycopg2Raw(Psycopg2ConnectionPool, RawConnectionPool, TestCase):
__test__ = True
if __name__ == '__main__': if __name__ == '__main__':
main() main()

View File

@@ -25,6 +25,7 @@ import eventlet
one = 1 one = 1
two = 2 two = 2
three = 3 three = 3
none = None
def noop(): def noop():
pass pass
@@ -164,6 +165,34 @@ class TestTpool(LimitedTestCase):
tpool.killall() tpool.killall()
tpool.setup() tpool.setup()
@skip_with_pyevent
def test_autowrap(self):
x = tpool.Proxy({'a':1, 'b':2}, autowrap=(int,))
self.assert_(isinstance(x.get('a'), tpool.Proxy))
self.assert_(not isinstance(x.items(), tpool.Proxy))
# attributes as well as callables
from tests import tpool_test
x = tpool.Proxy(tpool_test, autowrap=(int,))
self.assert_(isinstance(x.one, tpool.Proxy))
self.assert_(not isinstance(x.none, tpool.Proxy))
@skip_with_pyevent
def test_autowrap_names(self):
x = tpool.Proxy({'a':1, 'b':2}, autowrap_names=('get',))
self.assert_(isinstance(x.get('a'), tpool.Proxy))
self.assert_(not isinstance(x.items(), tpool.Proxy))
from tests import tpool_test
x = tpool.Proxy(tpool_test, autowrap_names=('one',))
self.assert_(isinstance(x.one, tpool.Proxy))
self.assert_(not isinstance(x.two, tpool.Proxy))
@skip_with_pyevent
def test_autowrap_both(self):
from tests import tpool_test
x = tpool.Proxy(tpool_test, autowrap=(int,), autowrap_names=('one',))
self.assert_(isinstance(x.one, tpool.Proxy))
# violating the abstraction to check that we didn't double-wrap
self.assert_(not isinstance(x._obj, tpool.Proxy))
class TpoolLongTests(LimitedTestCase): class TpoolLongTests(LimitedTestCase):
TEST_TIMEOUT=60 TEST_TIMEOUT=60