From fa6bac3d2b7bde96b91e34b3afe4c67b2f4f6353 Mon Sep 17 00:00:00 2001 From: Joshua Harlow Date: Thu, 20 Feb 2014 18:22:09 -0800 Subject: [PATCH] Move endpoint subclass finding to reflection util The endpoint finding uses a utility method that can be generalized into finding subclasses of a given list of items, so instead of having this logic be in the worker just move it to reflection instead. Change-Id: Id79efd9514f7b2f7713a1426a7e6ed0861ecee70 --- taskflow/engines/worker_based/worker.py | 34 +-------------- taskflow/utils/reflection.py | 56 ++++++++++++++++++++++--- 2 files changed, 51 insertions(+), 39 deletions(-) diff --git a/taskflow/engines/worker_based/worker.py b/taskflow/engines/worker_based/worker.py index 17e3e103..ef967d88 100644 --- a/taskflow/engines/worker_based/worker.py +++ b/taskflow/engines/worker_based/worker.py @@ -16,17 +16,12 @@ # License for the specific language governing permissions and limitations # under the License. -import inspect import logging -import types from concurrent import futures -import six - from taskflow.engines.worker_based import endpoint from taskflow.engines.worker_based import server -from taskflow.openstack.common import importutils from taskflow import task as t_task from taskflow.utils import reflection from taskflow.utils import threading_utils as tu @@ -84,34 +79,7 @@ class Worker(object): @staticmethod def _derive_endpoints(tasks): """Derive endpoints from list of strings, classes or packages.""" - derived_tasks = set() - for item in tasks: - module = None - if isinstance(item, six.string_types): - try: - pkg, cls = item.split(':') - except ValueError: - module = importutils.import_module(item) - else: - obj = importutils.import_class('%s.%s' % (pkg, cls)) - if not reflection.is_subclass(obj, t_task.BaseTask): - raise TypeError("Item %s is not a BaseTask subclass" % - item) - derived_tasks.add(obj) - elif isinstance(item, types.ModuleType): - module = item - elif reflection.is_subclass(item, t_task.BaseTask): - derived_tasks.add(item) - else: - raise TypeError("Item %s unexpected type: %s" % - (item, type(item))) - - # derive tasks - if module is not None: - for name, obj in inspect.getmembers(module): - if reflection.is_subclass(obj, t_task.BaseTask): - derived_tasks.add(obj) - + derived_tasks = reflection.find_subclasses(tasks, t_task.BaseTask) return [endpoint.Endpoint(task) for task in derived_tasks] def run(self): diff --git a/taskflow/utils/reflection.py b/taskflow/utils/reflection.py index 5e6e8a6f..132575ce 100644 --- a/taskflow/utils/reflection.py +++ b/taskflow/utils/reflection.py @@ -19,15 +19,59 @@ import types import six +from taskflow.openstack.common import importutils + + +def _get_members(obj, exclude_hidden): + """Yields the members of an object, filtering by hidden/not hidden.""" + for (name, value) in inspect.getmembers(obj): + if name.startswith("_") and exclude_hidden: + continue + yield (name, value) + + +def find_subclasses(locations, base_cls, exclude_hidden=True): + """Examines the given locations for types which are subclasses of the base + class type provided and returns the found subclasses. + + If a string is provided as one of the locations it will be imported and + examined if it is a subclass of the base class. If a module is given, + all of its members will be examined for attributes which are subclasses of + the base class. If a type itself is given it will be examined for being a + subclass of the base class. + """ + derived = set() + for item in locations: + module = None + if isinstance(item, six.string_types): + try: + pkg, cls = item.split(':') + except ValueError: + module = importutils.import_module(item) + else: + obj = importutils.import_class('%s.%s' % (pkg, cls)) + if not is_subclass(obj, base_cls): + raise TypeError("Item %s is not a %s subclass" % + (item, base_cls)) + derived.add(obj) + elif isinstance(item, types.ModuleType): + module = item + elif is_subclass(item, base_cls): + derived.add(item) + else: + raise TypeError("Item %s unexpected type: %s" % + (item, type(item))) + # If it's a module derive objects from it if we can. + if module is not None: + for (_name, obj) in _get_members(module, exclude_hidden): + if is_subclass(obj, base_cls): + derived.add(obj) + return derived + def get_member_names(obj, exclude_hidden=True): """Get all the member names for a object.""" - names = [] - for (name, _value) in inspect.getmembers(obj): - if exclude_hidden and name.startswith("_"): - continue - names.append(name) - return sorted(names) + return [name for (name, _obj) in _get_members(obj, exclude_hidden)] def get_class_name(obj):