diff --git a/eventlet/green/MySQLdb.py b/eventlet/green/MySQLdb.py new file mode 100644 index 0000000..9bcd038 --- /dev/null +++ b/eventlet/green/MySQLdb.py @@ -0,0 +1,32 @@ +__MySQLdb = __import__('MySQLdb') +globals().update(dict([(var, getattr(__MySQLdb, var)) + for var in dir(__MySQLdb) + if not var.startswith('__')])) + +__all__ = __MySQLdb.__all__ +__patched__ = ["connect", "Connect", 'Connection', 'connections'] + +from eventlet import tpool + +__orig_connections = __import__('MySQLdb.connections').connections + +def Connection(*args, **kw): + conn = tpool.execute(__orig_connections.Connection, *args, **kw) + return tpool.Proxy(conn, autowrap_names=('cursor',)) +connect = Connect = Connection + +# replicate the MySQLdb.connections module but with a tpooled Connection factory +class MySQLdbConnectionsModule(object): + pass +connections = MySQLdbConnectionsModule() +for var in dir(__orig_connections): + if not var.startswith('__'): + setattr(connections, var, getattr(__orig_connections, var)) +connections.Connection = Connection + +cursors = __import__('MySQLdb.cursors').cursors +converters = __import__('MySQLdb.converters').converters + +# TODO support instantiating cursors.FooCursor objects directly +# TODO though this is a low priority, it would be nice if we supported +# subclassing eventlet.green.MySQLdb.connections.Connection diff --git a/eventlet/patcher.py b/eventlet/patcher.py index 306ae41..8fe729f 100644 --- a/eventlet/patcher.py +++ b/eventlet/patcher.py @@ -64,7 +64,8 @@ def inject(module_name, new_globals, *additional_modules): _green_select_modules() + _green_socket_modules() + _green_thread_modules() + - _green_time_modules()) + _green_time_modules() + + _green_MySQLdb()) # after this we are gonna screw with sys.modules, so capture the # state of all the modules we're going to mess with, and lock @@ -205,8 +206,9 @@ def monkey_patch(**on): module if present; and thread, which patches thread, threading, and Queue. It's safe to call monkey_patch multiple times. - """ - accepted_args = set(('os', 'select', 'socket', 'thread', 'time', 'psycopg')) + """ + accepted_args = set(('os', 'select', 'socket', + 'thread', 'time', 'psycopg', 'MySQLdb')) default_on = on.pop("all",None) for k in on.iterkeys(): if k not in accepted_args: @@ -215,6 +217,9 @@ def monkey_patch(**on): if default_on is None: default_on = not (True in on.values()) for modname in accepted_args: + if modname == 'MySQLdb': + # MySQLdb is only on when explicitly patched for the moment + on.setdefault(modname, False) on.setdefault(modname, default_on) modules_to_patch = [] @@ -235,6 +240,9 @@ def monkey_patch(**on): if on['time'] and not already_patched.get('time'): modules_to_patch += _green_time_modules() already_patched['time'] = True + if on.get('MySQLdb') and not already_patched.get('MySQLdb'): + modules_to_patch += _green_MySQLdb() + already_patched['MySQLdb'] = True if on['psycopg'] and not already_patched.get('psycopg'): try: from eventlet.support import psycopg2_patcher @@ -316,6 +324,13 @@ def _green_time_modules(): from eventlet.green import time return [('time', time)] +def _green_MySQLdb(): + try: + from eventlet.green import MySQLdb + return [('MySQLdb', MySQLdb)] + except ImportError: + return [] + if __name__ == "__main__": import sys diff --git a/tests/__init__.py b/tests/__init__.py index f5225d3..f7c1d59 100644 --- a/tests/__init__.py +++ b/tests/__init__.py @@ -38,14 +38,17 @@ def skip_if(condition): should return True to skip the test. """ def skipped_wrapper(func): - if isinstance(condition, bool): - result = condition - else: - result = condition(func) - if result: - return skipped(func) - else: - return func + def wrapped(*a, **kw): + if isinstance(condition, bool): + result = condition + else: + result = condition(func) + if result: + return skipped(func)(*a, **kw) + else: + return func(*a, **kw) + wrapped.__name__ = func.__name__ + return wrapped return skipped_wrapper @@ -56,14 +59,17 @@ def skip_unless(condition): should return True if the condition is satisfied. """ def skipped_wrapper(func): - if isinstance(condition, bool): - result = condition - else: - result = condition(func) - if not result: - return skipped(func) - else: - return func + def wrapped(*a, **kw): + if isinstance(condition, bool): + result = condition + else: + result = condition(func) + if not result: + return skipped(func)(*a, **kw) + else: + return func(*a, **kw) + wrapped.__name__ = func.__name__ + return wrapped return skipped_wrapper @@ -81,6 +87,7 @@ def requires_twisted(func): def using_pyevent(_f): from eventlet.hubs import get_hub return 'pyevent' in type(get_hub()).__module__ + def skip_with_pyevent(func): """ Decorator that skips a test if we're using the pyevent hub.""" diff --git a/tests/mysqldb_test.py b/tests/mysqldb_test.py new file mode 100644 index 0000000..4709c4f --- /dev/null +++ b/tests/mysqldb_test.py @@ -0,0 +1,228 @@ +import os +import sys +import time +from tests import skipped, skip_unless, using_pyevent, get_database_auth, LimitedTestCase +import eventlet +from eventlet import event +try: + from eventlet.green import MySQLdb +except ImportError: + MySQLdb = False + +def mysql_requirement(_f): + """We want to skip tests if using pyevent, MySQLdb is not installed, or if + there is no database running on the localhost that the auth file grants + us access to. + + This errs on the side of skipping tests if everything is not right, but + it's better than a million tests failing when you don't care about mysql + support.""" + if using_pyevent(_f): + return False + if MySQLdb is False: + print "Skipping mysql tests, MySQLdb not importable" + return False + try: + auth = get_database_auth()['MySQLdb'].copy() + MySQLdb.connect(**auth) + return True + except MySQLdb.OperationalError: + print "Skipping mysql tests, error when connecting:" + traceback.print_exc() + return False + +class MySQLdbTester(LimitedTestCase): + def setUp(self): + self._auth = get_database_auth()['MySQLdb'] + self.create_db() + self.connection = None + self.connection = MySQLdb.connect(**self._auth) + cursor = self.connection.cursor() + cursor.execute("""CREATE TABLE gargleblatz + ( + a INTEGER + );""") + self.connection.commit() + cursor.close() + + def tearDown(self): + if self.connection: + self.connection.close() + self.drop_db() + + @skip_unless(mysql_requirement) + def create_db(self): + auth = self._auth.copy() + try: + self.drop_db() + except Exception: + pass + dbname = 'test_%d_%d' % (os.getpid(), time.time()*1000) + db = MySQLdb.connect(**auth).cursor() + db.execute("create database "+dbname) + db.close() + self._auth['db'] = dbname + del db + + def drop_db(self): + db = MySQLdb.connect(**self._auth).cursor() + db.execute("drop database "+self._auth['db']) + db.close() + del db + + def set_up_dummy_table(self, connection=None): + close_connection = False + if connection is None: + close_connection = True + if self.connection is None: + connection = MySQLdb.connect(**self._auth) + else: + connection = self.connection + + cursor = connection.cursor() + cursor.execute(self.dummy_table_sql) + connection.commit() + cursor.close() + if close_connection: + connection.close() + + dummy_table_sql = """CREATE TEMPORARY TABLE test_table + ( + row_id INTEGER PRIMARY KEY AUTO_INCREMENT, + value_int INTEGER, + value_float FLOAT, + value_string VARCHAR(200), + value_uuid CHAR(36), + value_binary BLOB, + value_binary_string VARCHAR(200) BINARY, + value_enum ENUM('Y','N'), + created TIMESTAMP + ) ENGINE=InnoDB;""" + + def assert_cursor_yields(self, curs): + counter = [0] + def tick(): + while True: + counter[0] += 1 + eventlet.sleep() + gt = eventlet.spawn(tick) + curs.execute("select 1") + rows = curs.fetchall() + self.assertEqual(rows, ((1L,),)) + self.assert_(counter[0] > 0, counter[0]) + gt.kill() + + def assert_cursor_works(self, cursor): + cursor.execute("select 1") + rows = cursor.fetchall() + self.assertEqual(rows, ((1L,),)) + self.assert_cursor_yields(cursor) + + def assert_connection_works(self, conn): + curs = conn.cursor() + self.assert_cursor_works(curs) + + def test_module_attributes(self): + import MySQLdb as orig + for key in dir(orig): + if key not in ('__author__', '__path__', '__revision__', + '__version__'): + self.assert_(hasattr(MySQLdb, key), "%s %s" % (key, getattr(orig, key))) + + def test_connecting(self): + self.assert_(self.connection is not None) + + def test_connecting_annoyingly(self): + self.assert_connection_works(MySQLdb.Connect(**self._auth)) + self.assert_connection_works(MySQLdb.Connection(**self._auth)) + self.assert_connection_works(MySQLdb.connections.Connection(**self._auth)) + + def test_create_cursor(self): + cursor = self.connection.cursor() + cursor.close() + + def test_run_query(self): + cursor = self.connection.cursor() + self.assert_cursor_works(cursor) + cursor.close() + + def test_run_bad_query(self): + cursor = self.connection.cursor() + try: + cursor.execute("garbage blah blah") + self.assert_(False) + except AssertionError: + raise + except Exception: + pass + cursor.close() + + def fill_up_table(self, conn): + curs = conn.cursor() + for i in range(1000): + curs.execute('insert into test_table (value_int) values (%s)' % i) + conn.commit() + + def test_yields(self): + conn = self.connection + self.set_up_dummy_table(conn) + self.fill_up_table(conn) + curs = conn.cursor() + results = [] + SHORT_QUERY = "select * from test_table" + evt = event.Event() + def a_query(): + self.assert_cursor_works(curs) + curs.execute(SHORT_QUERY) + results.append(2) + evt.send() + eventlet.spawn(a_query) + results.append(1) + self.assertEqual([1], results) + evt.wait() + self.assertEqual([1, 2], results) + + def test_visibility_from_other_connections(self): + conn = MySQLdb.connect(**self._auth) + conn2 = MySQLdb.connect(**self._auth) + curs = conn.cursor() + try: + curs2 = conn2.cursor() + curs2.execute("insert into gargleblatz (a) values (%s)" % (314159)) + self.assertEqual(curs2.rowcount, 1) + conn2.commit() + selection_query = "select * from gargleblatz" + curs2.execute(selection_query) + self.assertEqual(curs2.rowcount, 1) + del curs2, conn2 + # create a new connection, it should see the addition + conn3 = MySQLdb.connect(**self._auth) + curs3 = conn3.cursor() + curs3.execute(selection_query) + self.assertEqual(curs3.rowcount, 1) + # now, does the already-open connection see it? + curs.execute(selection_query) + self.assertEqual(curs.rowcount, 1) + del curs3, conn3 + finally: + # clean up my litter + curs.execute("delete from gargleblatz where a=314159") + conn.commit() + +from tests import patcher_test + +class MonkeyPatchTester(patcher_test.ProcessBase): + @skip_unless(mysql_requirement) + def test_monkey_patching(self): + output, lines = self.run_script(""" +from eventlet import patcher +import MySQLdb as m +from eventlet.green import MySQLdb as gm +patcher.monkey_patch(all=True, MySQLdb=True) +print "mysqltest", ",".join(sorted(patcher.already_patched.keys())) +print "connect", m.connect == gm.connect +""") + self.assertEqual(len(lines), 3) + self.assertEqual(lines[0].replace("psycopg", ""), + 'mysqltest MySQLdb,os,select,socket,thread,time') + self.assertEqual(lines[1], "connect True") diff --git a/tests/patcher_test.py b/tests/patcher_test.py index f409e86..466b385 100644 --- a/tests/patcher_test.py +++ b/tests/patcher_test.py @@ -47,6 +47,8 @@ class ProcessBase(LimitedTestCase): python_path = os.pathsep.join(sys.path + [self.tempdir]) new_env = os.environ.copy() new_env['PYTHONPATH'] = python_path + if not filename.endswith('.py'): + filename = filename + '.py' p = subprocess.Popen([sys.executable, os.path.join(self.tempdir, filename)], stdout=subprocess.PIPE, stderr=subprocess.STDOUT, env=new_env) @@ -54,6 +56,12 @@ class ProcessBase(LimitedTestCase): lines = output.split("\n") return output, lines + def run_script(self, contents, modname=None): + if modname is None: + modname = "testmod" + self.write_to_tempfile(modname, contents) + return self.launch_subprocess(modname) + class ImportPatched(ProcessBase): def test_patch_a_module(self): @@ -157,6 +165,8 @@ print "already_patched", ",".join(sorted(patcher.already_patched.keys())) patched_modules = lines[0][len(ap):].strip() # psycopg might or might not be patched based on installed modules patched_modules = patched_modules.replace("psycopg,", "") + # ditto for MySQLdb + patched_modules = patched_modules.replace("MySQLdb,", "") self.assertEqual(patched_modules, expected, "Logic:%s\nExpected: %s != %s" %(call, expected, patched_modules))