 aea1f401eb
			
		
	
	aea1f401eb
	
	
	
		
			
			* Added missing period for doc strings * Correct syntax errors * Remove H402 from flake8 ignore list Change-Id: Ia8592bf99378e3658d6cca2ceb148bf9eb0b5de8
		
			
				
	
	
		
			500 lines
		
	
	
		
			15 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			500 lines
		
	
	
		
			15 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
| # -*- coding: utf-8 -*-
 | |
| 
 | |
| # vim: tabstop=4 shiftwidth=4 softtabstop=4
 | |
| 
 | |
| #    Copyright (C) 2012 Yahoo! Inc. All Rights Reserved.
 | |
| #
 | |
| #    Licensed under the Apache License, Version 2.0 (the "License"); you may
 | |
| #    not use this file except in compliance with the License. You may obtain
 | |
| #    a copy of the License at
 | |
| #
 | |
| #         http://www.apache.org/licenses/LICENSE-2.0
 | |
| #
 | |
| #    Unless required by applicable law or agreed to in writing, software
 | |
| #    distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
 | |
| #    WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
 | |
| #    License for the specific language governing permissions and limitations
 | |
| #    under the License.
 | |
| 
 | |
| import collections
 | |
| import functools
 | |
| import sys
 | |
| import time
 | |
| 
 | |
| from taskflow import states
 | |
| from taskflow import test
 | |
| from taskflow.tests import utils as test_utils
 | |
| from taskflow.utils import lock_utils
 | |
| from taskflow.utils import misc
 | |
| from taskflow.utils import reflection
 | |
| 
 | |
| 
 | |
| def mere_function(a, b):
 | |
|     pass
 | |
| 
 | |
| 
 | |
| def function_with_defs(a, b, optional=None):
 | |
|     pass
 | |
| 
 | |
| 
 | |
| def function_with_kwargs(a, b, **kwargs):
 | |
|     pass
 | |
| 
 | |
| 
 | |
| class Class(object):
 | |
| 
 | |
|     def method(self, c, d):
 | |
|         pass
 | |
| 
 | |
|     @staticmethod
 | |
|     def static_method(e, f):
 | |
|         pass
 | |
| 
 | |
|     @classmethod
 | |
|     def class_method(cls, g, h):
 | |
|         pass
 | |
| 
 | |
| 
 | |
| class CallableClass(object):
 | |
|     def __call__(self, i, j):
 | |
|         pass
 | |
| 
 | |
| 
 | |
| class ClassWithInit(object):
 | |
|     def __init__(self, k, l):
 | |
|         pass
 | |
| 
 | |
| 
 | |
| class CallbackEqualityTest(test.TestCase):
 | |
|     def test_different_simple_callbacks(self):
 | |
| 
 | |
|         def a():
 | |
|             pass
 | |
| 
 | |
|         def b():
 | |
|             pass
 | |
| 
 | |
|         self.assertFalse(reflection.is_same_callback(a, b))
 | |
| 
 | |
|     def test_static_instance_callbacks(self):
 | |
| 
 | |
|         class A(object):
 | |
| 
 | |
|             @staticmethod
 | |
|             def b(a, b, c):
 | |
|                 pass
 | |
| 
 | |
|         a = A()
 | |
|         b = A()
 | |
| 
 | |
|         self.assertTrue(reflection.is_same_callback(a.b, b.b))
 | |
| 
 | |
|     def test_different_instance_callbacks(self):
 | |
| 
 | |
|         class A(object):
 | |
|             def b(self):
 | |
|                 pass
 | |
| 
 | |
|             def __eq__(self, other):
 | |
|                 return True
 | |
| 
 | |
|         b = A()
 | |
|         c = A()
 | |
| 
 | |
|         self.assertFalse(reflection.is_same_callback(b.b, c.b))
 | |
|         self.assertTrue(reflection.is_same_callback(b.b, c.b, strict=False))
 | |
| 
 | |
| 
 | |
| class GetCallableNameTest(test.TestCase):
 | |
| 
 | |
|     def test_mere_function(self):
 | |
|         name = reflection.get_callable_name(mere_function)
 | |
|         self.assertEqual(name, '.'.join((__name__, 'mere_function')))
 | |
| 
 | |
|     def test_method(self):
 | |
|         name = reflection.get_callable_name(Class.method)
 | |
|         self.assertEqual(name, '.'.join((__name__, 'method')))
 | |
| 
 | |
|     def test_instance_method(self):
 | |
|         name = reflection.get_callable_name(Class().method)
 | |
|         self.assertEqual(name, '.'.join((__name__, 'Class', 'method')))
 | |
| 
 | |
|     def test_static_method(self):
 | |
|         # NOTE(imelnikov): static method are just functions, class name
 | |
|         # is not recorded anywhere in them.
 | |
|         name = reflection.get_callable_name(Class.static_method)
 | |
|         self.assertEqual(name, '.'.join((__name__, 'static_method')))
 | |
| 
 | |
|     def test_class_method(self):
 | |
|         name = reflection.get_callable_name(Class.class_method)
 | |
|         self.assertEqual(name, '.'.join((__name__, 'Class', 'class_method')))
 | |
| 
 | |
|     def test_constructor(self):
 | |
|         name = reflection.get_callable_name(Class)
 | |
|         self.assertEqual(name, '.'.join((__name__, 'Class')))
 | |
| 
 | |
|     def test_callable_class(self):
 | |
|         name = reflection.get_callable_name(CallableClass())
 | |
|         self.assertEqual(name, '.'.join((__name__, 'CallableClass')))
 | |
| 
 | |
|     def test_callable_class_call(self):
 | |
|         name = reflection.get_callable_name(CallableClass().__call__)
 | |
|         self.assertEqual(name, '.'.join((__name__, 'CallableClass',
 | |
|                                          '__call__')))
 | |
| 
 | |
| 
 | |
| class NotifierTest(test.TestCase):
 | |
| 
 | |
|     def test_notify_called(self):
 | |
|         call_collector = []
 | |
| 
 | |
|         def call_me(state, details):
 | |
|             call_collector.append((state, details))
 | |
| 
 | |
|         notifier = misc.TransitionNotifier()
 | |
|         notifier.register(misc.TransitionNotifier.ANY, call_me)
 | |
|         notifier.notify(states.SUCCESS, {})
 | |
|         notifier.notify(states.SUCCESS, {})
 | |
| 
 | |
|         self.assertEqual(2, len(call_collector))
 | |
|         self.assertEqual(1, len(notifier))
 | |
| 
 | |
|     def test_notify_register_deregister(self):
 | |
| 
 | |
|         def call_me(state, details):
 | |
|             pass
 | |
| 
 | |
|         class A(object):
 | |
|             def call_me_too(self, state, details):
 | |
|                 pass
 | |
| 
 | |
|         notifier = misc.TransitionNotifier()
 | |
|         notifier.register(misc.TransitionNotifier.ANY, call_me)
 | |
|         a = A()
 | |
|         notifier.register(misc.TransitionNotifier.ANY, a.call_me_too)
 | |
| 
 | |
|         self.assertEqual(2, len(notifier))
 | |
|         notifier.deregister(misc.TransitionNotifier.ANY, call_me)
 | |
|         notifier.deregister(misc.TransitionNotifier.ANY, a.call_me_too)
 | |
|         self.assertEqual(0, len(notifier))
 | |
| 
 | |
|     def test_notify_reset(self):
 | |
| 
 | |
|         def call_me(state, details):
 | |
|             pass
 | |
| 
 | |
|         notifier = misc.TransitionNotifier()
 | |
|         notifier.register(misc.TransitionNotifier.ANY, call_me)
 | |
|         self.assertEqual(1, len(notifier))
 | |
| 
 | |
|         notifier.reset()
 | |
|         self.assertEqual(0, len(notifier))
 | |
| 
 | |
|     def test_bad_notify(self):
 | |
| 
 | |
|         def call_me(state, details):
 | |
|             pass
 | |
| 
 | |
|         notifier = misc.TransitionNotifier()
 | |
|         self.assertRaises(KeyError, notifier.register,
 | |
|                           misc.TransitionNotifier.ANY, call_me,
 | |
|                           kwargs={'details': 5})
 | |
| 
 | |
|     def test_selective_notify(self):
 | |
|         call_counts = collections.defaultdict(list)
 | |
| 
 | |
|         def call_me_on(registered_state, state, details):
 | |
|             call_counts[registered_state].append((state, details))
 | |
| 
 | |
|         notifier = misc.TransitionNotifier()
 | |
|         notifier.register(states.SUCCESS,
 | |
|                           functools.partial(call_me_on, states.SUCCESS))
 | |
|         notifier.register(misc.TransitionNotifier.ANY,
 | |
|                           functools.partial(call_me_on,
 | |
|                                             misc.TransitionNotifier.ANY))
 | |
| 
 | |
|         self.assertEqual(2, len(notifier))
 | |
|         notifier.notify(states.SUCCESS, {})
 | |
| 
 | |
|         self.assertEqual(1, len(call_counts[misc.TransitionNotifier.ANY]))
 | |
|         self.assertEqual(1, len(call_counts[states.SUCCESS]))
 | |
| 
 | |
|         notifier.notify(states.FAILURE, {})
 | |
|         self.assertEqual(2, len(call_counts[misc.TransitionNotifier.ANY]))
 | |
|         self.assertEqual(1, len(call_counts[states.SUCCESS]))
 | |
|         self.assertEqual(2, len(call_counts))
 | |
| 
 | |
| 
 | |
| class GetCallableArgsTest(test.TestCase):
 | |
| 
 | |
|     def test_mere_function(self):
 | |
|         result = reflection.get_callable_args(mere_function)
 | |
|         self.assertEqual(['a', 'b'], result)
 | |
| 
 | |
|     def test_function_with_defaults(self):
 | |
|         result = reflection.get_callable_args(function_with_defs)
 | |
|         self.assertEqual(['a', 'b', 'optional'], result)
 | |
| 
 | |
|     def test_required_only(self):
 | |
|         result = reflection.get_callable_args(function_with_defs,
 | |
|                                               required_only=True)
 | |
|         self.assertEqual(['a', 'b'], result)
 | |
| 
 | |
|     def test_method(self):
 | |
|         result = reflection.get_callable_args(Class.method)
 | |
|         self.assertEqual(['self', 'c', 'd'], result)
 | |
| 
 | |
|     def test_instance_method(self):
 | |
|         result = reflection.get_callable_args(Class().method)
 | |
|         self.assertEqual(['c', 'd'], result)
 | |
| 
 | |
|     def test_class_method(self):
 | |
|         result = reflection.get_callable_args(Class.class_method)
 | |
|         self.assertEqual(['g', 'h'], result)
 | |
| 
 | |
|     def test_class_constructor(self):
 | |
|         result = reflection.get_callable_args(ClassWithInit)
 | |
|         self.assertEqual(['k', 'l'], result)
 | |
| 
 | |
|     def test_class_with_call(self):
 | |
|         result = reflection.get_callable_args(CallableClass())
 | |
|         self.assertEqual(['i', 'j'], result)
 | |
| 
 | |
|     def test_decorators_work(self):
 | |
|         @lock_utils.locked
 | |
|         def locked_fun(x, y):
 | |
|             pass
 | |
|         result = reflection.get_callable_args(locked_fun)
 | |
|         self.assertEqual(['x', 'y'], result)
 | |
| 
 | |
| 
 | |
| class AcceptsKwargsTest(test.TestCase):
 | |
| 
 | |
|     def test_no_kwargs(self):
 | |
|         self.assertEqual(
 | |
|             reflection.accepts_kwargs(mere_function), False)
 | |
| 
 | |
|     def test_with_kwargs(self):
 | |
|         self.assertEqual(
 | |
|             reflection.accepts_kwargs(function_with_kwargs), True)
 | |
| 
 | |
| 
 | |
| class GetClassNameTest(test.TestCase):
 | |
| 
 | |
|     def test_std_exception(self):
 | |
|         name = reflection.get_class_name(RuntimeError)
 | |
|         self.assertEqual(name, 'RuntimeError')
 | |
| 
 | |
|     def test_global_class(self):
 | |
|         name = reflection.get_class_name(misc.Failure)
 | |
|         self.assertEqual(name, 'taskflow.utils.misc.Failure')
 | |
| 
 | |
|     def test_class(self):
 | |
|         name = reflection.get_class_name(Class)
 | |
|         self.assertEqual(name, '.'.join((__name__, 'Class')))
 | |
| 
 | |
|     def test_instance(self):
 | |
|         name = reflection.get_class_name(Class())
 | |
|         self.assertEqual(name, '.'.join((__name__, 'Class')))
 | |
| 
 | |
|     def test_int(self):
 | |
|         name = reflection.get_class_name(42)
 | |
|         self.assertEqual(name, 'int')
 | |
| 
 | |
| 
 | |
| class GetAllClassNamesTest(test.TestCase):
 | |
| 
 | |
|     def test_std_class(self):
 | |
|         names = list(reflection.get_all_class_names(RuntimeError))
 | |
|         self.assertEqual(names, test_utils.RUNTIME_ERROR_CLASSES)
 | |
| 
 | |
|     def test_std_class_up_to(self):
 | |
|         names = list(reflection.get_all_class_names(RuntimeError,
 | |
|                                                     up_to=Exception))
 | |
|         self.assertEqual(names, test_utils.RUNTIME_ERROR_CLASSES[:-2])
 | |
| 
 | |
| 
 | |
| class AttrDictTest(test.TestCase):
 | |
|     def test_ok_create(self):
 | |
|         attrs = {
 | |
|             'a': 1,
 | |
|             'b': 2,
 | |
|         }
 | |
|         obj = misc.AttrDict(**attrs)
 | |
|         self.assertEqual(obj.a, 1)
 | |
|         self.assertEqual(obj.b, 2)
 | |
| 
 | |
|     def test_private_create(self):
 | |
|         attrs = {
 | |
|             '_a': 1,
 | |
|         }
 | |
|         self.assertRaises(AttributeError, misc.AttrDict, **attrs)
 | |
| 
 | |
|     def test_invalid_create(self):
 | |
|         attrs = {
 | |
|             # Python attributes can't start with a number.
 | |
|             '123_abc': 1,
 | |
|         }
 | |
|         self.assertRaises(AttributeError, misc.AttrDict, **attrs)
 | |
| 
 | |
|     def test_no_overwrite(self):
 | |
|         attrs = {
 | |
|             # Python attributes can't start with a number.
 | |
|             'update': 1,
 | |
|         }
 | |
|         self.assertRaises(AttributeError, misc.AttrDict, **attrs)
 | |
| 
 | |
|     def test_back_todict(self):
 | |
|         attrs = {
 | |
|             'a': 1,
 | |
|         }
 | |
|         obj = misc.AttrDict(**attrs)
 | |
|         self.assertEqual(obj.a, 1)
 | |
|         self.assertEqual(attrs, dict(obj))
 | |
| 
 | |
|     def test_runtime_invalid_set(self):
 | |
| 
 | |
|         def bad_assign(obj):
 | |
|             obj._123 = 'b'
 | |
| 
 | |
|         attrs = {
 | |
|             'a': 1,
 | |
|         }
 | |
|         obj = misc.AttrDict(**attrs)
 | |
|         self.assertEqual(obj.a, 1)
 | |
|         self.assertRaises(AttributeError, bad_assign, obj)
 | |
| 
 | |
|     def test_bypass_get(self):
 | |
|         attrs = {
 | |
|             'a': 1,
 | |
|         }
 | |
|         obj = misc.AttrDict(**attrs)
 | |
|         self.assertEqual(1, obj['a'])
 | |
| 
 | |
|     def test_bypass_set_no_get(self):
 | |
| 
 | |
|         def bad_assign(obj):
 | |
|             obj._b = 'e'
 | |
| 
 | |
|         attrs = {
 | |
|             'a': 1,
 | |
|         }
 | |
|         obj = misc.AttrDict(**attrs)
 | |
|         self.assertEqual(1, obj['a'])
 | |
|         obj['_b'] = 'c'
 | |
|         self.assertRaises(AttributeError, bad_assign, obj)
 | |
|         self.assertEqual('c', obj['_b'])
 | |
| 
 | |
| 
 | |
| class IsValidAttributeNameTestCase(test.TestCase):
 | |
|     def test_a_is_ok(self):
 | |
|         self.assertTrue(misc.is_valid_attribute_name('a'))
 | |
| 
 | |
|     def test_name_can_be_longer(self):
 | |
|         self.assertTrue(misc.is_valid_attribute_name('foobarbaz'))
 | |
| 
 | |
|     def test_name_can_have_digits(self):
 | |
|         self.assertTrue(misc.is_valid_attribute_name('fo12'))
 | |
| 
 | |
|     def test_name_cannot_start_with_digit(self):
 | |
|         self.assertFalse(misc.is_valid_attribute_name('1z'))
 | |
| 
 | |
|     def test_hidden_names_are_forbidden(self):
 | |
|         self.assertFalse(misc.is_valid_attribute_name('_z'))
 | |
| 
 | |
|     def test_hidden_names_can_be_allowed(self):
 | |
|         self.assertTrue(
 | |
|             misc.is_valid_attribute_name('_z', allow_hidden=True))
 | |
| 
 | |
|     def test_self_is_forbidden(self):
 | |
|         self.assertFalse(misc.is_valid_attribute_name('self'))
 | |
| 
 | |
|     def test_self_can_be_allowed(self):
 | |
|         self.assertTrue(
 | |
|             misc.is_valid_attribute_name('self', allow_self=True))
 | |
| 
 | |
|     def test_no_unicode_please(self):
 | |
|         self.assertFalse(misc.is_valid_attribute_name('mañana'))
 | |
| 
 | |
| 
 | |
| class StopWatchUtilsTest(test.TestCase):
 | |
|     def test_no_states(self):
 | |
|         watch = misc.StopWatch()
 | |
|         self.assertRaises(RuntimeError, watch.stop)
 | |
|         self.assertRaises(RuntimeError, watch.resume)
 | |
| 
 | |
|     def test_expiry(self):
 | |
|         watch = misc.StopWatch(0.1)
 | |
|         watch.start()
 | |
|         time.sleep(0.2)
 | |
|         self.assertTrue(watch.expired())
 | |
| 
 | |
|     def test_no_expiry(self):
 | |
|         watch = misc.StopWatch(0.1)
 | |
|         watch.start()
 | |
|         self.assertFalse(watch.expired())
 | |
| 
 | |
|     def test_elapsed(self):
 | |
|         watch = misc.StopWatch()
 | |
|         watch.start()
 | |
|         time.sleep(0.2)
 | |
|         # NOTE(harlowja): Allow for a slight variation by using 0.19.
 | |
|         self.assertGreaterEqual(0.19, watch.elapsed())
 | |
| 
 | |
|     def test_pause_resume(self):
 | |
|         watch = misc.StopWatch()
 | |
|         watch.start()
 | |
|         time.sleep(0.05)
 | |
|         watch.stop()
 | |
|         elapsed = watch.elapsed()
 | |
|         time.sleep(0.05)
 | |
|         self.assertAlmostEqual(elapsed, watch.elapsed())
 | |
|         watch.resume()
 | |
|         self.assertNotEqual(elapsed, watch.elapsed())
 | |
| 
 | |
|     def test_context_manager(self):
 | |
|         with misc.StopWatch() as watch:
 | |
|             time.sleep(0.05)
 | |
|         self.assertGreater(0.01, watch.elapsed())
 | |
| 
 | |
| 
 | |
| class ExcInfoUtilsTest(test.TestCase):
 | |
| 
 | |
|     def _make_ex_info(self):
 | |
|         try:
 | |
|             raise RuntimeError('Woot!')
 | |
|         except Exception:
 | |
|             return sys.exc_info()
 | |
| 
 | |
|     def test_copy_none(self):
 | |
|         result = misc.copy_exc_info(None)
 | |
|         self.assertIsNone(result)
 | |
| 
 | |
|     def test_copy_exc_info(self):
 | |
|         exc_info = self._make_ex_info()
 | |
|         result = misc.copy_exc_info(exc_info)
 | |
|         self.assertIsNot(result, exc_info)
 | |
|         self.assertIs(result[0], RuntimeError)
 | |
|         self.assertIsNot(result[1], exc_info[1])
 | |
|         self.assertIs(result[2], exc_info[2])
 | |
| 
 | |
|     def test_none_equals(self):
 | |
|         self.assertTrue(misc.are_equal_exc_info_tuples(None, None))
 | |
| 
 | |
|     def test_none_ne_tuple(self):
 | |
|         exc_info = self._make_ex_info()
 | |
|         self.assertFalse(misc.are_equal_exc_info_tuples(None, exc_info))
 | |
| 
 | |
|     def test_tuple_nen_none(self):
 | |
|         exc_info = self._make_ex_info()
 | |
|         self.assertFalse(misc.are_equal_exc_info_tuples(exc_info, None))
 | |
| 
 | |
|     def test_tuple_equals_itself(self):
 | |
|         exc_info = self._make_ex_info()
 | |
|         self.assertTrue(misc.are_equal_exc_info_tuples(exc_info, exc_info))
 | |
| 
 | |
|     def test_typle_equals_copy(self):
 | |
|         exc_info = self._make_ex_info()
 | |
|         copied = misc.copy_exc_info(exc_info)
 | |
|         self.assertTrue(misc.are_equal_exc_info_tuples(exc_info, copied))
 |