Merge "Remove attrdict and just use existing types"

This commit is contained in:
Jenkins 2014-11-19 19:01:31 +00:00 committed by Gerrit Code Review
commit 2fcf67d6b9
8 changed files with 74 additions and 236 deletions

View File

@ -54,12 +54,12 @@ def _extract_engine(**kwargs):
kind = ENGINE_DEFAULT
# See if it's a URI and if so, extract any further options...
try:
pieces = misc.parse_uri(kind)
uri = misc.parse_uri(kind)
except (TypeError, ValueError):
pass
else:
kind = pieces['scheme']
options = misc.merge_uri(pieces, options.copy())
kind = uri.scheme
options = misc.merge_uri(uri, options.copy())
# Merge in any leftover **kwargs into the options, this makes it so that
# the provided **kwargs override any URI or engine_conf specific options.
options.update(kwargs)

View File

@ -14,6 +14,7 @@
# License for the specific language governing permissions and limitations
# under the License.
import collections
import logging
import socket
@ -21,7 +22,6 @@ import kombu
import six
from taskflow.engines.worker_based import dispatcher
from taskflow.utils import misc
from taskflow.utils import threading_utils
LOG = logging.getLogger(__name__)
@ -30,6 +30,16 @@ LOG = logging.getLogger(__name__)
# the socket can get "stuck", and is a best practice for Kombu consumers.
DRAIN_EVENTS_PERIOD = 1
# Helper objects returned when requested to get connection details, used
# instead of returning the raw results from the kombu connection objects
# themselves so that a person can not mutate those objects (which would be
# bad).
_ConnectionDetails = collections.namedtuple('_ConnectionDetails',
['uri', 'transport'])
_TransportDetails = collections.namedtuple('_TransportDetails',
['options', 'driver_type',
'driver_name', 'driver_version'])
class Proxy(object):
"""A proxy processes messages from/to the named exchange."""
@ -71,13 +81,18 @@ class Proxy(object):
driver_version = self._conn.transport.driver_version()
if driver_version and driver_version.lower() == 'n/a':
driver_version = None
return misc.AttrDict(
if self._conn.transport_options:
transport_options = self._conn.transport_options.copy()
else:
transport_options = {}
transport = _TransportDetails(
options=transport_options,
driver_type=self._conn.transport.driver_type,
driver_name=self._conn.transport.driver_name,
driver_version=driver_version)
return _ConnectionDetails(
uri=self._conn.as_uri(include_password=False),
transport=misc.AttrDict(
options=dict(self._conn.transport_options),
driver_type=self._conn.transport.driver_type,
driver_name=self._conn.transport.driver_name,
driver_version=driver_version))
transport=transport)
@property
def is_running(self):

View File

@ -148,6 +148,12 @@ class DeclareSuccess(task.Task):
print("All data processed and sent to %s" % (sent_to))
class DummyUser(object):
def __init__(self, user, id):
self.user = user
self.id = id
# Resources (db handles and similar) of course can *not* be persisted so we
# need to make sure that we pass this resource fetcher to the tasks constructor
# so that the tasks have access to any needed resources (the resources are
@ -168,7 +174,7 @@ flow.add(sub_flow)
# prepopulating this allows the tasks that dependent on the 'request' variable
# to start processing (in this case this is the ExtractInputRequest task).
store = {
'request': misc.AttrDict(user="bob", id="1.35"),
'request': DummyUser(user="bob", id="1.35"),
}
eng = engines.load(flow, engine='serial', store=store)

View File

@ -55,12 +55,12 @@ def fetch(name, conf, namespace=BACKEND_NAMESPACE, **kwargs):
conf = {'board': conf}
board = conf['board']
try:
pieces = misc.parse_uri(board)
uri = misc.parse_uri(board)
except (TypeError, ValueError):
pass
else:
board = pieces['scheme']
conf = misc.merge_uri(pieces, conf.copy())
board = uri.scheme
conf = misc.merge_uri(uri, conf.copy())
LOG.debug('Looking for %r jobboard driver in %r', board, namespace)
try:
mgr = driver.DriverManager(namespace, board,

View File

@ -52,12 +52,12 @@ def fetch(conf, namespace=BACKEND_NAMESPACE, **kwargs):
"""
backend_name = conf['connection']
try:
pieces = misc.parse_uri(backend_name)
uri = misc.parse_uri(backend_name)
except (TypeError, ValueError):
pass
else:
backend_name = pieces['scheme']
conf = misc.merge_uri(pieces, conf.copy())
backend_name = uri.scheme
conf = misc.merge_uri(uri, conf.copy())
LOG.debug('Looking for %r backend driver in %r', backend_name, namespace)
try:
mgr = driver.DriverManager(namespace, backend_name,

View File

@ -14,6 +14,7 @@
# License for the specific language governing permissions and limitations
# under the License.
import collections
import contextlib
import threading
@ -28,7 +29,6 @@ from taskflow.persistence.backends import impl_memory
from taskflow import states as st
from taskflow import test
from taskflow.tests import utils as test_utils
from taskflow.utils import misc
from taskflow.utils import persistence_utils as pu
from taskflow.utils import threading_utils
@ -58,6 +58,10 @@ def make_thread(conductor):
class SingleThreadedConductorTest(test_utils.EngineTestBase, test.TestCase):
ComponentBundle = collections.namedtuple('ComponentBundle',
['board', 'client',
'persistence', 'conductor'])
def make_components(self, name='testing', wait_timeout=0.1):
client = fake_client.FakeClient()
persistence = impl_memory.MemoryBackend()
@ -66,10 +70,7 @@ class SingleThreadedConductorTest(test_utils.EngineTestBase, test.TestCase):
persistence=persistence)
conductor = stc.SingleThreadedConductor(name, board, persistence,
wait_timeout=wait_timeout)
return misc.AttrDict(board=board,
client=client,
persistence=persistence,
conductor=conductor)
return self.ComponentBundle(board, client, persistence, conductor)
def test_connection(self):
components = self.make_components()

View File

@ -387,109 +387,6 @@ class CachedPropertyTest(test.TestCase):
self.assertEqual('b', a.b)
class AttrDictTest(test.TestCase):
def test_ok_create(self):
attrs = {
'a': 1,
'b': 2,
}
obj = misc.AttrDict(**attrs)
self.assertEqual(obj.a, 1)
self.assertEqual(obj.b, 2)
def test_private_create(self):
attrs = {
'_a': 1,
}
self.assertRaises(AttributeError, misc.AttrDict, **attrs)
def test_invalid_create(self):
attrs = {
# Python attributes can't start with a number.
'123_abc': 1,
}
self.assertRaises(AttributeError, misc.AttrDict, **attrs)
def test_no_overwrite(self):
attrs = {
# Python attributes can't start with a number.
'update': 1,
}
self.assertRaises(AttributeError, misc.AttrDict, **attrs)
def test_back_todict(self):
attrs = {
'a': 1,
}
obj = misc.AttrDict(**attrs)
self.assertEqual(obj.a, 1)
self.assertEqual(attrs, dict(obj))
def test_runtime_invalid_set(self):
def bad_assign(obj):
obj._123 = 'b'
attrs = {
'a': 1,
}
obj = misc.AttrDict(**attrs)
self.assertEqual(obj.a, 1)
self.assertRaises(AttributeError, bad_assign, obj)
def test_bypass_get(self):
attrs = {
'a': 1,
}
obj = misc.AttrDict(**attrs)
self.assertEqual(1, obj['a'])
def test_bypass_set_no_get(self):
def bad_assign(obj):
obj._b = 'e'
attrs = {
'a': 1,
}
obj = misc.AttrDict(**attrs)
self.assertEqual(1, obj['a'])
obj['_b'] = 'c'
self.assertRaises(AttributeError, bad_assign, obj)
self.assertEqual('c', obj['_b'])
class IsValidAttributeNameTestCase(test.TestCase):
def test_a_is_ok(self):
self.assertTrue(misc.is_valid_attribute_name('a'))
def test_name_can_be_longer(self):
self.assertTrue(misc.is_valid_attribute_name('foobarbaz'))
def test_name_can_have_digits(self):
self.assertTrue(misc.is_valid_attribute_name('fo12'))
def test_name_cannot_start_with_digit(self):
self.assertFalse(misc.is_valid_attribute_name('1z'))
def test_hidden_names_are_forbidden(self):
self.assertFalse(misc.is_valid_attribute_name('_z'))
def test_hidden_names_can_be_allowed(self):
self.assertTrue(
misc.is_valid_attribute_name('_z', allow_hidden=True))
def test_self_is_forbidden(self):
self.assertFalse(misc.is_valid_attribute_name('self'))
def test_self_can_be_allowed(self):
self.assertTrue(
misc.is_valid_attribute_name('self', allow_self=True))
def test_no_unicode_please(self):
self.assertFalse(misc.is_valid_attribute_name('mañana'))
class UriParseTest(test.TestCase):
def test_parse(self):
url = "zookeeper://192.168.0.1:2181/a/b/?c=d"
@ -501,11 +398,6 @@ class UriParseTest(test.TestCase):
self.assertEqual('/a/b/', parsed.path)
self.assertEqual({'c': 'd'}, parsed.params)
def test_multi_params(self):
url = "mysql://www.yahoo.com:3306/a/b/?c=d&c=e"
parsed = misc.parse_uri(url, query_duplicates=True)
self.assertEqual({'c': ['d', 'e']}, parsed.params)
def test_port_provided(self):
url = "rabbitmq://www.yahoo.com:5672"
parsed = misc.parse_uri(url)

View File

@ -19,11 +19,9 @@ import contextlib
import datetime
import errno
import inspect
import keyword
import logging
import os
import re
import string
import sys
import threading
@ -48,7 +46,23 @@ NUMERIC_TYPES = six.integer_types + (float,)
_SCHEME_REGEX = re.compile(r"^([A-Za-z][A-Za-z0-9+.-]*):")
def merge_uri(uri_pieces, conf):
# FIXME(harlowja): This should be removed with the next version of oslo.utils
# which now has this functionality built-in, until then we are deriving from
# there base class and adding this functionality on...
#
# The change was merged @ https://review.openstack.org/#/c/118881/
class ModifiedSplitResult(netutils._ModifiedSplitResult):
"""A split result that exposes the query parameters as a dictionary."""
@property
def params(self):
if self.query:
return dict(urlparse.parse_qsl(self.query))
else:
return {}
def merge_uri(uri, conf):
"""Merges a parsed uri into the given configuration dictionary.
Merges the username, password, hostname, and query params of a uri into
@ -57,22 +71,21 @@ def merge_uri(uri_pieces, conf):
NOTE(harlowja): does not merge the path, scheme or fragment.
"""
for k in ('username', 'password'):
if not uri_pieces[k]:
for (k, v) in [('username', uri.username), ('password', uri.password)]:
if not v:
continue
conf.setdefault(k, uri_pieces[k])
hostname = uri_pieces.get('hostname')
if hostname:
port = uri_pieces.get('port')
if port is not None:
hostname += ":%s" % (port)
conf.setdefault(k, v)
if uri.hostname:
hostname = uri.hostname
if uri.port is not None:
hostname += ":%s" % (uri.port)
conf.setdefault('hostname', hostname)
for (k, v) in six.iteritems(uri_pieces['params']):
for (k, v) in six.iteritems(uri.params):
conf.setdefault(k, v)
return conf
def parse_uri(uri, query_duplicates=False):
def parse_uri(uri):
"""Parses a uri into its components."""
# Do some basic validation before continuing...
if not isinstance(uri, six.string_types):
@ -83,38 +96,10 @@ def parse_uri(uri, query_duplicates=False):
if not match:
raise ValueError("Uri %r does not start with a RFC 3986 compliant"
" scheme" % (uri))
parsed = netutils.urlsplit(uri)
if parsed.query:
query_params = urlparse.parse_qsl(parsed.query)
if not query_duplicates:
query_params = dict(query_params)
else:
# Retain duplicates in a list for keys which have duplicates, but
# for items which are not duplicated, just associate the key with
# the value.
tmp_query_params = {}
for (k, v) in query_params:
if k in tmp_query_params:
p_v = tmp_query_params[k]
if isinstance(p_v, list):
p_v.append(v)
else:
p_v = [p_v, v]
tmp_query_params[k] = p_v
else:
tmp_query_params[k] = v
query_params = tmp_query_params
else:
query_params = {}
return AttrDict(
scheme=parsed.scheme,
username=parsed.username,
password=parsed.password,
fragment=parsed.fragment,
path=parsed.path,
params=query_params,
hostname=parsed.hostname,
port=parsed.port)
split = netutils.urlsplit(uri)
return ModifiedSplitResult(scheme=split.scheme, fragment=split.fragment,
path=split.path, netloc=split.netloc,
query=split.query)
def binary_encode(text, encoding='utf-8'):
@ -271,67 +256,6 @@ def get_duplicate_keys(iterable, key=None):
return duplicates
# NOTE(imelnikov): we should not use str.isalpha or str.isdigit
# as they are locale-dependant
_ASCII_WORD_SYMBOLS = frozenset(string.ascii_letters + string.digits + '_')
def is_valid_attribute_name(name, allow_self=False, allow_hidden=False):
"""Checks that a string is a valid/invalid python attribute name."""
return all((
isinstance(name, six.string_types),
len(name) > 0,
(allow_self or not name.lower().startswith('self')),
(allow_hidden or not name.lower().startswith('_')),
# NOTE(imelnikov): keywords should be forbidden.
not keyword.iskeyword(name),
# See: http://docs.python.org/release/2.5.2/ref/grammar.txt
not (name[0] in string.digits),
all(symbol in _ASCII_WORD_SYMBOLS for symbol in name)
))
class AttrDict(dict):
"""Dictionary subclass that allows for attribute based access.
This subclass allows for accessing a dictionaries keys and values by
accessing those keys as regular attributes. Keys that are not valid python
attribute names can not of course be acccessed/set (those keys must be
accessed/set by the traditional dictionary indexing operators instead).
"""
NO_ATTRS = tuple(reflection.get_member_names(dict))
@classmethod
def _is_valid_attribute_name(cls, name):
if not is_valid_attribute_name(name):
return False
# Make the name just be a simple string in latin-1 encoding in python3.
if name in cls.NO_ATTRS:
return False
return True
def __init__(self, **kwargs):
for (k, v) in kwargs.items():
if not self._is_valid_attribute_name(k):
raise AttributeError("Invalid attribute name: '%s'" % (k))
self[k] = v
def __getattr__(self, name):
if not self._is_valid_attribute_name(name):
raise AttributeError("Invalid attribute name: '%s'" % (name))
try:
return self[name]
except KeyError:
raise AttributeError("No attributed named: '%s'" % (name))
def __setattr__(self, name, value):
if not self._is_valid_attribute_name(name):
raise AttributeError("Invalid attribute name: '%s'" % (name))
self[name] = value
class ExponentialBackoff(object):
"""An iterable object that will yield back an exponential delay sequence.