From eddd4c76d3f347f080a07d9d552714dd1d9b42b5 Mon Sep 17 00:00:00 2001 From: kaido Date: Thu, 9 Jul 2015 15:38:43 -0700 Subject: [PATCH 1/2] Add initial draft of @wrap.check decorator --- docs/wrapping.rst | 19 ++++++++++++++++++ pint/testsuite/test_unit.py | 39 +++++++++++++++++++++++++++++++++++++ pint/unit.py | 30 ++++++++++++++++++++++++++++ 3 files changed, 88 insertions(+) diff --git a/docs/wrapping.rst b/docs/wrapping.rst index 49995aa..c46e1ea 100644 --- a/docs/wrapping.rst +++ b/docs/wrapping.rst @@ -147,5 +147,24 @@ To avoid the conversion of an argument or return value, use None >>> mypp3 = ureg.wraps((ureg.second, None), ureg.meter)(pendulum_period_error) +Checking units +============== + +When you want pint quantities to be used as inputs to your functions, pint provides a wrapper to ensure units are of +correct type - or more precisely, they match the expected dimensionality of the physical quantity. + +Similar to wraps(), you can pass None to skip checking of some parameters, but the return parameter type is not checked. + +.. doctest:: + + >>> mypp = ureg.check('[length]')(pendulum_period) + +In the decorator format: + +.. doctest:: + + >>>@ureg.wraps('[length]') + ... def pendulum_period(length): + ... return 2*math.pi*math.sqrt(length/G) diff --git a/pint/testsuite/test_unit.py b/pint/testsuite/test_unit.py index 5e80f49..08ae9f9 100644 --- a/pint/testsuite/test_unit.py +++ b/pint/testsuite/test_unit.py @@ -335,6 +335,45 @@ class TestRegistry(QuantityTestCase): h2 = ureg.wraps(('meter', 'cm'), [None, None])(hfunc) self.assertEqual(h2(3, 1), (3 * ureg.meter, 1 * ureg.cm)) + def test_check(self): + def func(x): + return x + + ureg = self.ureg + + f0 = ureg.check('[length]')(func) + self.assertRaises(AttributeError, f0, 3.) + self.assertEqual(f0(3. * ureg.centimeter), 0.03 * ureg.meter) + self.assertRaises(TypeError, f0, 3. * ureg.kilogram) + + f0b = ureg.check(ureg.meter)(func) + self.assertRaises(AttributeError, f0b, 3.) + self.assertEqual(f0b(3. * ureg.centimeter), 0.03 * ureg.meter) + self.assertRaises(TypeError, f0b, 3. * ureg.kilogram) + + def gfunc(x, y): + return x / y + + g0 = ureg.check(None, None)(gfunc) + self.assertEqual(g0(6, 2), 3) + self.assertEqual(g0(6 * ureg.parsec, 2), 3 * ureg.parsec) + + g1 = ureg.check('[speed]', '[time]')(gfunc) + self.assertRaises(AttributeError, g1, 3.0, 1) + self.assertRaises(TypeError, g1, 1 * ureg.parsec, 1 * ureg.angstrom) + self.assertRaises(TypeError, g1, 1 * ureg.km / ureg.hour, 1 * ureg.hour, 3.0) + self.assertEquals(g1(3.6 * ureg.km / ureg.hour, 1 * ureg.second), 1 * ureg.meter / ureg.second ** 2) + + g2 = ureg.check('[speed]')(gfunc) + self.assertRaises(AttributeError, g2, 3.0, 1) + self.assertRaises(TypeError, g2, 2 * ureg.parsec) + self.assertRaises(TypeError, g2, 2 * ureg.parsec, 1.0) + self.assertEquals(g2(2.0 * ureg.km / ureg.hour, 2), 1 * ureg.km / ureg.hour) + + g3 = ureg.check('[speed]', '[time]', '[mass]')(gfunc) + self.assertRaises(TypeError, g3, 1 * ureg.parsec, 1 * ureg.angstrom) + self.assertRaises(TypeError, g3, 1 * ureg.parsec, 1 * ureg.angstrom, 1 * ureg.kilogram) + def test_to_ref_vs_to(self): self.ureg.autoconvert_offset_to_baseunit = True q = 8. * self.ureg.inch diff --git a/pint/unit.py b/pint/unit.py index b88f310..cbbc34e 100644 --- a/pint/unit.py +++ b/pint/unit.py @@ -1145,6 +1145,36 @@ class UnitRegistry(object): return wrapper return decorator + def check(self, *args): + """Decorator to for quantity type checking for function inputs. + + Use it to ensure that the decorated function input parameters match + the expected type of pint quantity. + + Use None to skip argument checking. + + :param args: iterable of input units. + :return: the wrapped function. + :raises: + :class:`TypeError` if the parameters don't match dimensions + """ + dimensions = [self.get_dimensionality(dim) for dim in args] + + def decorator(func): + assigned = tuple(attr for attr in functools.WRAPPER_ASSIGNMENTS if hasattr(func, attr)) + updated = tuple(attr for attr in functools.WRAPPER_UPDATES if hasattr(func, attr)) + + @functools.wraps(func, assigned=assigned, updated=updated) + def wrapper(*values, **kwargs): + for dim, value in itertools.izip_longest(dimensions, values): + if dim and value.dimensionality != dim: + raise TypeError( + 'Expected units of %s, got %s' % + (dim, value.dimensionality)) + return func(*values, **kwargs) + return wrapper + return decorator + def build_unit_class(registry): From bb28afef5c308010edda68091f435c77500c75ad Mon Sep 17 00:00:00 2001 From: Kaido Kert Date: Thu, 9 Jul 2015 19:16:51 -0700 Subject: [PATCH 2/2] Fix some of the #283 issues --- docs/wrapping.rst | 3 +-- pint/compat/__init__.py | 5 +++++ pint/testsuite/test_unit.py | 18 +++++++++--------- pint/unit.py | 9 ++++----- 4 files changed, 19 insertions(+), 16 deletions(-) diff --git a/docs/wrapping.rst b/docs/wrapping.rst index c46e1ea..b2ed162 100644 --- a/docs/wrapping.rst +++ b/docs/wrapping.rst @@ -150,7 +150,6 @@ To avoid the conversion of an argument or return value, use None Checking units ============== - When you want pint quantities to be used as inputs to your functions, pint provides a wrapper to ensure units are of correct type - or more precisely, they match the expected dimensionality of the physical quantity. @@ -164,7 +163,7 @@ In the decorator format: .. doctest:: - >>>@ureg.wraps('[length]') + >>>@ureg.check('[length]') ... def pendulum_period(length): ... return 2*math.pi*math.sqrt(length/G) diff --git a/pint/compat/__init__.py b/pint/compat/__init__.py index 043f1bf..7442967 100644 --- a/pint/compat/__init__.py +++ b/pint/compat/__init__.py @@ -75,6 +75,11 @@ try: except ImportError: from .nullhandler import NullHandler +try: + from itertools import zip_longest +except ImportError: + from itertools import izip_longest as zip_longest + try: import numpy as np from numpy import ndarray diff --git a/pint/testsuite/test_unit.py b/pint/testsuite/test_unit.py index 08ae9f9..edbb2fa 100644 --- a/pint/testsuite/test_unit.py +++ b/pint/testsuite/test_unit.py @@ -344,12 +344,12 @@ class TestRegistry(QuantityTestCase): f0 = ureg.check('[length]')(func) self.assertRaises(AttributeError, f0, 3.) self.assertEqual(f0(3. * ureg.centimeter), 0.03 * ureg.meter) - self.assertRaises(TypeError, f0, 3. * ureg.kilogram) + self.assertRaises(DimensionalityError, f0, 3. * ureg.kilogram) f0b = ureg.check(ureg.meter)(func) self.assertRaises(AttributeError, f0b, 3.) self.assertEqual(f0b(3. * ureg.centimeter), 0.03 * ureg.meter) - self.assertRaises(TypeError, f0b, 3. * ureg.kilogram) + self.assertRaises(DimensionalityError, f0b, 3. * ureg.kilogram) def gfunc(x, y): return x / y @@ -360,19 +360,19 @@ class TestRegistry(QuantityTestCase): g1 = ureg.check('[speed]', '[time]')(gfunc) self.assertRaises(AttributeError, g1, 3.0, 1) - self.assertRaises(TypeError, g1, 1 * ureg.parsec, 1 * ureg.angstrom) + self.assertRaises(DimensionalityError, g1, 1 * ureg.parsec, 1 * ureg.angstrom) self.assertRaises(TypeError, g1, 1 * ureg.km / ureg.hour, 1 * ureg.hour, 3.0) - self.assertEquals(g1(3.6 * ureg.km / ureg.hour, 1 * ureg.second), 1 * ureg.meter / ureg.second ** 2) + self.assertEqual(g1(3.6 * ureg.km / ureg.hour, 1 * ureg.second), 1 * ureg.meter / ureg.second ** 2) g2 = ureg.check('[speed]')(gfunc) self.assertRaises(AttributeError, g2, 3.0, 1) - self.assertRaises(TypeError, g2, 2 * ureg.parsec) - self.assertRaises(TypeError, g2, 2 * ureg.parsec, 1.0) - self.assertEquals(g2(2.0 * ureg.km / ureg.hour, 2), 1 * ureg.km / ureg.hour) + self.assertRaises(DimensionalityError, g2, 2 * ureg.parsec) + self.assertRaises(DimensionalityError, g2, 2 * ureg.parsec, 1.0) + self.assertEqual(g2(2.0 * ureg.km / ureg.hour, 2), 1 * ureg.km / ureg.hour) g3 = ureg.check('[speed]', '[time]', '[mass]')(gfunc) - self.assertRaises(TypeError, g3, 1 * ureg.parsec, 1 * ureg.angstrom) - self.assertRaises(TypeError, g3, 1 * ureg.parsec, 1 * ureg.angstrom, 1 * ureg.kilogram) + self.assertRaises(DimensionalityError, g3, 1 * ureg.parsec, 1 * ureg.angstrom) + self.assertRaises(DimensionalityError, g3, 1 * ureg.parsec, 1 * ureg.angstrom, 1 * ureg.kilogram) def test_to_ref_vs_to(self): self.ureg.autoconvert_offset_to_baseunit = True diff --git a/pint/unit.py b/pint/unit.py index cbbc34e..7ca1591 100644 --- a/pint/unit.py +++ b/pint/unit.py @@ -29,7 +29,7 @@ from .util import (logger, pi_theorem, solve_dependencies, ParserHelper, string_preprocessor, find_connected_nodes, find_shortest_path, UnitsContainer, _is_dim, SharedRegistryObject, to_units_container) -from .compat import tokenizer, string_types, NUMERIC_TYPES, long_type +from .compat import tokenizer, string_types, NUMERIC_TYPES, long_type, zip_longest from .definitions import (Definition, UnitDefinition, PrefixDefinition, DimensionDefinition) from .converters import ScaleConverter @@ -1166,11 +1166,10 @@ class UnitRegistry(object): @functools.wraps(func, assigned=assigned, updated=updated) def wrapper(*values, **kwargs): - for dim, value in itertools.izip_longest(dimensions, values): + for dim, value in zip_longest(dimensions, values): if dim and value.dimensionality != dim: - raise TypeError( - 'Expected units of %s, got %s' % - (dim, value.dimensionality)) + raise DimensionalityError(value, 'a quantity of', + value.dimensionality, dim) return func(*values, **kwargs) return wrapper return decorator