From 81505533675d3edff88ab0dc1ac04b0a558bbbbe Mon Sep 17 00:00:00 2001 From: Joshua Harlow Date: Thu, 4 Dec 2014 16:11:56 -0800 Subject: [PATCH] Fix split on "+" for connection strings that specify dialects Fixes bug 1399486 Change-Id: I3b7e6331751f25d9c4221393e8329934925791e7 --- taskflow/persistence/backends/__init__.py | 5 +++++ .../unit/persistence/test_sql_persistence.py | 16 ++++++++++++---- 2 files changed, 17 insertions(+), 4 deletions(-) diff --git a/taskflow/persistence/backends/__init__.py b/taskflow/persistence/backends/__init__.py index 64b7cda1..30279b90 100644 --- a/taskflow/persistence/backends/__init__.py +++ b/taskflow/persistence/backends/__init__.py @@ -58,6 +58,11 @@ def fetch(conf, namespace=BACKEND_NAMESPACE, **kwargs): else: backend_name = uri.scheme conf = misc.merge_uri(uri, conf.copy()) + # If the backend is like 'mysql+pymysql://...' which informs the + # backend to use a dialect (supported by sqlalchemy at least) we just want + # to look at the first component to find our entrypoint backend name... + if backend_name.find("+") != -1: + backend_name = backend_name.split("+", 1)[0] LOG.debug('Looking for %r backend driver in %r', backend_name, namespace) try: mgr = driver.DriverManager(namespace, backend_name, diff --git a/taskflow/tests/unit/persistence/test_sql_persistence.py b/taskflow/tests/unit/persistence/test_sql_persistence.py index 229ef310..8489160d 100644 --- a/taskflow/tests/unit/persistence/test_sql_persistence.py +++ b/taskflow/tests/unit/persistence/test_sql_persistence.py @@ -145,6 +145,14 @@ class BackendPersistenceTestMixin(base.PersistenceTestMixin): def _get_connection(self): return self.backend.get_connection() + def test_entrypoint(self): + # Test that the entrypoint fetching also works (even with dialects) + # using the same configuration we used in setUp() but not using + # the impl_sqlalchemy SQLAlchemyBackend class directly... + with contextlib.closing(backends.fetch(self.db_conf)) as backend: + with contextlib.closing(backend.get_connection()): + pass + @abc.abstractmethod def _init_db(self): """Sets up the database, and returns the uri to that database.""" @@ -158,17 +166,17 @@ class BackendPersistenceTestMixin(base.PersistenceTestMixin): self.backend = None try: self.db_uri = self._init_db() + self.db_conf = { + 'connection': self.db_uri + } # Since we are using random database names, we need to make sure # and remove our random database when we are done testing. self.addCleanup(self._remove_db) - conf = { - 'connection': self.db_uri - } except Exception as e: self.skipTest("Failed to create temporary database;" " testing being skipped due to: %s" % (e)) try: - self.backend = impl_sqlalchemy.SQLAlchemyBackend(conf) + self.backend = impl_sqlalchemy.SQLAlchemyBackend(self.db_conf) self.addCleanup(self.backend.close) with contextlib.closing(self._get_connection()) as conn: conn.upgrade()