Merge "Remove attrdict and just use existing types"
This commit is contained in:
commit
2fcf67d6b9
|
@ -54,12 +54,12 @@ def _extract_engine(**kwargs):
|
|||
kind = ENGINE_DEFAULT
|
||||
# See if it's a URI and if so, extract any further options...
|
||||
try:
|
||||
pieces = misc.parse_uri(kind)
|
||||
uri = misc.parse_uri(kind)
|
||||
except (TypeError, ValueError):
|
||||
pass
|
||||
else:
|
||||
kind = pieces['scheme']
|
||||
options = misc.merge_uri(pieces, options.copy())
|
||||
kind = uri.scheme
|
||||
options = misc.merge_uri(uri, options.copy())
|
||||
# Merge in any leftover **kwargs into the options, this makes it so that
|
||||
# the provided **kwargs override any URI or engine_conf specific options.
|
||||
options.update(kwargs)
|
||||
|
|
|
@ -14,6 +14,7 @@
|
|||
# License for the specific language governing permissions and limitations
|
||||
# under the License.
|
||||
|
||||
import collections
|
||||
import logging
|
||||
import socket
|
||||
|
||||
|
@ -21,7 +22,6 @@ import kombu
|
|||
import six
|
||||
|
||||
from taskflow.engines.worker_based import dispatcher
|
||||
from taskflow.utils import misc
|
||||
from taskflow.utils import threading_utils
|
||||
|
||||
LOG = logging.getLogger(__name__)
|
||||
|
@ -30,6 +30,16 @@ LOG = logging.getLogger(__name__)
|
|||
# the socket can get "stuck", and is a best practice for Kombu consumers.
|
||||
DRAIN_EVENTS_PERIOD = 1
|
||||
|
||||
# Helper objects returned when requested to get connection details, used
|
||||
# instead of returning the raw results from the kombu connection objects
|
||||
# themselves so that a person can not mutate those objects (which would be
|
||||
# bad).
|
||||
_ConnectionDetails = collections.namedtuple('_ConnectionDetails',
|
||||
['uri', 'transport'])
|
||||
_TransportDetails = collections.namedtuple('_TransportDetails',
|
||||
['options', 'driver_type',
|
||||
'driver_name', 'driver_version'])
|
||||
|
||||
|
||||
class Proxy(object):
|
||||
"""A proxy processes messages from/to the named exchange."""
|
||||
|
@ -71,13 +81,18 @@ class Proxy(object):
|
|||
driver_version = self._conn.transport.driver_version()
|
||||
if driver_version and driver_version.lower() == 'n/a':
|
||||
driver_version = None
|
||||
return misc.AttrDict(
|
||||
if self._conn.transport_options:
|
||||
transport_options = self._conn.transport_options.copy()
|
||||
else:
|
||||
transport_options = {}
|
||||
transport = _TransportDetails(
|
||||
options=transport_options,
|
||||
driver_type=self._conn.transport.driver_type,
|
||||
driver_name=self._conn.transport.driver_name,
|
||||
driver_version=driver_version)
|
||||
return _ConnectionDetails(
|
||||
uri=self._conn.as_uri(include_password=False),
|
||||
transport=misc.AttrDict(
|
||||
options=dict(self._conn.transport_options),
|
||||
driver_type=self._conn.transport.driver_type,
|
||||
driver_name=self._conn.transport.driver_name,
|
||||
driver_version=driver_version))
|
||||
transport=transport)
|
||||
|
||||
@property
|
||||
def is_running(self):
|
||||
|
|
|
@ -148,6 +148,12 @@ class DeclareSuccess(task.Task):
|
|||
print("All data processed and sent to %s" % (sent_to))
|
||||
|
||||
|
||||
class DummyUser(object):
|
||||
def __init__(self, user, id):
|
||||
self.user = user
|
||||
self.id = id
|
||||
|
||||
|
||||
# Resources (db handles and similar) of course can *not* be persisted so we
|
||||
# need to make sure that we pass this resource fetcher to the tasks constructor
|
||||
# so that the tasks have access to any needed resources (the resources are
|
||||
|
@ -168,7 +174,7 @@ flow.add(sub_flow)
|
|||
# prepopulating this allows the tasks that dependent on the 'request' variable
|
||||
# to start processing (in this case this is the ExtractInputRequest task).
|
||||
store = {
|
||||
'request': misc.AttrDict(user="bob", id="1.35"),
|
||||
'request': DummyUser(user="bob", id="1.35"),
|
||||
}
|
||||
eng = engines.load(flow, engine='serial', store=store)
|
||||
|
||||
|
|
|
@ -55,12 +55,12 @@ def fetch(name, conf, namespace=BACKEND_NAMESPACE, **kwargs):
|
|||
conf = {'board': conf}
|
||||
board = conf['board']
|
||||
try:
|
||||
pieces = misc.parse_uri(board)
|
||||
uri = misc.parse_uri(board)
|
||||
except (TypeError, ValueError):
|
||||
pass
|
||||
else:
|
||||
board = pieces['scheme']
|
||||
conf = misc.merge_uri(pieces, conf.copy())
|
||||
board = uri.scheme
|
||||
conf = misc.merge_uri(uri, conf.copy())
|
||||
LOG.debug('Looking for %r jobboard driver in %r', board, namespace)
|
||||
try:
|
||||
mgr = driver.DriverManager(namespace, board,
|
||||
|
|
|
@ -52,12 +52,12 @@ def fetch(conf, namespace=BACKEND_NAMESPACE, **kwargs):
|
|||
"""
|
||||
backend_name = conf['connection']
|
||||
try:
|
||||
pieces = misc.parse_uri(backend_name)
|
||||
uri = misc.parse_uri(backend_name)
|
||||
except (TypeError, ValueError):
|
||||
pass
|
||||
else:
|
||||
backend_name = pieces['scheme']
|
||||
conf = misc.merge_uri(pieces, conf.copy())
|
||||
backend_name = uri.scheme
|
||||
conf = misc.merge_uri(uri, conf.copy())
|
||||
LOG.debug('Looking for %r backend driver in %r', backend_name, namespace)
|
||||
try:
|
||||
mgr = driver.DriverManager(namespace, backend_name,
|
||||
|
|
|
@ -14,6 +14,7 @@
|
|||
# License for the specific language governing permissions and limitations
|
||||
# under the License.
|
||||
|
||||
import collections
|
||||
import contextlib
|
||||
import threading
|
||||
|
||||
|
@ -28,7 +29,6 @@ from taskflow.persistence.backends import impl_memory
|
|||
from taskflow import states as st
|
||||
from taskflow import test
|
||||
from taskflow.tests import utils as test_utils
|
||||
from taskflow.utils import misc
|
||||
from taskflow.utils import persistence_utils as pu
|
||||
from taskflow.utils import threading_utils
|
||||
|
||||
|
@ -58,6 +58,10 @@ def make_thread(conductor):
|
|||
|
||||
|
||||
class SingleThreadedConductorTest(test_utils.EngineTestBase, test.TestCase):
|
||||
ComponentBundle = collections.namedtuple('ComponentBundle',
|
||||
['board', 'client',
|
||||
'persistence', 'conductor'])
|
||||
|
||||
def make_components(self, name='testing', wait_timeout=0.1):
|
||||
client = fake_client.FakeClient()
|
||||
persistence = impl_memory.MemoryBackend()
|
||||
|
@ -66,10 +70,7 @@ class SingleThreadedConductorTest(test_utils.EngineTestBase, test.TestCase):
|
|||
persistence=persistence)
|
||||
conductor = stc.SingleThreadedConductor(name, board, persistence,
|
||||
wait_timeout=wait_timeout)
|
||||
return misc.AttrDict(board=board,
|
||||
client=client,
|
||||
persistence=persistence,
|
||||
conductor=conductor)
|
||||
return self.ComponentBundle(board, client, persistence, conductor)
|
||||
|
||||
def test_connection(self):
|
||||
components = self.make_components()
|
||||
|
|
|
@ -387,109 +387,6 @@ class CachedPropertyTest(test.TestCase):
|
|||
self.assertEqual('b', a.b)
|
||||
|
||||
|
||||
class AttrDictTest(test.TestCase):
|
||||
def test_ok_create(self):
|
||||
attrs = {
|
||||
'a': 1,
|
||||
'b': 2,
|
||||
}
|
||||
obj = misc.AttrDict(**attrs)
|
||||
self.assertEqual(obj.a, 1)
|
||||
self.assertEqual(obj.b, 2)
|
||||
|
||||
def test_private_create(self):
|
||||
attrs = {
|
||||
'_a': 1,
|
||||
}
|
||||
self.assertRaises(AttributeError, misc.AttrDict, **attrs)
|
||||
|
||||
def test_invalid_create(self):
|
||||
attrs = {
|
||||
# Python attributes can't start with a number.
|
||||
'123_abc': 1,
|
||||
}
|
||||
self.assertRaises(AttributeError, misc.AttrDict, **attrs)
|
||||
|
||||
def test_no_overwrite(self):
|
||||
attrs = {
|
||||
# Python attributes can't start with a number.
|
||||
'update': 1,
|
||||
}
|
||||
self.assertRaises(AttributeError, misc.AttrDict, **attrs)
|
||||
|
||||
def test_back_todict(self):
|
||||
attrs = {
|
||||
'a': 1,
|
||||
}
|
||||
obj = misc.AttrDict(**attrs)
|
||||
self.assertEqual(obj.a, 1)
|
||||
self.assertEqual(attrs, dict(obj))
|
||||
|
||||
def test_runtime_invalid_set(self):
|
||||
|
||||
def bad_assign(obj):
|
||||
obj._123 = 'b'
|
||||
|
||||
attrs = {
|
||||
'a': 1,
|
||||
}
|
||||
obj = misc.AttrDict(**attrs)
|
||||
self.assertEqual(obj.a, 1)
|
||||
self.assertRaises(AttributeError, bad_assign, obj)
|
||||
|
||||
def test_bypass_get(self):
|
||||
attrs = {
|
||||
'a': 1,
|
||||
}
|
||||
obj = misc.AttrDict(**attrs)
|
||||
self.assertEqual(1, obj['a'])
|
||||
|
||||
def test_bypass_set_no_get(self):
|
||||
|
||||
def bad_assign(obj):
|
||||
obj._b = 'e'
|
||||
|
||||
attrs = {
|
||||
'a': 1,
|
||||
}
|
||||
obj = misc.AttrDict(**attrs)
|
||||
self.assertEqual(1, obj['a'])
|
||||
obj['_b'] = 'c'
|
||||
self.assertRaises(AttributeError, bad_assign, obj)
|
||||
self.assertEqual('c', obj['_b'])
|
||||
|
||||
|
||||
class IsValidAttributeNameTestCase(test.TestCase):
|
||||
def test_a_is_ok(self):
|
||||
self.assertTrue(misc.is_valid_attribute_name('a'))
|
||||
|
||||
def test_name_can_be_longer(self):
|
||||
self.assertTrue(misc.is_valid_attribute_name('foobarbaz'))
|
||||
|
||||
def test_name_can_have_digits(self):
|
||||
self.assertTrue(misc.is_valid_attribute_name('fo12'))
|
||||
|
||||
def test_name_cannot_start_with_digit(self):
|
||||
self.assertFalse(misc.is_valid_attribute_name('1z'))
|
||||
|
||||
def test_hidden_names_are_forbidden(self):
|
||||
self.assertFalse(misc.is_valid_attribute_name('_z'))
|
||||
|
||||
def test_hidden_names_can_be_allowed(self):
|
||||
self.assertTrue(
|
||||
misc.is_valid_attribute_name('_z', allow_hidden=True))
|
||||
|
||||
def test_self_is_forbidden(self):
|
||||
self.assertFalse(misc.is_valid_attribute_name('self'))
|
||||
|
||||
def test_self_can_be_allowed(self):
|
||||
self.assertTrue(
|
||||
misc.is_valid_attribute_name('self', allow_self=True))
|
||||
|
||||
def test_no_unicode_please(self):
|
||||
self.assertFalse(misc.is_valid_attribute_name('mañana'))
|
||||
|
||||
|
||||
class UriParseTest(test.TestCase):
|
||||
def test_parse(self):
|
||||
url = "zookeeper://192.168.0.1:2181/a/b/?c=d"
|
||||
|
@ -501,11 +398,6 @@ class UriParseTest(test.TestCase):
|
|||
self.assertEqual('/a/b/', parsed.path)
|
||||
self.assertEqual({'c': 'd'}, parsed.params)
|
||||
|
||||
def test_multi_params(self):
|
||||
url = "mysql://www.yahoo.com:3306/a/b/?c=d&c=e"
|
||||
parsed = misc.parse_uri(url, query_duplicates=True)
|
||||
self.assertEqual({'c': ['d', 'e']}, parsed.params)
|
||||
|
||||
def test_port_provided(self):
|
||||
url = "rabbitmq://www.yahoo.com:5672"
|
||||
parsed = misc.parse_uri(url)
|
||||
|
|
|
@ -19,11 +19,9 @@ import contextlib
|
|||
import datetime
|
||||
import errno
|
||||
import inspect
|
||||
import keyword
|
||||
import logging
|
||||
import os
|
||||
import re
|
||||
import string
|
||||
import sys
|
||||
import threading
|
||||
|
||||
|
@ -48,7 +46,23 @@ NUMERIC_TYPES = six.integer_types + (float,)
|
|||
_SCHEME_REGEX = re.compile(r"^([A-Za-z][A-Za-z0-9+.-]*):")
|
||||
|
||||
|
||||
def merge_uri(uri_pieces, conf):
|
||||
# FIXME(harlowja): This should be removed with the next version of oslo.utils
|
||||
# which now has this functionality built-in, until then we are deriving from
|
||||
# there base class and adding this functionality on...
|
||||
#
|
||||
# The change was merged @ https://review.openstack.org/#/c/118881/
|
||||
class ModifiedSplitResult(netutils._ModifiedSplitResult):
|
||||
"""A split result that exposes the query parameters as a dictionary."""
|
||||
|
||||
@property
|
||||
def params(self):
|
||||
if self.query:
|
||||
return dict(urlparse.parse_qsl(self.query))
|
||||
else:
|
||||
return {}
|
||||
|
||||
|
||||
def merge_uri(uri, conf):
|
||||
"""Merges a parsed uri into the given configuration dictionary.
|
||||
|
||||
Merges the username, password, hostname, and query params of a uri into
|
||||
|
@ -57,22 +71,21 @@ def merge_uri(uri_pieces, conf):
|
|||
|
||||
NOTE(harlowja): does not merge the path, scheme or fragment.
|
||||
"""
|
||||
for k in ('username', 'password'):
|
||||
if not uri_pieces[k]:
|
||||
for (k, v) in [('username', uri.username), ('password', uri.password)]:
|
||||
if not v:
|
||||
continue
|
||||
conf.setdefault(k, uri_pieces[k])
|
||||
hostname = uri_pieces.get('hostname')
|
||||
if hostname:
|
||||
port = uri_pieces.get('port')
|
||||
if port is not None:
|
||||
hostname += ":%s" % (port)
|
||||
conf.setdefault(k, v)
|
||||
if uri.hostname:
|
||||
hostname = uri.hostname
|
||||
if uri.port is not None:
|
||||
hostname += ":%s" % (uri.port)
|
||||
conf.setdefault('hostname', hostname)
|
||||
for (k, v) in six.iteritems(uri_pieces['params']):
|
||||
for (k, v) in six.iteritems(uri.params):
|
||||
conf.setdefault(k, v)
|
||||
return conf
|
||||
|
||||
|
||||
def parse_uri(uri, query_duplicates=False):
|
||||
def parse_uri(uri):
|
||||
"""Parses a uri into its components."""
|
||||
# Do some basic validation before continuing...
|
||||
if not isinstance(uri, six.string_types):
|
||||
|
@ -83,38 +96,10 @@ def parse_uri(uri, query_duplicates=False):
|
|||
if not match:
|
||||
raise ValueError("Uri %r does not start with a RFC 3986 compliant"
|
||||
" scheme" % (uri))
|
||||
parsed = netutils.urlsplit(uri)
|
||||
if parsed.query:
|
||||
query_params = urlparse.parse_qsl(parsed.query)
|
||||
if not query_duplicates:
|
||||
query_params = dict(query_params)
|
||||
else:
|
||||
# Retain duplicates in a list for keys which have duplicates, but
|
||||
# for items which are not duplicated, just associate the key with
|
||||
# the value.
|
||||
tmp_query_params = {}
|
||||
for (k, v) in query_params:
|
||||
if k in tmp_query_params:
|
||||
p_v = tmp_query_params[k]
|
||||
if isinstance(p_v, list):
|
||||
p_v.append(v)
|
||||
else:
|
||||
p_v = [p_v, v]
|
||||
tmp_query_params[k] = p_v
|
||||
else:
|
||||
tmp_query_params[k] = v
|
||||
query_params = tmp_query_params
|
||||
else:
|
||||
query_params = {}
|
||||
return AttrDict(
|
||||
scheme=parsed.scheme,
|
||||
username=parsed.username,
|
||||
password=parsed.password,
|
||||
fragment=parsed.fragment,
|
||||
path=parsed.path,
|
||||
params=query_params,
|
||||
hostname=parsed.hostname,
|
||||
port=parsed.port)
|
||||
split = netutils.urlsplit(uri)
|
||||
return ModifiedSplitResult(scheme=split.scheme, fragment=split.fragment,
|
||||
path=split.path, netloc=split.netloc,
|
||||
query=split.query)
|
||||
|
||||
|
||||
def binary_encode(text, encoding='utf-8'):
|
||||
|
@ -271,67 +256,6 @@ def get_duplicate_keys(iterable, key=None):
|
|||
return duplicates
|
||||
|
||||
|
||||
# NOTE(imelnikov): we should not use str.isalpha or str.isdigit
|
||||
# as they are locale-dependant
|
||||
_ASCII_WORD_SYMBOLS = frozenset(string.ascii_letters + string.digits + '_')
|
||||
|
||||
|
||||
def is_valid_attribute_name(name, allow_self=False, allow_hidden=False):
|
||||
"""Checks that a string is a valid/invalid python attribute name."""
|
||||
return all((
|
||||
isinstance(name, six.string_types),
|
||||
len(name) > 0,
|
||||
(allow_self or not name.lower().startswith('self')),
|
||||
(allow_hidden or not name.lower().startswith('_')),
|
||||
|
||||
# NOTE(imelnikov): keywords should be forbidden.
|
||||
not keyword.iskeyword(name),
|
||||
|
||||
# See: http://docs.python.org/release/2.5.2/ref/grammar.txt
|
||||
not (name[0] in string.digits),
|
||||
all(symbol in _ASCII_WORD_SYMBOLS for symbol in name)
|
||||
))
|
||||
|
||||
|
||||
class AttrDict(dict):
|
||||
"""Dictionary subclass that allows for attribute based access.
|
||||
|
||||
This subclass allows for accessing a dictionaries keys and values by
|
||||
accessing those keys as regular attributes. Keys that are not valid python
|
||||
attribute names can not of course be acccessed/set (those keys must be
|
||||
accessed/set by the traditional dictionary indexing operators instead).
|
||||
"""
|
||||
NO_ATTRS = tuple(reflection.get_member_names(dict))
|
||||
|
||||
@classmethod
|
||||
def _is_valid_attribute_name(cls, name):
|
||||
if not is_valid_attribute_name(name):
|
||||
return False
|
||||
# Make the name just be a simple string in latin-1 encoding in python3.
|
||||
if name in cls.NO_ATTRS:
|
||||
return False
|
||||
return True
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
for (k, v) in kwargs.items():
|
||||
if not self._is_valid_attribute_name(k):
|
||||
raise AttributeError("Invalid attribute name: '%s'" % (k))
|
||||
self[k] = v
|
||||
|
||||
def __getattr__(self, name):
|
||||
if not self._is_valid_attribute_name(name):
|
||||
raise AttributeError("Invalid attribute name: '%s'" % (name))
|
||||
try:
|
||||
return self[name]
|
||||
except KeyError:
|
||||
raise AttributeError("No attributed named: '%s'" % (name))
|
||||
|
||||
def __setattr__(self, name, value):
|
||||
if not self._is_valid_attribute_name(name):
|
||||
raise AttributeError("Invalid attribute name: '%s'" % (name))
|
||||
self[name] = value
|
||||
|
||||
|
||||
class ExponentialBackoff(object):
|
||||
"""An iterable object that will yield back an exponential delay sequence.
|
||||
|
||||
|
|
Loading…
Reference in New Issue