From ebff4177d2f6ac8663f0c281f4ab18ce80b5e309 Mon Sep 17 00:00:00 2001 From: Ryan Williams Date: Sat, 6 Feb 2010 20:27:49 -0800 Subject: [PATCH] Extended tpool's autowrap functionality to use the name of the attribute, which is a lot cleaner than type-checking (which is still supported). Added psycopg tests to db_pool tests. --- eventlet/db_pool.py | 9 +- eventlet/tpool.py | 40 +++++--- tests/db_pool_test.py | 209 +++++++++++++++++++++++++++++------------- tests/tpool_test.py | 29 ++++++ 4 files changed, 201 insertions(+), 86 deletions(-) diff --git a/eventlet/db_pool.py b/eventlet/db_pool.py index 6be0a1c..27c2241 100644 --- a/eventlet/db_pool.py +++ b/eventlet/db_pool.py @@ -57,7 +57,7 @@ class BaseConnectionPool(Pool): possible expiration. If this function is called when a timer is already scheduled, it does - nothing. + nothing. If max_age or max_idle is 0, _schedule_expiration likewise does nothing. """ @@ -251,13 +251,8 @@ class TpooledConnectionPool(BaseConnectionPool): timeout = api.exc_after(connect_timeout, ConnectTimeout()) try: from eventlet import tpool - try: - # *FIX: this is a huge hack that will probably only work for MySQLdb - autowrap = (db_module.cursors.DictCursor,) - except: - autowrap = () conn = tpool.execute(db_module.connect, *args, **kw) - return tpool.Proxy(conn, autowrap=autowrap) + return tpool.Proxy(conn, autowrap_names=('cursor',)) finally: timeout.cancel() diff --git a/eventlet/tpool.py b/eventlet/tpool.py index 8b82b19..bb694db 100644 --- a/eventlet/tpool.py +++ b/eventlet/tpool.py @@ -132,28 +132,42 @@ def proxy_call(autowrap, f, *args, **kwargs): class Proxy(object): """ - A simple proxy-wrapper of any object, in order to forward every method - invocation onto a thread in the native-thread pool. A key restriction is - that the object's methods cannot use Eventlet primitives without great care, - since the Eventlet dispatcher runs on a different native thread. - - Construct the Proxy with the instance that you want proxied. The optional - parameter *autowrap* is used when methods are called on the proxied object. - If a method on the proxied object returns something whose type is in - *autowrap*, then that object gets a Proxy wrapped around it, too. An - example use case for this is ensuring that DB-API connection objects - return cursor objects that are also Proxy-wrapped. + a simple proxy-wrapper of any object that comes with a + methods-only interface, in order to forward every method + invocation onto a thread in the native-thread pool. A key + restriction is that the object's methods should not switch + greenlets or use Eventlet primitives, since they are in a + different thread from the main hub, and therefore might behave + unexpectedly. This is for running native-threaded code + only. + + It's common to want to have some of the attributes or return + values also wrapped in Proxy objects (for example, database + connection objects produce cursor objects which also should be + wrapped in Proxy objects to remain nonblocking). *autowrap*, if + supplied, is a collection of types; if an attribute or return + value matches one of those types (via isinstance), it will be + wrapped in a Proxy. *autowrap_names* is a collection + of strings, which represent the names of attributes that should be + wrapped in Proxy objects when accessed. """ - def __init__(self, obj,autowrap=()): + def __init__(self, obj,autowrap=(), autowrap_names=()): self._obj = obj self._autowrap = autowrap + self._autowrap_names = autowrap_names def __getattr__(self,attr_name): f = getattr(self._obj,attr_name) if not callable(f): + if (isinstance(f, self._autowrap) or + attr_name in self._autowrap_names): + return Proxy(f, self._autowrap) return f def doit(*args, **kwargs): - return proxy_call(self._autowrap, f, *args, **kwargs) + result = proxy_call(self._autowrap, f, *args, **kwargs) + if attr_name in self._autowrap_names and not isinstance(result, Proxy): + return Proxy(result) + return result return doit # the following are a buncha methods that the python interpeter diff --git a/tests/db_pool_test.py b/tests/db_pool_test.py index 15be9c0..c70ab8e 100644 --- a/tests/db_pool_test.py +++ b/tests/db_pool_test.py @@ -5,6 +5,7 @@ from unittest import TestCase, main from eventlet import api from eventlet import event from eventlet import db_pool +import os class DBTester(object): __test__ = False # so that nose doesn't try to execute this directly @@ -16,7 +17,7 @@ class DBTester(object): cursor.execute("""CREATE TABLE gargleblatz ( a INTEGER - ) ENGINE = InnoDB;""") + );""") connection.commit() cursor.close() @@ -25,7 +26,7 @@ class DBTester(object): self.connection.close() self.drop_db() - def set_up_test_table(self, connection = None): + def set_up_dummy_table(self, connection = None): close_connection = False if connection is None: close_connection = True @@ -35,18 +36,7 @@ class DBTester(object): connection = self.connection cursor = connection.cursor() - cursor.execute("""CREATE TEMPORARY TABLE test_table - ( - row_id INTEGER PRIMARY KEY AUTO_INCREMENT, - value_int INTEGER, - value_float FLOAT, - value_string VARCHAR(200), - value_uuid CHAR(36), - value_binary BLOB, - value_binary_string VARCHAR(200) BINARY, - value_enum ENUM('Y','N'), - created TIMESTAMP - ) ENGINE = InnoDB;""") + cursor.execute(self.dummy_table_sql) connection.commit() cursor.close() if close_connection: @@ -57,21 +47,21 @@ class Mock(object): pass -class TestDBConnectionPool(DBTester): +class DBConnectionPool(DBTester): __test__ = False # so that nose doesn't try to execute this directly def setUp(self): - super(TestDBConnectionPool, self).setUp() + super(DBConnectionPool, self).setUp() self.pool = self.create_pool() self.connection = self.pool.get() def tearDown(self): if self.connection: self.pool.put(self.connection) - super(TestDBConnectionPool, self).tearDown() + self.pool.clear() + super(DBConnectionPool, self).tearDown() def assert_cursor_works(self, cursor): - # TODO: this is pretty mysql-specific - cursor.execute("show full processlist") + cursor.execute("select 1") rows = cursor.fetchall() self.assert_(rows) @@ -142,7 +132,7 @@ class TestDBConnectionPool(DBTester): def test_returns_immediately(self): self.pool = self.create_pool() conn = self.pool.get() - self.set_up_test_table(conn) + self.set_up_dummy_table(conn) self.fill_up_table(conn) curs = conn.cursor() results = [] @@ -163,7 +153,7 @@ class TestDBConnectionPool(DBTester): def test_connection_is_clean_after_put(self): self.pool = self.create_pool() conn = self.pool.get() - self.set_up_test_table(conn) + self.set_up_dummy_table(conn) curs = conn.cursor() for i in range(10): curs.execute('insert into test_table (value_int) values (%s)' % i) @@ -175,9 +165,9 @@ class TestDBConnectionPool(DBTester): for i in range(10): curs2.execute('insert into test_table (value_int) values (%s)' % i) conn2.commit() - rows = curs2.execute("select * from test_table") + curs2.execute("select * from test_table") # we should have only inserted them once - self.assertEqual(10, rows) + self.assertEqual(10, curs2.rowcount) def test_visibility_from_other_connections(self): self.pool = self.create_pool(3) @@ -186,26 +176,28 @@ class TestDBConnectionPool(DBTester): curs = conn.cursor() try: curs2 = conn2.cursor() - rows2 = curs2.execute("insert into gargleblatz (a) values (%s)" % (314159)) - self.assertEqual(rows2, 1) + curs2.execute("insert into gargleblatz (a) values (%s)" % (314159)) + self.assertEqual(curs2.rowcount, 1) conn2.commit() selection_query = "select * from gargleblatz" - rows2 = curs2.execute(selection_query) - self.assertEqual(rows2, 1) + curs2.execute(selection_query) + self.assertEqual(curs2.rowcount, 1) del curs2 - del conn2 + self.pool.put(conn2) # create a new connection, it should see the addition conn3 = self.pool.get() curs3 = conn3.cursor() - rows3 = curs3.execute(selection_query) - self.assertEqual(rows3, 1) + curs3.execute(selection_query) + self.assertEqual(curs3.rowcount, 1) # now, does the already-open connection see it? - rows = curs.execute(selection_query) - self.assertEqual(rows, 1) + curs.execute(selection_query) + self.assertEqual(curs.rowcount, 1) + self.pool.put(conn3) finally: # clean up my litter curs.execute("delete from gargleblatz where a=314159") conn.commit() + self.pool.put(conn) @skipped def test_two_simultaneous_connections(self): @@ -213,11 +205,11 @@ class TestDBConnectionPool(DBTester): # way to do this self.pool = self.create_pool(2) conn = self.pool.get() - self.set_up_test_table(conn) + self.set_up_dummy_table(conn) self.fill_up_table(conn) curs = conn.cursor() conn2 = self.pool.get() - self.set_up_test_table(conn2) + self.set_up_dummy_table(conn2) self.fill_up_table(conn2) curs2 = conn2.cursor() results = [] @@ -359,12 +351,6 @@ class TestDBConnectionPool(DBTester): conn2.close() # should not be added to the free items self.assertEquals(len(self.pool.free_items), 0) - def test_connection_timeout(self): - # use a nonexistent ip address -- this one is reserved by IANA - self._auth['host'] = '192.0.2.1' - pool = self.create_pool() - self.assertRaises(db_pool.ConnectTimeout, pool.get) - def test_waiters_get_woken(self): # verify that when there's someone waiting on an empty pool # and someone puts an immediately-closed connection back in @@ -425,7 +411,7 @@ class RaisingDBModule(object): raise RuntimeError() -class TestTpoolConnectionPool(TestDBConnectionPool): +class TpoolConnectionPool(DBConnectionPool): __test__ = False # so that nose doesn't try to execute this directly def create_pool(self, max_size = 1, max_idle = 10, max_age = 10, connect_timeout=0.5, module=None): if module is None: @@ -440,16 +426,16 @@ class TestTpoolConnectionPool(TestDBConnectionPool): def setUp(self): from eventlet import tpool tpool.QUIET = True - super(TestTpoolConnectionPool, self).setUp() + super(TpoolConnectionPool, self).setUp() def tearDown(self): from eventlet import tpool tpool.QUIET = False tpool.killall() - super(TestTpoolConnectionPool, self).tearDown() + super(TpoolConnectionPool, self).tearDown() -class TestRawConnectionPool(TestDBConnectionPool): +class RawConnectionPool(DBConnectionPool): __test__ = False # so that nose doesn't try to execute this directly def create_pool(self, max_size = 1, max_idle = 10, max_age = 10, connect_timeout= 0.5, module=None): if module is None: @@ -460,45 +446,69 @@ class TestRawConnectionPool(TestDBConnectionPool): connect_timeout=connect_timeout, **self._auth) - def test_connection_timeout(self): - pass # not gonna work for raw connections because they're blocking - def get_auth(): - try: - import simplejson - import os.path - auth_utf8 = simplejson.load(open(os.path.join(os.path.dirname(__file__), 'auth.json'))) - # have to convert unicode objects to str objects because mysqldb is dum - return dict([(str(k), str(v)) - for k, v in auth_utf8.items()]) - except (IOError, ImportError), e: - return {'host': 'localhost','user': 'root','passwd': '','db': 'persist0'} + """Looks in the local directory and in the user's home directory for a file named ".test_dbauth", + which contains a json map of parameters to the connect function. + """ + files = [os.path.join(os.path.dirname(__file__), '.test_dbauth'), + os.path.join(os.path.expanduser('~'), '.test_dbauth')] + for f in files: + try: + import simplejson + auth_utf8 = simplejson.load(open(f)) + # have to convert unicode objects to str objects because mysqldb is dum + # using a doubly-nested list comprehension because we know that the structure + # of the structure is a two-level dict + return dict([(str(modname), dict([(str(k), str(v)) + for k, v in connectargs.items()])) + for modname, connectargs in auth_utf8.items()]) + except (IOError, ImportError), e: + pass + return {'MySQLdb':{'host': 'localhost','user': 'root','passwd': '','db': 'persist0'}, + 'psycopg2':{'database':'test', 'user':'test'}} def mysql_requirement(_f): try: import MySQLdb try: - MySQLdb.connect(**get_auth()) + auth = get_auth()['MySQLdb'].copy() + auth.pop('db') + MySQLdb.connect(**auth) return True except MySQLdb.OperationalError: + print "Skipping mysql tests, error when connecting" + import traceback + traceback.print_exc() return False except ImportError: + print "Skipping mysql tests, MySQLdb not importable" return False -class TestMysqlConnectionPool(object): - __test__ = True +class MysqlConnectionPool(object): + dummy_table_sql = """CREATE TEMPORARY TABLE test_table + ( + row_id INTEGER PRIMARY KEY AUTO_INCREMENT, + value_int INTEGER, + value_float FLOAT, + value_string VARCHAR(200), + value_uuid CHAR(36), + value_binary BLOB, + value_binary_string VARCHAR(200) BINARY, + value_enum ENUM('Y','N'), + created TIMESTAMP + ) ENGINE=InnoDB;""" @skip_unless(mysql_requirement) def setUp(self): import MySQLdb self._dbmodule = MySQLdb - self._auth = get_auth() - super(TestMysqlConnectionPool, self).setUp() + self._auth = get_auth()['MySQLdb'] + super(MysqlConnectionPool, self).setUp() def tearDown(self): - pass + super(MysqlConnectionPool, self).tearDown() def create_db(self): auth = self._auth.copy() @@ -519,13 +529,80 @@ class TestMysqlConnectionPool(object): del db -class Test01MysqlTpool(TestMysqlConnectionPool, TestTpoolConnectionPool, TestCase): - pass +class Test01MysqlTpool(MysqlConnectionPool, TpoolConnectionPool, TestCase): + __test__ = True -class Test02MysqlRaw(TestMysqlConnectionPool, TestRawConnectionPool, TestCase): - pass +class Test02MysqlRaw(MysqlConnectionPool, RawConnectionPool, TestCase): + __test__ = True + +def postgres_requirement(_f): + try: + import psycopg2 + try: + auth = get_auth()['psycopg2'].copy() + auth.pop('database') + psycopg2.connect(**auth) + return True + except psycopg2.OperationalError: + print "Skipping postgres tests, error when connecting" + return False + except ImportError: + print "Skipping postgres tests, psycopg2 not importable" + return False +class Psycopg2ConnectionPool(object): + dummy_table_sql = """CREATE TEMPORARY TABLE test_table + ( + row_id SERIAL PRIMARY KEY, + value_int INTEGER, + value_float FLOAT, + value_string VARCHAR(200), + value_uuid CHAR(36), + value_binary BYTEA, + value_binary_string BYTEA, + created TIMESTAMP + );""" + + @skip_unless(postgres_requirement) + def setUp(self): + import psycopg2 + self._dbmodule = psycopg2 + self._auth = get_auth()['psycopg2'] + super(Psycopg2ConnectionPool, self).setUp() + + def tearDown(self): + super(Psycopg2ConnectionPool, self).tearDown() + + def create_db(self): + try: + self.drop_db() + except Exception: + pass + auth = self._auth.copy() + dbname = auth.pop('database') + conn = self._dbmodule.connect(**auth) + conn.set_isolation_level(0) + db = conn.cursor() + db.execute("create database "+dbname) + db.close() + del db + + def drop_db(self): + auth = self._auth.copy() + dbname = auth.pop('database') + conn = self._dbmodule.connect(**auth) + conn.set_isolation_level(0) + db = conn.cursor() + db.execute("drop database "+self._auth['database']) + db.close() + del db + +class Test01Psycopg2Tpool(Psycopg2ConnectionPool, TpoolConnectionPool, TestCase): + __test__ = True + +class Test02Psycopg2Raw(Psycopg2ConnectionPool, RawConnectionPool, TestCase): + __test__ = True if __name__ == '__main__': main() diff --git a/tests/tpool_test.py b/tests/tpool_test.py index d0f1a01..4696bed 100644 --- a/tests/tpool_test.py +++ b/tests/tpool_test.py @@ -25,6 +25,7 @@ import eventlet one = 1 two = 2 three = 3 +none = None def noop(): pass @@ -164,6 +165,34 @@ class TestTpool(LimitedTestCase): tpool.killall() tpool.setup() + @skip_with_pyevent + def test_autowrap(self): + x = tpool.Proxy({'a':1, 'b':2}, autowrap=(int,)) + self.assert_(isinstance(x.get('a'), tpool.Proxy)) + self.assert_(not isinstance(x.items(), tpool.Proxy)) + # attributes as well as callables + from tests import tpool_test + x = tpool.Proxy(tpool_test, autowrap=(int,)) + self.assert_(isinstance(x.one, tpool.Proxy)) + self.assert_(not isinstance(x.none, tpool.Proxy)) + + @skip_with_pyevent + def test_autowrap_names(self): + x = tpool.Proxy({'a':1, 'b':2}, autowrap_names=('get',)) + self.assert_(isinstance(x.get('a'), tpool.Proxy)) + self.assert_(not isinstance(x.items(), tpool.Proxy)) + from tests import tpool_test + x = tpool.Proxy(tpool_test, autowrap_names=('one',)) + self.assert_(isinstance(x.one, tpool.Proxy)) + self.assert_(not isinstance(x.two, tpool.Proxy)) + + @skip_with_pyevent + def test_autowrap_both(self): + from tests import tpool_test + x = tpool.Proxy(tpool_test, autowrap=(int,), autowrap_names=('one',)) + self.assert_(isinstance(x.one, tpool.Proxy)) + # violating the abstraction to check that we didn't double-wrap + self.assert_(not isinstance(x._obj, tpool.Proxy)) class TpoolLongTests(LimitedTestCase): TEST_TIMEOUT=60