diff --git a/taskflow/utils/threading_utils.py b/taskflow/utils/threading_utils.py index 1f3186bf..b2790657 100644 --- a/taskflow/utils/threading_utils.py +++ b/taskflow/utils/threading_utils.py @@ -82,16 +82,19 @@ _ThreadBuilder = collections.namedtuple('_ThreadBuilder', ['thread_factory', 'before_start', 'after_start', 'before_join', 'after_join']) -_ThreadBuilder.callables = tuple([ - # Attribute name -> none allowed as a valid value... - ('thread_factory', False), - ('before_start', True), - ('after_start', True), - ('before_join', True), - ('after_join', True), +_ThreadBuilder.fields = tuple([ + 'thread_factory', + 'before_start', + 'after_start', + 'before_join', + 'after_join', ]) +def no_op(*args, **kwargs): + """Function that does nothing.""" + + class ThreadBundle(object): """A group/bundle of threads that start/stop together.""" @@ -110,13 +113,19 @@ class ThreadBundle(object): in dead-lock since the lock on this object is not meant to be (and is not) reentrant... """ + if before_start is None: + before_start = no_op + if after_start is None: + after_start = no_op + if before_join is None: + before_join = no_op + if after_join is None: + after_join = no_op builder = _ThreadBuilder(thread_factory, before_start, after_start, before_join, after_join) - for attr_name, none_allowed in builder.callables: + for attr_name in builder.fields: cb = getattr(builder, attr_name) - if cb is None and none_allowed: - continue if not six.callable(cb): raise ValueError("Provided callback for argument" " '%s' must be callable" % attr_name) @@ -130,11 +139,6 @@ class ThreadBundle(object): False, ]) - @staticmethod - def _trigger_callback(callback, thread): - if callback is not None: - callback(thread) - def start(self): """Creates & starts all associated threads (that are not running).""" count = 0 @@ -145,11 +149,11 @@ class ThreadBundle(object): continue if not thread: self._threads[i][1] = thread = builder.thread_factory() - self._trigger_callback(builder.before_start, thread) + builder.before_start(thread) thread.start() count += 1 try: - self._trigger_callback(builder.after_start, thread) + builder.after_start(thread) finally: # Just incase the 'after_start' callback blows up make sure # we always set this... @@ -164,11 +168,11 @@ class ThreadBundle(object): for i, (builder, thread, started) in it: if not thread or not started: continue - self._trigger_callback(builder.before_join, thread) + builder.before_join(thread) thread.join() count += 1 try: - self._trigger_callback(builder.after_join, thread) + builder.after_join(thread) finally: # Just incase the 'after_join' callback blows up make sure # we always set/reset these...