Merge "Integrate urlparse for configuration augmentation"

This commit is contained in:
Jenkins
2014-05-01 06:54:52 +00:00
committed by Gerrit Code Review
7 changed files with 268 additions and 15 deletions

View File

@@ -6,6 +6,7 @@ module=importutils
module=jsonutils
module=timeutils
module=uuidutils
module=network_utils
script=tools/run_cross_tests.sh

View File

@@ -21,6 +21,7 @@ import stevedore.driver
from taskflow.openstack.common import importutils
from taskflow.persistence import backends as p_backends
from taskflow.utils import misc
from taskflow.utils import persistence_utils as p_utils
from taskflow.utils import reflection
@@ -80,6 +81,15 @@ def load(flow, store=None, flow_detail=None, book=None,
if isinstance(engine_conf, six.string_types):
engine_conf = {'engine': engine_conf}
engine_name = engine_conf['engine']
try:
pieces = misc.parse_uri(engine_name)
except (TypeError, ValueError):
pass
else:
engine_name = pieces['scheme']
engine_conf = misc.merge_uri(pieces, engine_conf.copy())
if isinstance(backend, dict):
backend = p_backends.fetch(backend)
@@ -88,7 +98,7 @@ def load(flow, store=None, flow_detail=None, book=None,
backend=backend)
mgr = stevedore.driver.DriverManager(
namespace, engine_conf['engine'],
namespace, engine_name,
invoke_on_load=True,
invoke_kwds={
'conf': engine_conf.copy(),

View File

@@ -20,6 +20,7 @@ import six
from stevedore import driver
from taskflow import exceptions as exc
from taskflow.utils import misc
# NOTE(harlowja): this is the entrypoint namespace, not the module namespace.
@@ -33,11 +34,16 @@ def fetch(name, conf, namespace=BACKEND_NAMESPACE, **kwargs):
specific kwargs) in the given entrypoint namespace and create it with the
given name.
"""
# NOTE(harlowja): this allows simpler syntax.
if isinstance(conf, six.string_types):
conf = {'board': conf}
board = conf['board']
try:
pieces = misc.parse_uri(board)
except (TypeError, ValueError):
pass
else:
board = pieces['scheme']
conf = misc.merge_uri(pieces, conf.copy())
LOG.debug('Looking for %r jobboard driver in %r', board, namespace)
try:
mgr = driver.DriverManager(namespace, board,

View File

@@ -0,0 +1,108 @@
# Copyright 2012 OpenStack Foundation.
# All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License. You may obtain
# a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations
# under the License.
"""
Network-related utilities and helper functions.
"""
# TODO(jd) Use six.moves once
# https://bitbucket.org/gutworth/six/pull-request/28
# is merged
try:
import urllib.parse
SplitResult = urllib.parse.SplitResult
except ImportError:
import urlparse
SplitResult = urlparse.SplitResult
from six.moves.urllib import parse
def parse_host_port(address, default_port=None):
"""Interpret a string as a host:port pair.
An IPv6 address MUST be escaped if accompanied by a port,
because otherwise ambiguity ensues: 2001:db8:85a3::8a2e:370:7334
means both [2001:db8:85a3::8a2e:370:7334] and
[2001:db8:85a3::8a2e:370]:7334.
>>> parse_host_port('server01:80')
('server01', 80)
>>> parse_host_port('server01')
('server01', None)
>>> parse_host_port('server01', default_port=1234)
('server01', 1234)
>>> parse_host_port('[::1]:80')
('::1', 80)
>>> parse_host_port('[::1]')
('::1', None)
>>> parse_host_port('[::1]', default_port=1234)
('::1', 1234)
>>> parse_host_port('2001:db8:85a3::8a2e:370:7334', default_port=1234)
('2001:db8:85a3::8a2e:370:7334', 1234)
"""
if address[0] == '[':
# Escaped ipv6
_host, _port = address[1:].split(']')
host = _host
if ':' in _port:
port = _port.split(':')[1]
else:
port = default_port
else:
if address.count(':') == 1:
host, port = address.split(':')
else:
# 0 means ipv4, >1 means ipv6.
# We prohibit unescaped ipv6 addresses with port.
host = address
port = default_port
return (host, None if port is None else int(port))
class ModifiedSplitResult(SplitResult):
"""Split results class for urlsplit."""
# NOTE(dims): The functions below are needed for Python 2.6.x.
# We can remove these when we drop support for 2.6.x.
@property
def hostname(self):
netloc = self.netloc.split('@', 1)[-1]
host, port = parse_host_port(netloc)
return host
@property
def port(self):
netloc = self.netloc.split('@', 1)[-1]
host, port = parse_host_port(netloc)
return port
def urlsplit(url, scheme='', allow_fragments=True):
"""Parse a URL using urlparse.urlsplit(), splitting query and fragments.
This function papers over Python issue9374 when needed.
The parameters are the same as urlparse.urlsplit.
"""
scheme, netloc, path, query, fragment = parse.urlsplit(
url, scheme, allow_fragments)
if allow_fragments and '#' in path:
path, fragment = path.split('#', 1)
if '?' in path:
path, query = path.split('?', 1)
return ModifiedSplitResult(scheme, netloc,
path, query, fragment)

View File

@@ -15,20 +15,16 @@
# under the License.
import logging
import re
from stevedore import driver
from taskflow import exceptions as exc
from taskflow.utils import misc
# NOTE(harlowja): this is the entrypoint namespace, not the module namespace.
BACKEND_NAMESPACE = 'taskflow.persistence'
# NOTE(imelnikov): regular expression to get scheme from URI,
# see RFC 3986 section 3.1
SCHEME_REGEX = re.compile(r"^([A-Za-z]{1}[A-Za-z0-9+.-]*):")
LOG = logging.getLogger(__name__)
@@ -36,14 +32,14 @@ def fetch(conf, namespace=BACKEND_NAMESPACE, **kwargs):
"""Fetches a given backend using the given configuration (and any backend
specific kwargs) in the given entrypoint namespace.
"""
connection = conf['connection']
match = SCHEME_REGEX.match(connection)
if match:
backend_name = match.group(1)
backend_name = conf['connection']
try:
pieces = misc.parse_uri(backend_name)
except (TypeError, ValueError):
pass
else:
backend_name = connection
backend_name = pieces['scheme']
conf = misc.merge_uri(pieces, conf.copy())
LOG.debug('Looking for %r backend driver in %r', backend_name, namespace)
try:
mgr = driver.DriverManager(namespace, backend_name,

View File

@@ -456,6 +456,51 @@ class StopWatchUtilsTest(test.TestCase):
self.assertGreater(0.01, watch.elapsed())
class UriParseTest(test.TestCase):
def test_parse(self):
url = "zookeeper://192.168.0.1:2181/a/b/?c=d"
parsed = misc.parse_uri(url)
self.assertEqual('zookeeper', parsed.scheme)
self.assertEqual(2181, parsed.port)
self.assertEqual('192.168.0.1', parsed.hostname)
self.assertEqual('', parsed.fragment)
self.assertEqual('/a/b/', parsed.path)
self.assertEqual({'c': 'd'}, parsed.params)
def test_multi_params(self):
url = "mysql://www.yahoo.com:3306/a/b/?c=d&c=e"
parsed = misc.parse_uri(url, query_duplicates=True)
self.assertEqual({'c': ['d', 'e']}, parsed.params)
def test_port_provided(self):
url = "rabbitmq://www.yahoo.com:5672"
parsed = misc.parse_uri(url)
self.assertEqual('rabbitmq', parsed.scheme)
self.assertEqual('www.yahoo.com', parsed.hostname)
self.assertEqual(5672, parsed.port)
self.assertEqual('', parsed.path)
def test_ipv6_host(self):
url = "rsync://[2001:db8:0:1]:873"
parsed = misc.parse_uri(url)
self.assertEqual('rsync', parsed.scheme)
self.assertEqual('2001:db8:0:1', parsed.hostname)
self.assertEqual(873, parsed.port)
def test_user_password(self):
url = "rsync://test:test_pw@www.yahoo.com:873"
parsed = misc.parse_uri(url)
self.assertEqual('test', parsed.username)
self.assertEqual('test_pw', parsed.password)
self.assertEqual('www.yahoo.com', parsed.hostname)
def test_user(self):
url = "rsync://test@www.yahoo.com:873"
parsed = misc.parse_uri(url)
self.assertEqual('test', parsed.username)
self.assertEqual(None, parsed.password)
class ExcInfoUtilsTest(test.TestCase):
def _make_ex_info(self):

View File

@@ -24,6 +24,7 @@ import functools
import keyword
import logging
import os
import re
import string
import sys
import threading
@@ -31,15 +32,101 @@ import time
import traceback
import six
from six.moves.urllib import parse as urlparse
from taskflow import exceptions as exc
from taskflow.openstack.common import jsonutils
from taskflow.openstack.common import network_utils
from taskflow.utils import reflection
LOG = logging.getLogger(__name__)
NUMERIC_TYPES = six.integer_types + (float,)
# NOTE(imelnikov): regular expression to get scheme from URI,
# see RFC 3986 section 3.1
_SCHEME_REGEX = re.compile(r"^([A-Za-z][A-Za-z0-9+.-]*):")
def merge_uri(uri_pieces, conf):
"""Merges the username, password, hostname, and query params of a uri into
the given configuration (does not overwrite the configuration keys if they
already exist) and returns the adjusted configuration.
NOTE(harlowja): does not merge the path, scheme or fragment.
"""
for k in ('username', 'password'):
if not uri_pieces[k]:
continue
conf.setdefault(k, uri_pieces[k])
hostname = uri_pieces.get('hostname')
if hostname:
port = uri_pieces.get('port')
if port is not None:
hostname += ":%s" % (port)
conf.setdefault('hostname', hostname)
for (k, v) in six.iteritems(uri_pieces['params']):
conf.setdefault(k, v)
return conf
def parse_uri(uri, query_duplicates=False):
"""Parses a uri into its components and returns a dictionary containing
those components.
"""
# Do some basic validation before continuing...
if not isinstance(uri, six.string_types):
raise TypeError("Can only parse string types to uri data, "
"and not an object of type %s"
% reflection.get_class_name(uri))
match = _SCHEME_REGEX.match(uri)
if not match:
raise ValueError("Uri %r does not start with a RFC 3986 compliant"
" scheme" % (uri))
parsed = network_utils.urlsplit(uri)
if parsed.query:
query_params = urlparse.parse_qsl(parsed.query)
if not query_duplicates:
query_params = dict(query_params)
else:
# Retain duplicates in a list for keys which have duplicates, but
# for items which are not duplicated, just associate the key with
# the value.
tmp_query_params = {}
for (k, v) in query_params:
if k in tmp_query_params:
p_v = tmp_query_params[k]
if isinstance(p_v, list):
p_v.append(v)
else:
p_v = [p_v, v]
tmp_query_params[k] = p_v
else:
tmp_query_params[k] = v
query_params = tmp_query_params
else:
query_params = {}
uri_pieces = {
'scheme': parsed.scheme,
'username': parsed.username,
'password': parsed.password,
'fragment': parsed.fragment,
'path': parsed.path,
'params': query_params,
}
for k in ('hostname', 'port'):
try:
uri_pieces[k] = getattr(parsed, k)
except (IndexError, ValueError):
# The underlying network_utils throws when the host string is empty
# which it may be in cases where it is not provided.
#
# NOTE(harlowja): when https://review.openstack.org/#/c/86921/ gets
# merged we can just remove this since that error will no longer
# occur.
uri_pieces[k] = None
return AttrDict(**uri_pieces)
def binary_encode(text, encoding='utf-8'):
"""Converts a string of into a binary type using given encoding.