diff --git a/tests/__init__.py b/tests/__init__.py index 84f21ee..0ed0338 100644 --- a/tests/__init__.py +++ b/tests/__init__.py @@ -155,3 +155,38 @@ def silence_warnings(func): warnings.simplefilter('default', DeprecationWarning) wrapper.__name__ = func.__name__ return wrapper + + +def get_database_auth(): + """Retrieves a dict of connection parameters for connecting to test databases. + + Authentication parameters are highly-machine specific, so + get_database_auth gets its information from either environment + variables or a config file. The environment variable is + "EVENTLET_DB_TEST_AUTH" and it should contain a json object. If + this environment variable is present, it's used and config files + are ignored. If it's not present, it looks in the local directory + (tests) and in the user's home directory for a file named + ".test_dbauth", which contains a json map of parameters to the + connect function. + """ + import os + import simplejson + if 'EVENTLET_DB_TEST_AUTH' in os.environ: + return simplejson.loads(os.environ.get('EVENTLET_DB_TEST_AUTH')) + files = [os.path.join(os.path.dirname(__file__), '.test_dbauth'), + os.path.join(os.path.expanduser('~'), '.test_dbauth')] + for f in files: + try: + auth_utf8 = simplejson.load(open(f)) + # have to convert unicode objects to str objects because mysqldb is dum + # using a doubly-nested list comprehension because we know that the structure + # of the structure is a two-level dict + return dict([(str(modname), dict([(str(k), str(v)) + for k, v in connectargs.items()])) + for modname, connectargs in auth_utf8.items()]) + except (IOError, ImportError): + pass + return {'MySQLdb':{'host': 'localhost','user': 'root','passwd': ''}, + 'psycopg2':{'user':'test'}} + diff --git a/tests/db_pool_test.py b/tests/db_pool_test.py index c0b1093..ed6f09b 100644 --- a/tests/db_pool_test.py +++ b/tests/db_pool_test.py @@ -4,7 +4,7 @@ import os import traceback from unittest import TestCase, main -from tests import skipped, skip_unless, skip_with_pyevent +from tests import skipped, skip_unless, skip_with_pyevent, get_database_auth from eventlet import event from eventlet import db_pool import eventlet @@ -472,29 +472,7 @@ class RawConnectionPool(DBConnectionPool): connect_timeout=connect_timeout, **self._auth) - -def get_auth(): - """Looks in the local directory and in the user's home directory - for a file named ".test_dbauth", which contains a json map of - parameters to the connect function. - """ - files = [os.path.join(os.path.dirname(__file__), '.test_dbauth'), - os.path.join(os.path.expanduser('~'), '.test_dbauth')] - for f in files: - try: - import simplejson - auth_utf8 = simplejson.load(open(f)) - # have to convert unicode objects to str objects because mysqldb is dum - # using a doubly-nested list comprehension because we know that the structure - # of the structure is a two-level dict - return dict([(str(modname), dict([(str(k), str(v)) - for k, v in connectargs.items()])) - for modname, connectargs in auth_utf8.items()]) - except (IOError, ImportError): - pass - return {'MySQLdb':{'host': 'localhost','user': 'root','passwd': ''}, - 'psycopg2':{'user':'test'}} - +get_auth = get_database_auth def mysql_requirement(_f): verbose = os.environ.get('eventlet_test_mysql_verbose') diff --git a/tests/patcher_psycopg_test.py b/tests/patcher_psycopg_test.py index 2bb6395..f63801e 100644 --- a/tests/patcher_psycopg_test.py +++ b/tests/patcher_psycopg_test.py @@ -1,6 +1,7 @@ import os from tests import patcher_test +from tests import get_database_auth psycopg_test_file = """ import os @@ -36,7 +37,13 @@ print "done" class PatchingPsycopg(patcher_test.Patcher): def test_psycopg_pached(self): if 'PSYCOPG_TEST_DSN' not in os.environ: - os.environ['PSYCOPG_TEST_DSN'] = 'dbname=postgres' + # construct a non-json dsn for the subprocess + psycopg_auth = get_database_auth()['psycopg2'] + if isinstance(psycopg_auth,str): + dsn = psycopg_auth + else: + dsn = " ".join(["%s=%s" % (k,v) for k,v, in psycopg_auth.iteritems()]) + os.environ['PSYCOPG_TEST_DSN'] = dsn self.write_to_tempfile("psycopg_patcher", psycopg_test_file) output, lines = self.launch_subprocess('psycopg_patcher.py') if lines[0].startswith('Psycopg not monkeypatched'):