diff --git a/taskflow/engines/worker_based/worker.py b/taskflow/engines/worker_based/worker.py index 2232a0d0..1f133e12 100644 --- a/taskflow/engines/worker_based/worker.py +++ b/taskflow/engines/worker_based/worker.py @@ -14,17 +14,12 @@ # License for the specific language governing permissions and limitations # under the License. -import inspect import logging -import types from concurrent import futures -import six - from taskflow.engines.worker_based import endpoint from taskflow.engines.worker_based import server -from taskflow.openstack.common import importutils from taskflow import task as t_task from taskflow.utils import reflection from taskflow.utils import threading_utils as tu @@ -87,34 +82,7 @@ class Worker(object): @staticmethod def _derive_endpoints(tasks): """Derive endpoints from list of strings, classes or packages.""" - derived_tasks = set() - for item in tasks: - module = None - if isinstance(item, six.string_types): - try: - pkg, cls = item.split(':') - except ValueError: - module = importutils.import_module(item) - else: - obj = importutils.import_class('%s.%s' % (pkg, cls)) - if not reflection.is_subclass(obj, t_task.BaseTask): - raise TypeError("Item %s is not a BaseTask subclass" % - item) - derived_tasks.add(obj) - elif isinstance(item, types.ModuleType): - module = item - elif reflection.is_subclass(item, t_task.BaseTask): - derived_tasks.add(item) - else: - raise TypeError("Item %s unexpected type: %s" % - (item, type(item))) - - # derive tasks - if module is not None: - for name, obj in inspect.getmembers(module): - if reflection.is_subclass(obj, t_task.BaseTask): - derived_tasks.add(obj) - + derived_tasks = reflection.find_subclasses(tasks, t_task.BaseTask) return [endpoint.Endpoint(task) for task in derived_tasks] def run(self): diff --git a/taskflow/utils/reflection.py b/taskflow/utils/reflection.py index 824352a8..a5d80b55 100644 --- a/taskflow/utils/reflection.py +++ b/taskflow/utils/reflection.py @@ -19,15 +19,59 @@ import types import six +from taskflow.openstack.common import importutils + + +def _get_members(obj, exclude_hidden): + """Yields the members of an object, filtering by hidden/not hidden.""" + for (name, value) in inspect.getmembers(obj): + if name.startswith("_") and exclude_hidden: + continue + yield (name, value) + + +def find_subclasses(locations, base_cls, exclude_hidden=True): + """Examines the given locations for types which are subclasses of the base + class type provided and returns the found subclasses. + + If a string is provided as one of the locations it will be imported and + examined if it is a subclass of the base class. If a module is given, + all of its members will be examined for attributes which are subclasses of + the base class. If a type itself is given it will be examined for being a + subclass of the base class. + """ + derived = set() + for item in locations: + module = None + if isinstance(item, six.string_types): + try: + pkg, cls = item.split(':') + except ValueError: + module = importutils.import_module(item) + else: + obj = importutils.import_class('%s.%s' % (pkg, cls)) + if not is_subclass(obj, base_cls): + raise TypeError("Item %s is not a %s subclass" % + (item, base_cls)) + derived.add(obj) + elif isinstance(item, types.ModuleType): + module = item + elif is_subclass(item, base_cls): + derived.add(item) + else: + raise TypeError("Item %s unexpected type: %s" % + (item, type(item))) + # If it's a module derive objects from it if we can. + if module is not None: + for (_name, obj) in _get_members(module, exclude_hidden): + if is_subclass(obj, base_cls): + derived.add(obj) + return derived + def get_member_names(obj, exclude_hidden=True): """Get all the member names for a object.""" - names = [] - for (name, _value) in inspect.getmembers(obj): - if exclude_hidden and name.startswith("_"): - continue - names.append(name) - return sorted(names) + return [name for (name, _obj) in _get_members(obj, exclude_hidden)] def get_class_name(obj):