diff --git a/test_utils.py b/test_utils.py index 27d3c8d..909fe93 100644 --- a/test_utils.py +++ b/test_utils.py @@ -10,6 +10,7 @@ # License for the specific language governing permissions and limitations # under the License. +import logging import sys import six @@ -131,3 +132,75 @@ class PrintTestCase(test_utils.TestCase): if isinstance(output, six.binary_type): output = output.decode('utf-8') self.assertIn(name, output) + + +class TestPositional(test_utils.TestCase): + + @utils.positional(1) + def no_vars(self): + # positional doesn't enforce anything here + return True + + @utils.positional(3, utils.positional.EXCEPT) + def mixed_except(self, arg, kwarg1=None, kwarg2=None): + # self, arg, and kwarg1 may be passed positionally + return (arg, kwarg1, kwarg2) + + @utils.positional(3, utils.positional.WARN) + def mixed_warn(self, arg, kwarg1=None, kwarg2=None): + # self, arg, and kwarg1 may be passed positionally, only a warning + # is emitted + return (arg, kwarg1, kwarg2) + + def test_nothing(self): + self.assertTrue(self.no_vars()) + + def test_mixed_except(self): + self.assertEqual((1, 2, 3), self.mixed_except(1, 2, kwarg2=3)) + self.assertEqual((1, 2, 3), self.mixed_except(1, kwarg1=2, kwarg2=3)) + self.assertEqual((1, None, None), self.mixed_except(1)) + self.assertRaises(TypeError, self.mixed_except, 1, 2, 3) + + def test_mixed_warn(self): + logger_message = six.moves.cStringIO() + handler = logging.StreamHandler(logger_message) + handler.setLevel(logging.DEBUG) + + logger = logging.getLogger(utils.__name__) + level = logger.getEffectiveLevel() + logger.setLevel(logging.DEBUG) + logger.addHandler(handler) + + self.addCleanup(logger.removeHandler, handler) + self.addCleanup(logger.setLevel, level) + + self.mixed_warn(1, 2, 3) + + self.assertIn('takes at most 3 positional', logger_message.getvalue()) + + @utils.positional(enforcement=utils.positional.EXCEPT) + def inspect_func(self, arg, kwarg=None): + return (arg, kwarg) + + def test_inspect_positions(self): + self.assertEqual((1, None), self.inspect_func(1)) + self.assertEqual((1, 2), self.inspect_func(1, kwarg=2)) + self.assertRaises(TypeError, self.inspect_func) + self.assertRaises(TypeError, self.inspect_func, 1, 2) + + @utils.positional.classmethod(1) + def class_method(cls, a, b): + return (cls, a, b) + + @utils.positional.method(1) + def normal_method(self, a, b): + self.assertIsInstance(self, TestPositional) + return (self, a, b) + + def test_class_method(self): + self.assertEqual((TestPositional, 1, 2), self.class_method(1, b=2)) + self.assertRaises(TypeError, self.class_method, 1, 2) + + def test_normal_method(self): + self.assertEqual((self, 1, 2), self.normal_method(1, b=2)) + self.assertRaises(TypeError, self.normal_method, 1, 2)