Merge "Remove attrdict and just use existing types"

This commit is contained in:
Jenkins
2014-11-19 19:01:31 +00:00
committed by Gerrit Code Review
8 changed files with 74 additions and 236 deletions

View File

@@ -54,12 +54,12 @@ def _extract_engine(**kwargs):
kind = ENGINE_DEFAULT kind = ENGINE_DEFAULT
# See if it's a URI and if so, extract any further options... # See if it's a URI and if so, extract any further options...
try: try:
pieces = misc.parse_uri(kind) uri = misc.parse_uri(kind)
except (TypeError, ValueError): except (TypeError, ValueError):
pass pass
else: else:
kind = pieces['scheme'] kind = uri.scheme
options = misc.merge_uri(pieces, options.copy()) options = misc.merge_uri(uri, options.copy())
# Merge in any leftover **kwargs into the options, this makes it so that # Merge in any leftover **kwargs into the options, this makes it so that
# the provided **kwargs override any URI or engine_conf specific options. # the provided **kwargs override any URI or engine_conf specific options.
options.update(kwargs) options.update(kwargs)

View File

@@ -14,6 +14,7 @@
# License for the specific language governing permissions and limitations # License for the specific language governing permissions and limitations
# under the License. # under the License.
import collections
import logging import logging
import socket import socket
@@ -21,7 +22,6 @@ import kombu
import six import six
from taskflow.engines.worker_based import dispatcher from taskflow.engines.worker_based import dispatcher
from taskflow.utils import misc
from taskflow.utils import threading_utils from taskflow.utils import threading_utils
LOG = logging.getLogger(__name__) LOG = logging.getLogger(__name__)
@@ -30,6 +30,16 @@ LOG = logging.getLogger(__name__)
# the socket can get "stuck", and is a best practice for Kombu consumers. # the socket can get "stuck", and is a best practice for Kombu consumers.
DRAIN_EVENTS_PERIOD = 1 DRAIN_EVENTS_PERIOD = 1
# Helper objects returned when requested to get connection details, used
# instead of returning the raw results from the kombu connection objects
# themselves so that a person can not mutate those objects (which would be
# bad).
_ConnectionDetails = collections.namedtuple('_ConnectionDetails',
['uri', 'transport'])
_TransportDetails = collections.namedtuple('_TransportDetails',
['options', 'driver_type',
'driver_name', 'driver_version'])
class Proxy(object): class Proxy(object):
"""A proxy processes messages from/to the named exchange.""" """A proxy processes messages from/to the named exchange."""
@@ -71,13 +81,18 @@ class Proxy(object):
driver_version = self._conn.transport.driver_version() driver_version = self._conn.transport.driver_version()
if driver_version and driver_version.lower() == 'n/a': if driver_version and driver_version.lower() == 'n/a':
driver_version = None driver_version = None
return misc.AttrDict( if self._conn.transport_options:
uri=self._conn.as_uri(include_password=False), transport_options = self._conn.transport_options.copy()
transport=misc.AttrDict( else:
options=dict(self._conn.transport_options), transport_options = {}
transport = _TransportDetails(
options=transport_options,
driver_type=self._conn.transport.driver_type, driver_type=self._conn.transport.driver_type,
driver_name=self._conn.transport.driver_name, driver_name=self._conn.transport.driver_name,
driver_version=driver_version)) driver_version=driver_version)
return _ConnectionDetails(
uri=self._conn.as_uri(include_password=False),
transport=transport)
@property @property
def is_running(self): def is_running(self):

View File

@@ -148,6 +148,12 @@ class DeclareSuccess(task.Task):
print("All data processed and sent to %s" % (sent_to)) print("All data processed and sent to %s" % (sent_to))
class DummyUser(object):
def __init__(self, user, id):
self.user = user
self.id = id
# Resources (db handles and similar) of course can *not* be persisted so we # Resources (db handles and similar) of course can *not* be persisted so we
# need to make sure that we pass this resource fetcher to the tasks constructor # need to make sure that we pass this resource fetcher to the tasks constructor
# so that the tasks have access to any needed resources (the resources are # so that the tasks have access to any needed resources (the resources are
@@ -168,7 +174,7 @@ flow.add(sub_flow)
# prepopulating this allows the tasks that dependent on the 'request' variable # prepopulating this allows the tasks that dependent on the 'request' variable
# to start processing (in this case this is the ExtractInputRequest task). # to start processing (in this case this is the ExtractInputRequest task).
store = { store = {
'request': misc.AttrDict(user="bob", id="1.35"), 'request': DummyUser(user="bob", id="1.35"),
} }
eng = engines.load(flow, engine='serial', store=store) eng = engines.load(flow, engine='serial', store=store)

View File

@@ -55,12 +55,12 @@ def fetch(name, conf, namespace=BACKEND_NAMESPACE, **kwargs):
conf = {'board': conf} conf = {'board': conf}
board = conf['board'] board = conf['board']
try: try:
pieces = misc.parse_uri(board) uri = misc.parse_uri(board)
except (TypeError, ValueError): except (TypeError, ValueError):
pass pass
else: else:
board = pieces['scheme'] board = uri.scheme
conf = misc.merge_uri(pieces, conf.copy()) conf = misc.merge_uri(uri, conf.copy())
LOG.debug('Looking for %r jobboard driver in %r', board, namespace) LOG.debug('Looking for %r jobboard driver in %r', board, namespace)
try: try:
mgr = driver.DriverManager(namespace, board, mgr = driver.DriverManager(namespace, board,

View File

@@ -52,12 +52,12 @@ def fetch(conf, namespace=BACKEND_NAMESPACE, **kwargs):
""" """
backend_name = conf['connection'] backend_name = conf['connection']
try: try:
pieces = misc.parse_uri(backend_name) uri = misc.parse_uri(backend_name)
except (TypeError, ValueError): except (TypeError, ValueError):
pass pass
else: else:
backend_name = pieces['scheme'] backend_name = uri.scheme
conf = misc.merge_uri(pieces, conf.copy()) conf = misc.merge_uri(uri, conf.copy())
LOG.debug('Looking for %r backend driver in %r', backend_name, namespace) LOG.debug('Looking for %r backend driver in %r', backend_name, namespace)
try: try:
mgr = driver.DriverManager(namespace, backend_name, mgr = driver.DriverManager(namespace, backend_name,

View File

@@ -14,6 +14,7 @@
# License for the specific language governing permissions and limitations # License for the specific language governing permissions and limitations
# under the License. # under the License.
import collections
import contextlib import contextlib
import threading import threading
@@ -28,7 +29,6 @@ from taskflow.persistence.backends import impl_memory
from taskflow import states as st from taskflow import states as st
from taskflow import test from taskflow import test
from taskflow.tests import utils as test_utils from taskflow.tests import utils as test_utils
from taskflow.utils import misc
from taskflow.utils import persistence_utils as pu from taskflow.utils import persistence_utils as pu
from taskflow.utils import threading_utils from taskflow.utils import threading_utils
@@ -58,6 +58,10 @@ def make_thread(conductor):
class SingleThreadedConductorTest(test_utils.EngineTestBase, test.TestCase): class SingleThreadedConductorTest(test_utils.EngineTestBase, test.TestCase):
ComponentBundle = collections.namedtuple('ComponentBundle',
['board', 'client',
'persistence', 'conductor'])
def make_components(self, name='testing', wait_timeout=0.1): def make_components(self, name='testing', wait_timeout=0.1):
client = fake_client.FakeClient() client = fake_client.FakeClient()
persistence = impl_memory.MemoryBackend() persistence = impl_memory.MemoryBackend()
@@ -66,10 +70,7 @@ class SingleThreadedConductorTest(test_utils.EngineTestBase, test.TestCase):
persistence=persistence) persistence=persistence)
conductor = stc.SingleThreadedConductor(name, board, persistence, conductor = stc.SingleThreadedConductor(name, board, persistence,
wait_timeout=wait_timeout) wait_timeout=wait_timeout)
return misc.AttrDict(board=board, return self.ComponentBundle(board, client, persistence, conductor)
client=client,
persistence=persistence,
conductor=conductor)
def test_connection(self): def test_connection(self):
components = self.make_components() components = self.make_components()

View File

@@ -387,109 +387,6 @@ class CachedPropertyTest(test.TestCase):
self.assertEqual('b', a.b) self.assertEqual('b', a.b)
class AttrDictTest(test.TestCase):
def test_ok_create(self):
attrs = {
'a': 1,
'b': 2,
}
obj = misc.AttrDict(**attrs)
self.assertEqual(obj.a, 1)
self.assertEqual(obj.b, 2)
def test_private_create(self):
attrs = {
'_a': 1,
}
self.assertRaises(AttributeError, misc.AttrDict, **attrs)
def test_invalid_create(self):
attrs = {
# Python attributes can't start with a number.
'123_abc': 1,
}
self.assertRaises(AttributeError, misc.AttrDict, **attrs)
def test_no_overwrite(self):
attrs = {
# Python attributes can't start with a number.
'update': 1,
}
self.assertRaises(AttributeError, misc.AttrDict, **attrs)
def test_back_todict(self):
attrs = {
'a': 1,
}
obj = misc.AttrDict(**attrs)
self.assertEqual(obj.a, 1)
self.assertEqual(attrs, dict(obj))
def test_runtime_invalid_set(self):
def bad_assign(obj):
obj._123 = 'b'
attrs = {
'a': 1,
}
obj = misc.AttrDict(**attrs)
self.assertEqual(obj.a, 1)
self.assertRaises(AttributeError, bad_assign, obj)
def test_bypass_get(self):
attrs = {
'a': 1,
}
obj = misc.AttrDict(**attrs)
self.assertEqual(1, obj['a'])
def test_bypass_set_no_get(self):
def bad_assign(obj):
obj._b = 'e'
attrs = {
'a': 1,
}
obj = misc.AttrDict(**attrs)
self.assertEqual(1, obj['a'])
obj['_b'] = 'c'
self.assertRaises(AttributeError, bad_assign, obj)
self.assertEqual('c', obj['_b'])
class IsValidAttributeNameTestCase(test.TestCase):
def test_a_is_ok(self):
self.assertTrue(misc.is_valid_attribute_name('a'))
def test_name_can_be_longer(self):
self.assertTrue(misc.is_valid_attribute_name('foobarbaz'))
def test_name_can_have_digits(self):
self.assertTrue(misc.is_valid_attribute_name('fo12'))
def test_name_cannot_start_with_digit(self):
self.assertFalse(misc.is_valid_attribute_name('1z'))
def test_hidden_names_are_forbidden(self):
self.assertFalse(misc.is_valid_attribute_name('_z'))
def test_hidden_names_can_be_allowed(self):
self.assertTrue(
misc.is_valid_attribute_name('_z', allow_hidden=True))
def test_self_is_forbidden(self):
self.assertFalse(misc.is_valid_attribute_name('self'))
def test_self_can_be_allowed(self):
self.assertTrue(
misc.is_valid_attribute_name('self', allow_self=True))
def test_no_unicode_please(self):
self.assertFalse(misc.is_valid_attribute_name('mañana'))
class UriParseTest(test.TestCase): class UriParseTest(test.TestCase):
def test_parse(self): def test_parse(self):
url = "zookeeper://192.168.0.1:2181/a/b/?c=d" url = "zookeeper://192.168.0.1:2181/a/b/?c=d"
@@ -501,11 +398,6 @@ class UriParseTest(test.TestCase):
self.assertEqual('/a/b/', parsed.path) self.assertEqual('/a/b/', parsed.path)
self.assertEqual({'c': 'd'}, parsed.params) self.assertEqual({'c': 'd'}, parsed.params)
def test_multi_params(self):
url = "mysql://www.yahoo.com:3306/a/b/?c=d&c=e"
parsed = misc.parse_uri(url, query_duplicates=True)
self.assertEqual({'c': ['d', 'e']}, parsed.params)
def test_port_provided(self): def test_port_provided(self):
url = "rabbitmq://www.yahoo.com:5672" url = "rabbitmq://www.yahoo.com:5672"
parsed = misc.parse_uri(url) parsed = misc.parse_uri(url)

View File

@@ -19,11 +19,9 @@ import contextlib
import datetime import datetime
import errno import errno
import inspect import inspect
import keyword
import logging import logging
import os import os
import re import re
import string
import sys import sys
import threading import threading
@@ -48,7 +46,23 @@ NUMERIC_TYPES = six.integer_types + (float,)
_SCHEME_REGEX = re.compile(r"^([A-Za-z][A-Za-z0-9+.-]*):") _SCHEME_REGEX = re.compile(r"^([A-Za-z][A-Za-z0-9+.-]*):")
def merge_uri(uri_pieces, conf): # FIXME(harlowja): This should be removed with the next version of oslo.utils
# which now has this functionality built-in, until then we are deriving from
# there base class and adding this functionality on...
#
# The change was merged @ https://review.openstack.org/#/c/118881/
class ModifiedSplitResult(netutils._ModifiedSplitResult):
"""A split result that exposes the query parameters as a dictionary."""
@property
def params(self):
if self.query:
return dict(urlparse.parse_qsl(self.query))
else:
return {}
def merge_uri(uri, conf):
"""Merges a parsed uri into the given configuration dictionary. """Merges a parsed uri into the given configuration dictionary.
Merges the username, password, hostname, and query params of a uri into Merges the username, password, hostname, and query params of a uri into
@@ -57,22 +71,21 @@ def merge_uri(uri_pieces, conf):
NOTE(harlowja): does not merge the path, scheme or fragment. NOTE(harlowja): does not merge the path, scheme or fragment.
""" """
for k in ('username', 'password'): for (k, v) in [('username', uri.username), ('password', uri.password)]:
if not uri_pieces[k]: if not v:
continue continue
conf.setdefault(k, uri_pieces[k]) conf.setdefault(k, v)
hostname = uri_pieces.get('hostname') if uri.hostname:
if hostname: hostname = uri.hostname
port = uri_pieces.get('port') if uri.port is not None:
if port is not None: hostname += ":%s" % (uri.port)
hostname += ":%s" % (port)
conf.setdefault('hostname', hostname) conf.setdefault('hostname', hostname)
for (k, v) in six.iteritems(uri_pieces['params']): for (k, v) in six.iteritems(uri.params):
conf.setdefault(k, v) conf.setdefault(k, v)
return conf return conf
def parse_uri(uri, query_duplicates=False): def parse_uri(uri):
"""Parses a uri into its components.""" """Parses a uri into its components."""
# Do some basic validation before continuing... # Do some basic validation before continuing...
if not isinstance(uri, six.string_types): if not isinstance(uri, six.string_types):
@@ -83,38 +96,10 @@ def parse_uri(uri, query_duplicates=False):
if not match: if not match:
raise ValueError("Uri %r does not start with a RFC 3986 compliant" raise ValueError("Uri %r does not start with a RFC 3986 compliant"
" scheme" % (uri)) " scheme" % (uri))
parsed = netutils.urlsplit(uri) split = netutils.urlsplit(uri)
if parsed.query: return ModifiedSplitResult(scheme=split.scheme, fragment=split.fragment,
query_params = urlparse.parse_qsl(parsed.query) path=split.path, netloc=split.netloc,
if not query_duplicates: query=split.query)
query_params = dict(query_params)
else:
# Retain duplicates in a list for keys which have duplicates, but
# for items which are not duplicated, just associate the key with
# the value.
tmp_query_params = {}
for (k, v) in query_params:
if k in tmp_query_params:
p_v = tmp_query_params[k]
if isinstance(p_v, list):
p_v.append(v)
else:
p_v = [p_v, v]
tmp_query_params[k] = p_v
else:
tmp_query_params[k] = v
query_params = tmp_query_params
else:
query_params = {}
return AttrDict(
scheme=parsed.scheme,
username=parsed.username,
password=parsed.password,
fragment=parsed.fragment,
path=parsed.path,
params=query_params,
hostname=parsed.hostname,
port=parsed.port)
def binary_encode(text, encoding='utf-8'): def binary_encode(text, encoding='utf-8'):
@@ -271,67 +256,6 @@ def get_duplicate_keys(iterable, key=None):
return duplicates return duplicates
# NOTE(imelnikov): we should not use str.isalpha or str.isdigit
# as they are locale-dependant
_ASCII_WORD_SYMBOLS = frozenset(string.ascii_letters + string.digits + '_')
def is_valid_attribute_name(name, allow_self=False, allow_hidden=False):
"""Checks that a string is a valid/invalid python attribute name."""
return all((
isinstance(name, six.string_types),
len(name) > 0,
(allow_self or not name.lower().startswith('self')),
(allow_hidden or not name.lower().startswith('_')),
# NOTE(imelnikov): keywords should be forbidden.
not keyword.iskeyword(name),
# See: http://docs.python.org/release/2.5.2/ref/grammar.txt
not (name[0] in string.digits),
all(symbol in _ASCII_WORD_SYMBOLS for symbol in name)
))
class AttrDict(dict):
"""Dictionary subclass that allows for attribute based access.
This subclass allows for accessing a dictionaries keys and values by
accessing those keys as regular attributes. Keys that are not valid python
attribute names can not of course be acccessed/set (those keys must be
accessed/set by the traditional dictionary indexing operators instead).
"""
NO_ATTRS = tuple(reflection.get_member_names(dict))
@classmethod
def _is_valid_attribute_name(cls, name):
if not is_valid_attribute_name(name):
return False
# Make the name just be a simple string in latin-1 encoding in python3.
if name in cls.NO_ATTRS:
return False
return True
def __init__(self, **kwargs):
for (k, v) in kwargs.items():
if not self._is_valid_attribute_name(k):
raise AttributeError("Invalid attribute name: '%s'" % (k))
self[k] = v
def __getattr__(self, name):
if not self._is_valid_attribute_name(name):
raise AttributeError("Invalid attribute name: '%s'" % (name))
try:
return self[name]
except KeyError:
raise AttributeError("No attributed named: '%s'" % (name))
def __setattr__(self, name, value):
if not self._is_valid_attribute_name(name):
raise AttributeError("Invalid attribute name: '%s'" % (name))
self[name] = value
class ExponentialBackoff(object): class ExponentialBackoff(object):
"""An iterable object that will yield back an exponential delay sequence. """An iterable object that will yield back an exponential delay sequence.