diff --git a/taskflow/jobs/backends/__init__.py b/taskflow/jobs/backends/__init__.py index f818905f..ebbe62fb 100644 --- a/taskflow/jobs/backends/__init__.py +++ b/taskflow/jobs/backends/__init__.py @@ -16,7 +16,6 @@ import contextlib -import six from stevedore import driver from taskflow import exceptions as exc @@ -51,16 +50,7 @@ def fetch(name, conf, namespace=BACKEND_NAMESPACE, **kwargs): is ``{'a': 'b', 'c': 'd'}`` to the constructor of that board instance (also including the name specified). """ - if isinstance(conf, six.string_types): - conf = {'board': conf} - board = conf['board'] - try: - uri = misc.parse_uri(board) - except (TypeError, ValueError): - pass - else: - board = uri.scheme - conf = misc.merge_uri(uri, conf.copy()) + board, conf = misc.extract_driver_and_conf(conf, 'board') LOG.debug('Looking for %r jobboard driver in %r', board, namespace) try: mgr = driver.DriverManager(namespace, board, diff --git a/taskflow/persistence/backends/__init__.py b/taskflow/persistence/backends/__init__.py index 50f24167..294febe1 100644 --- a/taskflow/persistence/backends/__init__.py +++ b/taskflow/persistence/backends/__init__.py @@ -16,7 +16,6 @@ import contextlib -import six from stevedore import driver from taskflow import exceptions as exc @@ -51,30 +50,21 @@ def fetch(conf, namespace=BACKEND_NAMESPACE, **kwargs): is ``{'a': 'b', 'c': 'd'}`` to the constructor of that persistence backend instance. """ - if isinstance(conf, six.string_types): - conf = {'connection': conf} - backend_name = conf['connection'] - try: - uri = misc.parse_uri(backend_name) - except (TypeError, ValueError): - pass - else: - backend_name = uri.scheme - conf = misc.merge_uri(uri, conf.copy()) + backend, conf = misc.extract_driver_and_conf(conf, 'connection') # If the backend is like 'mysql+pymysql://...' which informs the # backend to use a dialect (supported by sqlalchemy at least) we just want # to look at the first component to find our entrypoint backend name... - if backend_name.find("+") != -1: - backend_name = backend_name.split("+", 1)[0] - LOG.debug('Looking for %r backend driver in %r', backend_name, namespace) + if backend.find("+") != -1: + backend = backend.split("+", 1)[0] + LOG.debug('Looking for %r backend driver in %r', backend, namespace) try: - mgr = driver.DriverManager(namespace, backend_name, + mgr = driver.DriverManager(namespace, backend, invoke_on_load=True, invoke_args=(conf,), invoke_kwds=kwargs) return mgr.driver except RuntimeError as e: - raise exc.NotFound("Could not find backend %s: %s" % (backend_name, e)) + raise exc.NotFound("Could not find backend %s: %s" % (backend, e)) @contextlib.contextmanager diff --git a/taskflow/utils/misc.py b/taskflow/utils/misc.py index 1b804400..eb32abd2 100644 --- a/taskflow/utils/misc.py +++ b/taskflow/utils/misc.py @@ -112,6 +112,19 @@ def countdown_iter(start_at, decr=1): start_at -= decr +def extract_driver_and_conf(conf, conf_key): + """Common function to get a driver name and its configuration.""" + if isinstance(conf, six.string_types): + conf = {conf_key: conf} + maybe_uri = conf[conf_key] + try: + uri = parse_uri(maybe_uri) + except (TypeError, ValueError): + return (maybe_uri, conf) + else: + return (uri.scheme, merge_uri(uri, conf.copy())) + + def reverse_enumerate(items): """Like reversed(enumerate(items)) but with less copying/cloning...""" for i in countdown_iter(len(items)):