Extended tpool's autowrap functionality to use the name of the attribute, which is a lot cleaner than type-checking (which is still supported). Added psycopg tests to db_pool tests.

This commit is contained in:
Ryan Williams
2010-02-06 20:27:49 -08:00
parent 363653c9c5
commit ebff4177d2
4 changed files with 201 additions and 86 deletions

View File

@@ -57,7 +57,7 @@ class BaseConnectionPool(Pool):
possible expiration.
If this function is called when a timer is already scheduled, it does
nothing.
nothing.
If max_age or max_idle is 0, _schedule_expiration likewise does nothing.
"""
@@ -251,13 +251,8 @@ class TpooledConnectionPool(BaseConnectionPool):
timeout = api.exc_after(connect_timeout, ConnectTimeout())
try:
from eventlet import tpool
try:
# *FIX: this is a huge hack that will probably only work for MySQLdb
autowrap = (db_module.cursors.DictCursor,)
except:
autowrap = ()
conn = tpool.execute(db_module.connect, *args, **kw)
return tpool.Proxy(conn, autowrap=autowrap)
return tpool.Proxy(conn, autowrap_names=('cursor',))
finally:
timeout.cancel()

View File

@@ -132,28 +132,42 @@ def proxy_call(autowrap, f, *args, **kwargs):
class Proxy(object):
"""
A simple proxy-wrapper of any object, in order to forward every method
invocation onto a thread in the native-thread pool. A key restriction is
that the object's methods cannot use Eventlet primitives without great care,
since the Eventlet dispatcher runs on a different native thread.
Construct the Proxy with the instance that you want proxied. The optional
parameter *autowrap* is used when methods are called on the proxied object.
If a method on the proxied object returns something whose type is in
*autowrap*, then that object gets a Proxy wrapped around it, too. An
example use case for this is ensuring that DB-API connection objects
return cursor objects that are also Proxy-wrapped.
a simple proxy-wrapper of any object that comes with a
methods-only interface, in order to forward every method
invocation onto a thread in the native-thread pool. A key
restriction is that the object's methods should not switch
greenlets or use Eventlet primitives, since they are in a
different thread from the main hub, and therefore might behave
unexpectedly. This is for running native-threaded code
only.
It's common to want to have some of the attributes or return
values also wrapped in Proxy objects (for example, database
connection objects produce cursor objects which also should be
wrapped in Proxy objects to remain nonblocking). *autowrap*, if
supplied, is a collection of types; if an attribute or return
value matches one of those types (via isinstance), it will be
wrapped in a Proxy. *autowrap_names* is a collection
of strings, which represent the names of attributes that should be
wrapped in Proxy objects when accessed.
"""
def __init__(self, obj,autowrap=()):
def __init__(self, obj,autowrap=(), autowrap_names=()):
self._obj = obj
self._autowrap = autowrap
self._autowrap_names = autowrap_names
def __getattr__(self,attr_name):
f = getattr(self._obj,attr_name)
if not callable(f):
if (isinstance(f, self._autowrap) or
attr_name in self._autowrap_names):
return Proxy(f, self._autowrap)
return f
def doit(*args, **kwargs):
return proxy_call(self._autowrap, f, *args, **kwargs)
result = proxy_call(self._autowrap, f, *args, **kwargs)
if attr_name in self._autowrap_names and not isinstance(result, Proxy):
return Proxy(result)
return result
return doit
# the following are a buncha methods that the python interpeter

View File

@@ -5,6 +5,7 @@ from unittest import TestCase, main
from eventlet import api
from eventlet import event
from eventlet import db_pool
import os
class DBTester(object):
__test__ = False # so that nose doesn't try to execute this directly
@@ -16,7 +17,7 @@ class DBTester(object):
cursor.execute("""CREATE TABLE gargleblatz
(
a INTEGER
) ENGINE = InnoDB;""")
);""")
connection.commit()
cursor.close()
@@ -25,7 +26,7 @@ class DBTester(object):
self.connection.close()
self.drop_db()
def set_up_test_table(self, connection = None):
def set_up_dummy_table(self, connection = None):
close_connection = False
if connection is None:
close_connection = True
@@ -35,18 +36,7 @@ class DBTester(object):
connection = self.connection
cursor = connection.cursor()
cursor.execute("""CREATE TEMPORARY TABLE test_table
(
row_id INTEGER PRIMARY KEY AUTO_INCREMENT,
value_int INTEGER,
value_float FLOAT,
value_string VARCHAR(200),
value_uuid CHAR(36),
value_binary BLOB,
value_binary_string VARCHAR(200) BINARY,
value_enum ENUM('Y','N'),
created TIMESTAMP
) ENGINE = InnoDB;""")
cursor.execute(self.dummy_table_sql)
connection.commit()
cursor.close()
if close_connection:
@@ -57,21 +47,21 @@ class Mock(object):
pass
class TestDBConnectionPool(DBTester):
class DBConnectionPool(DBTester):
__test__ = False # so that nose doesn't try to execute this directly
def setUp(self):
super(TestDBConnectionPool, self).setUp()
super(DBConnectionPool, self).setUp()
self.pool = self.create_pool()
self.connection = self.pool.get()
def tearDown(self):
if self.connection:
self.pool.put(self.connection)
super(TestDBConnectionPool, self).tearDown()
self.pool.clear()
super(DBConnectionPool, self).tearDown()
def assert_cursor_works(self, cursor):
# TODO: this is pretty mysql-specific
cursor.execute("show full processlist")
cursor.execute("select 1")
rows = cursor.fetchall()
self.assert_(rows)
@@ -142,7 +132,7 @@ class TestDBConnectionPool(DBTester):
def test_returns_immediately(self):
self.pool = self.create_pool()
conn = self.pool.get()
self.set_up_test_table(conn)
self.set_up_dummy_table(conn)
self.fill_up_table(conn)
curs = conn.cursor()
results = []
@@ -163,7 +153,7 @@ class TestDBConnectionPool(DBTester):
def test_connection_is_clean_after_put(self):
self.pool = self.create_pool()
conn = self.pool.get()
self.set_up_test_table(conn)
self.set_up_dummy_table(conn)
curs = conn.cursor()
for i in range(10):
curs.execute('insert into test_table (value_int) values (%s)' % i)
@@ -175,9 +165,9 @@ class TestDBConnectionPool(DBTester):
for i in range(10):
curs2.execute('insert into test_table (value_int) values (%s)' % i)
conn2.commit()
rows = curs2.execute("select * from test_table")
curs2.execute("select * from test_table")
# we should have only inserted them once
self.assertEqual(10, rows)
self.assertEqual(10, curs2.rowcount)
def test_visibility_from_other_connections(self):
self.pool = self.create_pool(3)
@@ -186,26 +176,28 @@ class TestDBConnectionPool(DBTester):
curs = conn.cursor()
try:
curs2 = conn2.cursor()
rows2 = curs2.execute("insert into gargleblatz (a) values (%s)" % (314159))
self.assertEqual(rows2, 1)
curs2.execute("insert into gargleblatz (a) values (%s)" % (314159))
self.assertEqual(curs2.rowcount, 1)
conn2.commit()
selection_query = "select * from gargleblatz"
rows2 = curs2.execute(selection_query)
self.assertEqual(rows2, 1)
curs2.execute(selection_query)
self.assertEqual(curs2.rowcount, 1)
del curs2
del conn2
self.pool.put(conn2)
# create a new connection, it should see the addition
conn3 = self.pool.get()
curs3 = conn3.cursor()
rows3 = curs3.execute(selection_query)
self.assertEqual(rows3, 1)
curs3.execute(selection_query)
self.assertEqual(curs3.rowcount, 1)
# now, does the already-open connection see it?
rows = curs.execute(selection_query)
self.assertEqual(rows, 1)
curs.execute(selection_query)
self.assertEqual(curs.rowcount, 1)
self.pool.put(conn3)
finally:
# clean up my litter
curs.execute("delete from gargleblatz where a=314159")
conn.commit()
self.pool.put(conn)
@skipped
def test_two_simultaneous_connections(self):
@@ -213,11 +205,11 @@ class TestDBConnectionPool(DBTester):
# way to do this
self.pool = self.create_pool(2)
conn = self.pool.get()
self.set_up_test_table(conn)
self.set_up_dummy_table(conn)
self.fill_up_table(conn)
curs = conn.cursor()
conn2 = self.pool.get()
self.set_up_test_table(conn2)
self.set_up_dummy_table(conn2)
self.fill_up_table(conn2)
curs2 = conn2.cursor()
results = []
@@ -359,12 +351,6 @@ class TestDBConnectionPool(DBTester):
conn2.close() # should not be added to the free items
self.assertEquals(len(self.pool.free_items), 0)
def test_connection_timeout(self):
# use a nonexistent ip address -- this one is reserved by IANA
self._auth['host'] = '192.0.2.1'
pool = self.create_pool()
self.assertRaises(db_pool.ConnectTimeout, pool.get)
def test_waiters_get_woken(self):
# verify that when there's someone waiting on an empty pool
# and someone puts an immediately-closed connection back in
@@ -425,7 +411,7 @@ class RaisingDBModule(object):
raise RuntimeError()
class TestTpoolConnectionPool(TestDBConnectionPool):
class TpoolConnectionPool(DBConnectionPool):
__test__ = False # so that nose doesn't try to execute this directly
def create_pool(self, max_size = 1, max_idle = 10, max_age = 10, connect_timeout=0.5, module=None):
if module is None:
@@ -440,16 +426,16 @@ class TestTpoolConnectionPool(TestDBConnectionPool):
def setUp(self):
from eventlet import tpool
tpool.QUIET = True
super(TestTpoolConnectionPool, self).setUp()
super(TpoolConnectionPool, self).setUp()
def tearDown(self):
from eventlet import tpool
tpool.QUIET = False
tpool.killall()
super(TestTpoolConnectionPool, self).tearDown()
super(TpoolConnectionPool, self).tearDown()
class TestRawConnectionPool(TestDBConnectionPool):
class RawConnectionPool(DBConnectionPool):
__test__ = False # so that nose doesn't try to execute this directly
def create_pool(self, max_size = 1, max_idle = 10, max_age = 10, connect_timeout= 0.5, module=None):
if module is None:
@@ -460,45 +446,69 @@ class TestRawConnectionPool(TestDBConnectionPool):
connect_timeout=connect_timeout,
**self._auth)
def test_connection_timeout(self):
pass # not gonna work for raw connections because they're blocking
def get_auth():
try:
import simplejson
import os.path
auth_utf8 = simplejson.load(open(os.path.join(os.path.dirname(__file__), 'auth.json')))
# have to convert unicode objects to str objects because mysqldb is dum
return dict([(str(k), str(v))
for k, v in auth_utf8.items()])
except (IOError, ImportError), e:
return {'host': 'localhost','user': 'root','passwd': '','db': 'persist0'}
"""Looks in the local directory and in the user's home directory for a file named ".test_dbauth",
which contains a json map of parameters to the connect function.
"""
files = [os.path.join(os.path.dirname(__file__), '.test_dbauth'),
os.path.join(os.path.expanduser('~'), '.test_dbauth')]
for f in files:
try:
import simplejson
auth_utf8 = simplejson.load(open(f))
# have to convert unicode objects to str objects because mysqldb is dum
# using a doubly-nested list comprehension because we know that the structure
# of the structure is a two-level dict
return dict([(str(modname), dict([(str(k), str(v))
for k, v in connectargs.items()]))
for modname, connectargs in auth_utf8.items()])
except (IOError, ImportError), e:
pass
return {'MySQLdb':{'host': 'localhost','user': 'root','passwd': '','db': 'persist0'},
'psycopg2':{'database':'test', 'user':'test'}}
def mysql_requirement(_f):
try:
import MySQLdb
try:
MySQLdb.connect(**get_auth())
auth = get_auth()['MySQLdb'].copy()
auth.pop('db')
MySQLdb.connect(**auth)
return True
except MySQLdb.OperationalError:
print "Skipping mysql tests, error when connecting"
import traceback
traceback.print_exc()
return False
except ImportError:
print "Skipping mysql tests, MySQLdb not importable"
return False
class TestMysqlConnectionPool(object):
__test__ = True
class MysqlConnectionPool(object):
dummy_table_sql = """CREATE TEMPORARY TABLE test_table
(
row_id INTEGER PRIMARY KEY AUTO_INCREMENT,
value_int INTEGER,
value_float FLOAT,
value_string VARCHAR(200),
value_uuid CHAR(36),
value_binary BLOB,
value_binary_string VARCHAR(200) BINARY,
value_enum ENUM('Y','N'),
created TIMESTAMP
) ENGINE=InnoDB;"""
@skip_unless(mysql_requirement)
def setUp(self):
import MySQLdb
self._dbmodule = MySQLdb
self._auth = get_auth()
super(TestMysqlConnectionPool, self).setUp()
self._auth = get_auth()['MySQLdb']
super(MysqlConnectionPool, self).setUp()
def tearDown(self):
pass
super(MysqlConnectionPool, self).tearDown()
def create_db(self):
auth = self._auth.copy()
@@ -519,13 +529,80 @@ class TestMysqlConnectionPool(object):
del db
class Test01MysqlTpool(TestMysqlConnectionPool, TestTpoolConnectionPool, TestCase):
pass
class Test01MysqlTpool(MysqlConnectionPool, TpoolConnectionPool, TestCase):
__test__ = True
class Test02MysqlRaw(TestMysqlConnectionPool, TestRawConnectionPool, TestCase):
pass
class Test02MysqlRaw(MysqlConnectionPool, RawConnectionPool, TestCase):
__test__ = True
def postgres_requirement(_f):
try:
import psycopg2
try:
auth = get_auth()['psycopg2'].copy()
auth.pop('database')
psycopg2.connect(**auth)
return True
except psycopg2.OperationalError:
print "Skipping postgres tests, error when connecting"
return False
except ImportError:
print "Skipping postgres tests, psycopg2 not importable"
return False
class Psycopg2ConnectionPool(object):
dummy_table_sql = """CREATE TEMPORARY TABLE test_table
(
row_id SERIAL PRIMARY KEY,
value_int INTEGER,
value_float FLOAT,
value_string VARCHAR(200),
value_uuid CHAR(36),
value_binary BYTEA,
value_binary_string BYTEA,
created TIMESTAMP
);"""
@skip_unless(postgres_requirement)
def setUp(self):
import psycopg2
self._dbmodule = psycopg2
self._auth = get_auth()['psycopg2']
super(Psycopg2ConnectionPool, self).setUp()
def tearDown(self):
super(Psycopg2ConnectionPool, self).tearDown()
def create_db(self):
try:
self.drop_db()
except Exception:
pass
auth = self._auth.copy()
dbname = auth.pop('database')
conn = self._dbmodule.connect(**auth)
conn.set_isolation_level(0)
db = conn.cursor()
db.execute("create database "+dbname)
db.close()
del db
def drop_db(self):
auth = self._auth.copy()
dbname = auth.pop('database')
conn = self._dbmodule.connect(**auth)
conn.set_isolation_level(0)
db = conn.cursor()
db.execute("drop database "+self._auth['database'])
db.close()
del db
class Test01Psycopg2Tpool(Psycopg2ConnectionPool, TpoolConnectionPool, TestCase):
__test__ = True
class Test02Psycopg2Raw(Psycopg2ConnectionPool, RawConnectionPool, TestCase):
__test__ = True
if __name__ == '__main__':
main()

View File

@@ -25,6 +25,7 @@ import eventlet
one = 1
two = 2
three = 3
none = None
def noop():
pass
@@ -164,6 +165,34 @@ class TestTpool(LimitedTestCase):
tpool.killall()
tpool.setup()
@skip_with_pyevent
def test_autowrap(self):
x = tpool.Proxy({'a':1, 'b':2}, autowrap=(int,))
self.assert_(isinstance(x.get('a'), tpool.Proxy))
self.assert_(not isinstance(x.items(), tpool.Proxy))
# attributes as well as callables
from tests import tpool_test
x = tpool.Proxy(tpool_test, autowrap=(int,))
self.assert_(isinstance(x.one, tpool.Proxy))
self.assert_(not isinstance(x.none, tpool.Proxy))
@skip_with_pyevent
def test_autowrap_names(self):
x = tpool.Proxy({'a':1, 'b':2}, autowrap_names=('get',))
self.assert_(isinstance(x.get('a'), tpool.Proxy))
self.assert_(not isinstance(x.items(), tpool.Proxy))
from tests import tpool_test
x = tpool.Proxy(tpool_test, autowrap_names=('one',))
self.assert_(isinstance(x.one, tpool.Proxy))
self.assert_(not isinstance(x.two, tpool.Proxy))
@skip_with_pyevent
def test_autowrap_both(self):
from tests import tpool_test
x = tpool.Proxy(tpool_test, autowrap=(int,), autowrap_names=('one',))
self.assert_(isinstance(x.one, tpool.Proxy))
# violating the abstraction to check that we didn't double-wrap
self.assert_(not isinstance(x._obj, tpool.Proxy))
class TpoolLongTests(LimitedTestCase):
TEST_TIMEOUT=60