From 7fe6bf0d6b3fa1dc0345ab0e9af4909c3f5b9549 Mon Sep 17 00:00:00 2001 From: Joshua Harlow Date: Thu, 21 Aug 2014 23:59:19 -0700 Subject: [PATCH] Remove attrdict and just use existing types In order to make it simpler (and less code) just prefer to use object types that already exist instead of trying to make dictionaries also behave like objects. For those that really need this kind of functionality: https://pypi.python.org/pypi/attrdict Change-Id: Ib7ddfa517f0500082fafac2c3e53fd6a158a6ddf --- taskflow/engines/helpers.py | 6 +- taskflow/engines/worker_based/proxy.py | 29 +++- taskflow/examples/fake_billing.py | 8 +- taskflow/jobs/backends/__init__.py | 6 +- taskflow/persistence/backends/__init__.py | 6 +- .../tests/unit/conductor/test_conductor.py | 11 +- taskflow/tests/unit/test_utils.py | 108 -------------- taskflow/utils/misc.py | 136 ++++-------------- 8 files changed, 74 insertions(+), 236 deletions(-) diff --git a/taskflow/engines/helpers.py b/taskflow/engines/helpers.py index a62365919..caf7ec131 100644 --- a/taskflow/engines/helpers.py +++ b/taskflow/engines/helpers.py @@ -54,12 +54,12 @@ def _extract_engine(**kwargs): kind = ENGINE_DEFAULT # See if it's a URI and if so, extract any further options... try: - pieces = misc.parse_uri(kind) + uri = misc.parse_uri(kind) except (TypeError, ValueError): pass else: - kind = pieces['scheme'] - options = misc.merge_uri(pieces, options.copy()) + kind = uri.scheme + options = misc.merge_uri(uri, options.copy()) # Merge in any leftover **kwargs into the options, this makes it so that # the provided **kwargs override any URI or engine_conf specific options. options.update(kwargs) diff --git a/taskflow/engines/worker_based/proxy.py b/taskflow/engines/worker_based/proxy.py index e0f9ce7d6..6f608f424 100644 --- a/taskflow/engines/worker_based/proxy.py +++ b/taskflow/engines/worker_based/proxy.py @@ -14,6 +14,7 @@ # License for the specific language governing permissions and limitations # under the License. +import collections import logging import socket @@ -21,7 +22,6 @@ import kombu import six from taskflow.engines.worker_based import dispatcher -from taskflow.utils import misc from taskflow.utils import threading_utils LOG = logging.getLogger(__name__) @@ -30,6 +30,16 @@ LOG = logging.getLogger(__name__) # the socket can get "stuck", and is a best practice for Kombu consumers. DRAIN_EVENTS_PERIOD = 1 +# Helper objects returned when requested to get connection details, used +# instead of returning the raw results from the kombu connection objects +# themselves so that a person can not mutate those objects (which would be +# bad). +_ConnectionDetails = collections.namedtuple('_ConnectionDetails', + ['uri', 'transport']) +_TransportDetails = collections.namedtuple('_TransportDetails', + ['options', 'driver_type', + 'driver_name', 'driver_version']) + class Proxy(object): """A proxy processes messages from/to the named exchange.""" @@ -71,13 +81,18 @@ class Proxy(object): driver_version = self._conn.transport.driver_version() if driver_version and driver_version.lower() == 'n/a': driver_version = None - return misc.AttrDict( + if self._conn.transport_options: + transport_options = self._conn.transport_options.copy() + else: + transport_options = {} + transport = _TransportDetails( + options=transport_options, + driver_type=self._conn.transport.driver_type, + driver_name=self._conn.transport.driver_name, + driver_version=driver_version) + return _ConnectionDetails( uri=self._conn.as_uri(include_password=False), - transport=misc.AttrDict( - options=dict(self._conn.transport_options), - driver_type=self._conn.transport.driver_type, - driver_name=self._conn.transport.driver_name, - driver_version=driver_version)) + transport=transport) @property def is_running(self): diff --git a/taskflow/examples/fake_billing.py b/taskflow/examples/fake_billing.py index ac15dbae1..22c75cd91 100644 --- a/taskflow/examples/fake_billing.py +++ b/taskflow/examples/fake_billing.py @@ -148,6 +148,12 @@ class DeclareSuccess(task.Task): print("All data processed and sent to %s" % (sent_to)) +class DummyUser(object): + def __init__(self, user, id): + self.user = user + self.id = id + + # Resources (db handles and similar) of course can *not* be persisted so we # need to make sure that we pass this resource fetcher to the tasks constructor # so that the tasks have access to any needed resources (the resources are @@ -168,7 +174,7 @@ flow.add(sub_flow) # prepopulating this allows the tasks that dependent on the 'request' variable # to start processing (in this case this is the ExtractInputRequest task). store = { - 'request': misc.AttrDict(user="bob", id="1.35"), + 'request': DummyUser(user="bob", id="1.35"), } eng = engines.load(flow, engine='serial', store=store) diff --git a/taskflow/jobs/backends/__init__.py b/taskflow/jobs/backends/__init__.py index 099f0476e..94afa6e56 100644 --- a/taskflow/jobs/backends/__init__.py +++ b/taskflow/jobs/backends/__init__.py @@ -55,12 +55,12 @@ def fetch(name, conf, namespace=BACKEND_NAMESPACE, **kwargs): conf = {'board': conf} board = conf['board'] try: - pieces = misc.parse_uri(board) + uri = misc.parse_uri(board) except (TypeError, ValueError): pass else: - board = pieces['scheme'] - conf = misc.merge_uri(pieces, conf.copy()) + board = uri.scheme + conf = misc.merge_uri(uri, conf.copy()) LOG.debug('Looking for %r jobboard driver in %r', board, namespace) try: mgr = driver.DriverManager(namespace, board, diff --git a/taskflow/persistence/backends/__init__.py b/taskflow/persistence/backends/__init__.py index 6faabdef4..64b7cda10 100644 --- a/taskflow/persistence/backends/__init__.py +++ b/taskflow/persistence/backends/__init__.py @@ -52,12 +52,12 @@ def fetch(conf, namespace=BACKEND_NAMESPACE, **kwargs): """ backend_name = conf['connection'] try: - pieces = misc.parse_uri(backend_name) + uri = misc.parse_uri(backend_name) except (TypeError, ValueError): pass else: - backend_name = pieces['scheme'] - conf = misc.merge_uri(pieces, conf.copy()) + backend_name = uri.scheme + conf = misc.merge_uri(uri, conf.copy()) LOG.debug('Looking for %r backend driver in %r', backend_name, namespace) try: mgr = driver.DriverManager(namespace, backend_name, diff --git a/taskflow/tests/unit/conductor/test_conductor.py b/taskflow/tests/unit/conductor/test_conductor.py index 8f9c2d105..cf19fa840 100644 --- a/taskflow/tests/unit/conductor/test_conductor.py +++ b/taskflow/tests/unit/conductor/test_conductor.py @@ -14,6 +14,7 @@ # License for the specific language governing permissions and limitations # under the License. +import collections import contextlib import threading @@ -28,7 +29,6 @@ from taskflow.persistence.backends import impl_memory from taskflow import states as st from taskflow import test from taskflow.tests import utils as test_utils -from taskflow.utils import misc from taskflow.utils import persistence_utils as pu from taskflow.utils import threading_utils @@ -58,6 +58,10 @@ def make_thread(conductor): class SingleThreadedConductorTest(test_utils.EngineTestBase, test.TestCase): + ComponentBundle = collections.namedtuple('ComponentBundle', + ['board', 'client', + 'persistence', 'conductor']) + def make_components(self, name='testing', wait_timeout=0.1): client = fake_client.FakeClient() persistence = impl_memory.MemoryBackend() @@ -66,10 +70,7 @@ class SingleThreadedConductorTest(test_utils.EngineTestBase, test.TestCase): persistence=persistence) conductor = stc.SingleThreadedConductor(name, board, persistence, wait_timeout=wait_timeout) - return misc.AttrDict(board=board, - client=client, - persistence=persistence, - conductor=conductor) + return self.ComponentBundle(board, client, persistence, conductor) def test_connection(self): components = self.make_components() diff --git a/taskflow/tests/unit/test_utils.py b/taskflow/tests/unit/test_utils.py index e8660bd32..384178107 100644 --- a/taskflow/tests/unit/test_utils.py +++ b/taskflow/tests/unit/test_utils.py @@ -387,109 +387,6 @@ class CachedPropertyTest(test.TestCase): self.assertEqual('b', a.b) -class AttrDictTest(test.TestCase): - def test_ok_create(self): - attrs = { - 'a': 1, - 'b': 2, - } - obj = misc.AttrDict(**attrs) - self.assertEqual(obj.a, 1) - self.assertEqual(obj.b, 2) - - def test_private_create(self): - attrs = { - '_a': 1, - } - self.assertRaises(AttributeError, misc.AttrDict, **attrs) - - def test_invalid_create(self): - attrs = { - # Python attributes can't start with a number. - '123_abc': 1, - } - self.assertRaises(AttributeError, misc.AttrDict, **attrs) - - def test_no_overwrite(self): - attrs = { - # Python attributes can't start with a number. - 'update': 1, - } - self.assertRaises(AttributeError, misc.AttrDict, **attrs) - - def test_back_todict(self): - attrs = { - 'a': 1, - } - obj = misc.AttrDict(**attrs) - self.assertEqual(obj.a, 1) - self.assertEqual(attrs, dict(obj)) - - def test_runtime_invalid_set(self): - - def bad_assign(obj): - obj._123 = 'b' - - attrs = { - 'a': 1, - } - obj = misc.AttrDict(**attrs) - self.assertEqual(obj.a, 1) - self.assertRaises(AttributeError, bad_assign, obj) - - def test_bypass_get(self): - attrs = { - 'a': 1, - } - obj = misc.AttrDict(**attrs) - self.assertEqual(1, obj['a']) - - def test_bypass_set_no_get(self): - - def bad_assign(obj): - obj._b = 'e' - - attrs = { - 'a': 1, - } - obj = misc.AttrDict(**attrs) - self.assertEqual(1, obj['a']) - obj['_b'] = 'c' - self.assertRaises(AttributeError, bad_assign, obj) - self.assertEqual('c', obj['_b']) - - -class IsValidAttributeNameTestCase(test.TestCase): - def test_a_is_ok(self): - self.assertTrue(misc.is_valid_attribute_name('a')) - - def test_name_can_be_longer(self): - self.assertTrue(misc.is_valid_attribute_name('foobarbaz')) - - def test_name_can_have_digits(self): - self.assertTrue(misc.is_valid_attribute_name('fo12')) - - def test_name_cannot_start_with_digit(self): - self.assertFalse(misc.is_valid_attribute_name('1z')) - - def test_hidden_names_are_forbidden(self): - self.assertFalse(misc.is_valid_attribute_name('_z')) - - def test_hidden_names_can_be_allowed(self): - self.assertTrue( - misc.is_valid_attribute_name('_z', allow_hidden=True)) - - def test_self_is_forbidden(self): - self.assertFalse(misc.is_valid_attribute_name('self')) - - def test_self_can_be_allowed(self): - self.assertTrue( - misc.is_valid_attribute_name('self', allow_self=True)) - - def test_no_unicode_please(self): - self.assertFalse(misc.is_valid_attribute_name('maƱana')) - - class UriParseTest(test.TestCase): def test_parse(self): url = "zookeeper://192.168.0.1:2181/a/b/?c=d" @@ -501,11 +398,6 @@ class UriParseTest(test.TestCase): self.assertEqual('/a/b/', parsed.path) self.assertEqual({'c': 'd'}, parsed.params) - def test_multi_params(self): - url = "mysql://www.yahoo.com:3306/a/b/?c=d&c=e" - parsed = misc.parse_uri(url, query_duplicates=True) - self.assertEqual({'c': ['d', 'e']}, parsed.params) - def test_port_provided(self): url = "rabbitmq://www.yahoo.com:5672" parsed = misc.parse_uri(url) diff --git a/taskflow/utils/misc.py b/taskflow/utils/misc.py index a46198812..3822188c7 100644 --- a/taskflow/utils/misc.py +++ b/taskflow/utils/misc.py @@ -19,11 +19,9 @@ import contextlib import datetime import errno import inspect -import keyword import logging import os import re -import string import sys import threading @@ -48,7 +46,23 @@ NUMERIC_TYPES = six.integer_types + (float,) _SCHEME_REGEX = re.compile(r"^([A-Za-z][A-Za-z0-9+.-]*):") -def merge_uri(uri_pieces, conf): +# FIXME(harlowja): This should be removed with the next version of oslo.utils +# which now has this functionality built-in, until then we are deriving from +# there base class and adding this functionality on... +# +# The change was merged @ https://review.openstack.org/#/c/118881/ +class ModifiedSplitResult(netutils._ModifiedSplitResult): + """A split result that exposes the query parameters as a dictionary.""" + + @property + def params(self): + if self.query: + return dict(urlparse.parse_qsl(self.query)) + else: + return {} + + +def merge_uri(uri, conf): """Merges a parsed uri into the given configuration dictionary. Merges the username, password, hostname, and query params of a uri into @@ -57,22 +71,21 @@ def merge_uri(uri_pieces, conf): NOTE(harlowja): does not merge the path, scheme or fragment. """ - for k in ('username', 'password'): - if not uri_pieces[k]: + for (k, v) in [('username', uri.username), ('password', uri.password)]: + if not v: continue - conf.setdefault(k, uri_pieces[k]) - hostname = uri_pieces.get('hostname') - if hostname: - port = uri_pieces.get('port') - if port is not None: - hostname += ":%s" % (port) + conf.setdefault(k, v) + if uri.hostname: + hostname = uri.hostname + if uri.port is not None: + hostname += ":%s" % (uri.port) conf.setdefault('hostname', hostname) - for (k, v) in six.iteritems(uri_pieces['params']): + for (k, v) in six.iteritems(uri.params): conf.setdefault(k, v) return conf -def parse_uri(uri, query_duplicates=False): +def parse_uri(uri): """Parses a uri into its components.""" # Do some basic validation before continuing... if not isinstance(uri, six.string_types): @@ -83,38 +96,10 @@ def parse_uri(uri, query_duplicates=False): if not match: raise ValueError("Uri %r does not start with a RFC 3986 compliant" " scheme" % (uri)) - parsed = netutils.urlsplit(uri) - if parsed.query: - query_params = urlparse.parse_qsl(parsed.query) - if not query_duplicates: - query_params = dict(query_params) - else: - # Retain duplicates in a list for keys which have duplicates, but - # for items which are not duplicated, just associate the key with - # the value. - tmp_query_params = {} - for (k, v) in query_params: - if k in tmp_query_params: - p_v = tmp_query_params[k] - if isinstance(p_v, list): - p_v.append(v) - else: - p_v = [p_v, v] - tmp_query_params[k] = p_v - else: - tmp_query_params[k] = v - query_params = tmp_query_params - else: - query_params = {} - return AttrDict( - scheme=parsed.scheme, - username=parsed.username, - password=parsed.password, - fragment=parsed.fragment, - path=parsed.path, - params=query_params, - hostname=parsed.hostname, - port=parsed.port) + split = netutils.urlsplit(uri) + return ModifiedSplitResult(scheme=split.scheme, fragment=split.fragment, + path=split.path, netloc=split.netloc, + query=split.query) def binary_encode(text, encoding='utf-8'): @@ -271,67 +256,6 @@ def get_duplicate_keys(iterable, key=None): return duplicates -# NOTE(imelnikov): we should not use str.isalpha or str.isdigit -# as they are locale-dependant -_ASCII_WORD_SYMBOLS = frozenset(string.ascii_letters + string.digits + '_') - - -def is_valid_attribute_name(name, allow_self=False, allow_hidden=False): - """Checks that a string is a valid/invalid python attribute name.""" - return all(( - isinstance(name, six.string_types), - len(name) > 0, - (allow_self or not name.lower().startswith('self')), - (allow_hidden or not name.lower().startswith('_')), - - # NOTE(imelnikov): keywords should be forbidden. - not keyword.iskeyword(name), - - # See: http://docs.python.org/release/2.5.2/ref/grammar.txt - not (name[0] in string.digits), - all(symbol in _ASCII_WORD_SYMBOLS for symbol in name) - )) - - -class AttrDict(dict): - """Dictionary subclass that allows for attribute based access. - - This subclass allows for accessing a dictionaries keys and values by - accessing those keys as regular attributes. Keys that are not valid python - attribute names can not of course be acccessed/set (those keys must be - accessed/set by the traditional dictionary indexing operators instead). - """ - NO_ATTRS = tuple(reflection.get_member_names(dict)) - - @classmethod - def _is_valid_attribute_name(cls, name): - if not is_valid_attribute_name(name): - return False - # Make the name just be a simple string in latin-1 encoding in python3. - if name in cls.NO_ATTRS: - return False - return True - - def __init__(self, **kwargs): - for (k, v) in kwargs.items(): - if not self._is_valid_attribute_name(k): - raise AttributeError("Invalid attribute name: '%s'" % (k)) - self[k] = v - - def __getattr__(self, name): - if not self._is_valid_attribute_name(name): - raise AttributeError("Invalid attribute name: '%s'" % (name)) - try: - return self[name] - except KeyError: - raise AttributeError("No attributed named: '%s'" % (name)) - - def __setattr__(self, name, value): - if not self._is_valid_attribute_name(name): - raise AttributeError("Invalid attribute name: '%s'" % (name)) - self[name] = value - - class ExponentialBackoff(object): """An iterable object that will yield back an exponential delay sequence.