diff --git a/openstack-common.conf b/openstack-common.conf index 24d2fc0f..8adc059c 100644 --- a/openstack-common.conf +++ b/openstack-common.conf @@ -6,6 +6,7 @@ module=importutils module=jsonutils module=timeutils module=uuidutils +module=network_utils script=tools/run_cross_tests.sh diff --git a/taskflow/engines/helpers.py b/taskflow/engines/helpers.py index a4435487..876577ff 100644 --- a/taskflow/engines/helpers.py +++ b/taskflow/engines/helpers.py @@ -21,6 +21,7 @@ import stevedore.driver from taskflow.openstack.common import importutils from taskflow.persistence import backends as p_backends +from taskflow.utils import misc from taskflow.utils import persistence_utils as p_utils from taskflow.utils import reflection @@ -80,6 +81,15 @@ def load(flow, store=None, flow_detail=None, book=None, if isinstance(engine_conf, six.string_types): engine_conf = {'engine': engine_conf} + engine_name = engine_conf['engine'] + try: + pieces = misc.parse_uri(engine_name) + except (TypeError, ValueError): + pass + else: + engine_name = pieces['scheme'] + engine_conf = misc.merge_uri(pieces, engine_conf.copy()) + if isinstance(backend, dict): backend = p_backends.fetch(backend) @@ -88,7 +98,7 @@ def load(flow, store=None, flow_detail=None, book=None, backend=backend) mgr = stevedore.driver.DriverManager( - namespace, engine_conf['engine'], + namespace, engine_name, invoke_on_load=True, invoke_kwds={ 'conf': engine_conf.copy(), diff --git a/taskflow/jobs/backends/__init__.py b/taskflow/jobs/backends/__init__.py index ad4dc060..b720024b 100644 --- a/taskflow/jobs/backends/__init__.py +++ b/taskflow/jobs/backends/__init__.py @@ -20,6 +20,7 @@ import six from stevedore import driver from taskflow import exceptions as exc +from taskflow.utils import misc # NOTE(harlowja): this is the entrypoint namespace, not the module namespace. @@ -33,11 +34,16 @@ def fetch(name, conf, namespace=BACKEND_NAMESPACE, **kwargs): specific kwargs) in the given entrypoint namespace and create it with the given name. """ - # NOTE(harlowja): this allows simpler syntax. if isinstance(conf, six.string_types): conf = {'board': conf} - board = conf['board'] + try: + pieces = misc.parse_uri(board) + except (TypeError, ValueError): + pass + else: + board = pieces['scheme'] + conf = misc.merge_uri(pieces, conf.copy()) LOG.debug('Looking for %r jobboard driver in %r', board, namespace) try: mgr = driver.DriverManager(namespace, board, diff --git a/taskflow/openstack/common/network_utils.py b/taskflow/openstack/common/network_utils.py new file mode 100644 index 00000000..fa812b29 --- /dev/null +++ b/taskflow/openstack/common/network_utils.py @@ -0,0 +1,108 @@ +# Copyright 2012 OpenStack Foundation. +# All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + +""" +Network-related utilities and helper functions. +""" + +# TODO(jd) Use six.moves once +# https://bitbucket.org/gutworth/six/pull-request/28 +# is merged +try: + import urllib.parse + SplitResult = urllib.parse.SplitResult +except ImportError: + import urlparse + SplitResult = urlparse.SplitResult + +from six.moves.urllib import parse + + +def parse_host_port(address, default_port=None): + """Interpret a string as a host:port pair. + + An IPv6 address MUST be escaped if accompanied by a port, + because otherwise ambiguity ensues: 2001:db8:85a3::8a2e:370:7334 + means both [2001:db8:85a3::8a2e:370:7334] and + [2001:db8:85a3::8a2e:370]:7334. + + >>> parse_host_port('server01:80') + ('server01', 80) + >>> parse_host_port('server01') + ('server01', None) + >>> parse_host_port('server01', default_port=1234) + ('server01', 1234) + >>> parse_host_port('[::1]:80') + ('::1', 80) + >>> parse_host_port('[::1]') + ('::1', None) + >>> parse_host_port('[::1]', default_port=1234) + ('::1', 1234) + >>> parse_host_port('2001:db8:85a3::8a2e:370:7334', default_port=1234) + ('2001:db8:85a3::8a2e:370:7334', 1234) + + """ + if address[0] == '[': + # Escaped ipv6 + _host, _port = address[1:].split(']') + host = _host + if ':' in _port: + port = _port.split(':')[1] + else: + port = default_port + else: + if address.count(':') == 1: + host, port = address.split(':') + else: + # 0 means ipv4, >1 means ipv6. + # We prohibit unescaped ipv6 addresses with port. + host = address + port = default_port + + return (host, None if port is None else int(port)) + + +class ModifiedSplitResult(SplitResult): + """Split results class for urlsplit.""" + + # NOTE(dims): The functions below are needed for Python 2.6.x. + # We can remove these when we drop support for 2.6.x. + @property + def hostname(self): + netloc = self.netloc.split('@', 1)[-1] + host, port = parse_host_port(netloc) + return host + + @property + def port(self): + netloc = self.netloc.split('@', 1)[-1] + host, port = parse_host_port(netloc) + return port + + +def urlsplit(url, scheme='', allow_fragments=True): + """Parse a URL using urlparse.urlsplit(), splitting query and fragments. + This function papers over Python issue9374 when needed. + + The parameters are the same as urlparse.urlsplit. + """ + scheme, netloc, path, query, fragment = parse.urlsplit( + url, scheme, allow_fragments) + if allow_fragments and '#' in path: + path, fragment = path.split('#', 1) + if '?' in path: + path, query = path.split('?', 1) + return ModifiedSplitResult(scheme, netloc, + path, query, fragment) diff --git a/taskflow/persistence/backends/__init__.py b/taskflow/persistence/backends/__init__.py index 5cf30243..f89e60d4 100644 --- a/taskflow/persistence/backends/__init__.py +++ b/taskflow/persistence/backends/__init__.py @@ -15,20 +15,16 @@ # under the License. import logging -import re from stevedore import driver from taskflow import exceptions as exc +from taskflow.utils import misc # NOTE(harlowja): this is the entrypoint namespace, not the module namespace. BACKEND_NAMESPACE = 'taskflow.persistence' -# NOTE(imelnikov): regular expression to get scheme from URI, -# see RFC 3986 section 3.1 -SCHEME_REGEX = re.compile(r"^([A-Za-z]{1}[A-Za-z0-9+.-]*):") - LOG = logging.getLogger(__name__) @@ -36,14 +32,14 @@ def fetch(conf, namespace=BACKEND_NAMESPACE, **kwargs): """Fetches a given backend using the given configuration (and any backend specific kwargs) in the given entrypoint namespace. """ - connection = conf['connection'] - - match = SCHEME_REGEX.match(connection) - if match: - backend_name = match.group(1) + backend_name = conf['connection'] + try: + pieces = misc.parse_uri(backend_name) + except (TypeError, ValueError): + pass else: - backend_name = connection - + backend_name = pieces['scheme'] + conf = misc.merge_uri(pieces, conf.copy()) LOG.debug('Looking for %r backend driver in %r', backend_name, namespace) try: mgr = driver.DriverManager(namespace, backend_name, diff --git a/taskflow/tests/unit/test_utils.py b/taskflow/tests/unit/test_utils.py index d0bb0695..fbdf3ec6 100644 --- a/taskflow/tests/unit/test_utils.py +++ b/taskflow/tests/unit/test_utils.py @@ -456,6 +456,51 @@ class StopWatchUtilsTest(test.TestCase): self.assertGreater(0.01, watch.elapsed()) +class UriParseTest(test.TestCase): + def test_parse(self): + url = "zookeeper://192.168.0.1:2181/a/b/?c=d" + parsed = misc.parse_uri(url) + self.assertEqual('zookeeper', parsed.scheme) + self.assertEqual(2181, parsed.port) + self.assertEqual('192.168.0.1', parsed.hostname) + self.assertEqual('', parsed.fragment) + self.assertEqual('/a/b/', parsed.path) + self.assertEqual({'c': 'd'}, parsed.params) + + def test_multi_params(self): + url = "mysql://www.yahoo.com:3306/a/b/?c=d&c=e" + parsed = misc.parse_uri(url, query_duplicates=True) + self.assertEqual({'c': ['d', 'e']}, parsed.params) + + def test_port_provided(self): + url = "rabbitmq://www.yahoo.com:5672" + parsed = misc.parse_uri(url) + self.assertEqual('rabbitmq', parsed.scheme) + self.assertEqual('www.yahoo.com', parsed.hostname) + self.assertEqual(5672, parsed.port) + self.assertEqual('', parsed.path) + + def test_ipv6_host(self): + url = "rsync://[2001:db8:0:1]:873" + parsed = misc.parse_uri(url) + self.assertEqual('rsync', parsed.scheme) + self.assertEqual('2001:db8:0:1', parsed.hostname) + self.assertEqual(873, parsed.port) + + def test_user_password(self): + url = "rsync://test:test_pw@www.yahoo.com:873" + parsed = misc.parse_uri(url) + self.assertEqual('test', parsed.username) + self.assertEqual('test_pw', parsed.password) + self.assertEqual('www.yahoo.com', parsed.hostname) + + def test_user(self): + url = "rsync://test@www.yahoo.com:873" + parsed = misc.parse_uri(url) + self.assertEqual('test', parsed.username) + self.assertEqual(None, parsed.password) + + class ExcInfoUtilsTest(test.TestCase): def _make_ex_info(self): diff --git a/taskflow/utils/misc.py b/taskflow/utils/misc.py index 0e3a1c3d..0a592689 100644 --- a/taskflow/utils/misc.py +++ b/taskflow/utils/misc.py @@ -24,6 +24,7 @@ import functools import keyword import logging import os +import re import string import sys import threading @@ -31,15 +32,101 @@ import time import traceback import six +from six.moves.urllib import parse as urlparse from taskflow import exceptions as exc from taskflow.openstack.common import jsonutils +from taskflow.openstack.common import network_utils from taskflow.utils import reflection LOG = logging.getLogger(__name__) NUMERIC_TYPES = six.integer_types + (float,) +# NOTE(imelnikov): regular expression to get scheme from URI, +# see RFC 3986 section 3.1 +_SCHEME_REGEX = re.compile(r"^([A-Za-z][A-Za-z0-9+.-]*):") + + +def merge_uri(uri_pieces, conf): + """Merges the username, password, hostname, and query params of a uri into + the given configuration (does not overwrite the configuration keys if they + already exist) and returns the adjusted configuration. + + NOTE(harlowja): does not merge the path, scheme or fragment. + """ + for k in ('username', 'password'): + if not uri_pieces[k]: + continue + conf.setdefault(k, uri_pieces[k]) + hostname = uri_pieces.get('hostname') + if hostname: + port = uri_pieces.get('port') + if port is not None: + hostname += ":%s" % (port) + conf.setdefault('hostname', hostname) + for (k, v) in six.iteritems(uri_pieces['params']): + conf.setdefault(k, v) + return conf + + +def parse_uri(uri, query_duplicates=False): + """Parses a uri into its components and returns a dictionary containing + those components. + """ + # Do some basic validation before continuing... + if not isinstance(uri, six.string_types): + raise TypeError("Can only parse string types to uri data, " + "and not an object of type %s" + % reflection.get_class_name(uri)) + match = _SCHEME_REGEX.match(uri) + if not match: + raise ValueError("Uri %r does not start with a RFC 3986 compliant" + " scheme" % (uri)) + parsed = network_utils.urlsplit(uri) + if parsed.query: + query_params = urlparse.parse_qsl(parsed.query) + if not query_duplicates: + query_params = dict(query_params) + else: + # Retain duplicates in a list for keys which have duplicates, but + # for items which are not duplicated, just associate the key with + # the value. + tmp_query_params = {} + for (k, v) in query_params: + if k in tmp_query_params: + p_v = tmp_query_params[k] + if isinstance(p_v, list): + p_v.append(v) + else: + p_v = [p_v, v] + tmp_query_params[k] = p_v + else: + tmp_query_params[k] = v + query_params = tmp_query_params + else: + query_params = {} + uri_pieces = { + 'scheme': parsed.scheme, + 'username': parsed.username, + 'password': parsed.password, + 'fragment': parsed.fragment, + 'path': parsed.path, + 'params': query_params, + } + for k in ('hostname', 'port'): + try: + uri_pieces[k] = getattr(parsed, k) + except (IndexError, ValueError): + # The underlying network_utils throws when the host string is empty + # which it may be in cases where it is not provided. + # + # NOTE(harlowja): when https://review.openstack.org/#/c/86921/ gets + # merged we can just remove this since that error will no longer + # occur. + uri_pieces[k] = None + return AttrDict(**uri_pieces) + def binary_encode(text, encoding='utf-8'): """Converts a string of into a binary type using given encoding.