Implemented reference in wraps decorator.
We use an API based on strings prefixed with the equal sign. Each parameter can be labeled with a unique name. Parameters can reference other using labels to build up relationships. Close #195
This commit is contained in:
parent
457df519c7
commit
f70d749936
@ -1,7 +1,7 @@
|
||||
.. _wrapping:
|
||||
|
||||
Wrapping functions
|
||||
==================
|
||||
Wrapping and checking functions
|
||||
===============================
|
||||
|
||||
In some cases you might want to use pint with a pre-existing web service or library
|
||||
which is not units aware. Or you might want to write a fast implementation of a
|
||||
@ -137,6 +137,30 @@ Or if the function has multiple outputs:
|
||||
... (ureg.meter, ureg.radians))(pendulum_period_maxspeed)
|
||||
...
|
||||
|
||||
|
||||
Specifying relations between arguments
|
||||
--------------------------------------
|
||||
|
||||
In certain cases the actual units but just their relation. This is done using string
|
||||
starting with the equal sign `=`:
|
||||
|
||||
.. doctest::
|
||||
|
||||
>>> @ureg.wraps('=A**2', ('=A', '=A'))
|
||||
... def sqsum(x, y):
|
||||
... return x * x + 2 * x * y + y * y
|
||||
|
||||
which can be read as the first argument (`x`) has certain units (we labeled them `A`),
|
||||
the second argument (`y`) has the same units as the first (`A` again). The return value
|
||||
has the unit of `x` squared (`A**2`)
|
||||
|
||||
You can use more than one labels.
|
||||
|
||||
>>> @ureg.wraps('=A**2*B', ('=A', '=A*B', '=B'))
|
||||
... def some_function(x, y, z):
|
||||
|
||||
|
||||
|
||||
Ignoring an argument or return value
|
||||
------------------------------------
|
||||
|
||||
|
217
pint/registry_helpers.py
Normal file
217
pint/registry_helpers.py
Normal file
@ -0,0 +1,217 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
pint.registry_helpers
|
||||
~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
Miscellaneous methods of the registry writen as separate functions.
|
||||
|
||||
:copyright: 2013 by Pint Authors, see AUTHORS for more details.
|
||||
:license: BSD, see LICENSE for more details.
|
||||
"""
|
||||
|
||||
import functools
|
||||
|
||||
from .compat import string_types, zip_longest
|
||||
from .errors import DimensionalityError
|
||||
from .util import to_units_container
|
||||
|
||||
|
||||
def _replace_units(original_units, values_by_name):
|
||||
"""Convert a unit compatible type to a UnitsContainer.
|
||||
|
||||
:param original_units: a UnitsContainer instance.
|
||||
:param values_by_name: a map between original names and the new values.
|
||||
"""
|
||||
q = 1
|
||||
for arg_name, exponent in original_units.items():
|
||||
q = q * values_by_name[arg_name] ** exponent
|
||||
|
||||
return to_units_container(q)
|
||||
|
||||
|
||||
def _to_units_container(a):
|
||||
"""Convert a unit compatible type to a UnitsContainer,
|
||||
checking if it is string field prefixed with an equal
|
||||
(which is considered a reference)
|
||||
|
||||
Return a tuple with the unit container and a boolean indicating if it was a reference.
|
||||
"""
|
||||
if isinstance(a, string_types) and '=' in a:
|
||||
return to_units_container(a.split('=', 1)[1]), True
|
||||
return to_units_container(a), False
|
||||
|
||||
|
||||
def _parse_wrap_args(args):
|
||||
|
||||
# Arguments which contain definitions
|
||||
# (i.e. names that appear alone and for the first time)
|
||||
defs_args = set()
|
||||
defs_args_ndx = set()
|
||||
|
||||
# Arguments which depend on others
|
||||
dependent_args_ndx = set()
|
||||
|
||||
# Arguments which have units.
|
||||
unit_args_ndx = set()
|
||||
|
||||
# _to_units_container
|
||||
args_as_uc = [_to_units_container(arg) for arg in args]
|
||||
|
||||
# Check for references in args, remove None values
|
||||
for ndx, (arg, is_ref) in enumerate(args_as_uc):
|
||||
if arg is None:
|
||||
continue
|
||||
elif is_ref:
|
||||
if len(arg) == 1:
|
||||
[(key, value)] = arg.items()
|
||||
if value == 1 and key not in defs_args:
|
||||
# This is the first time that
|
||||
# a variable is used => it is a definition.
|
||||
defs_args.add(key)
|
||||
defs_args_ndx.add(ndx)
|
||||
args_as_uc[ndx] = (key, True)
|
||||
else:
|
||||
# The variable was already found elsewhere,
|
||||
# we consider it a dependent variable.
|
||||
dependent_args_ndx.add(ndx)
|
||||
else:
|
||||
dependent_args_ndx.add(ndx)
|
||||
else:
|
||||
unit_args_ndx.add(ndx)
|
||||
|
||||
# Check that all valid dependent variables
|
||||
for ndx in dependent_args_ndx:
|
||||
arg, is_ref = args_as_uc[ndx]
|
||||
if not isinstance(arg, dict):
|
||||
continue
|
||||
if not set(arg.keys()) <= defs_args:
|
||||
raise ValueError('Found a missing token while wrapping a function: '
|
||||
'Not all variable referenced in %s are defined using !' % args[ndx])
|
||||
|
||||
def _converter(ureg, values, strict):
|
||||
new_values = list(value for value in values)
|
||||
|
||||
values_by_name = {}
|
||||
|
||||
# first pass: Grab named values
|
||||
for ndx in defs_args_ndx:
|
||||
values_by_name[args_as_uc[ndx][0]] = values[ndx]
|
||||
new_values[ndx] = values[ndx]._magnitude
|
||||
|
||||
# second pass: calculate derived values based on named values
|
||||
for ndx in dependent_args_ndx:
|
||||
new_values[ndx] = ureg._convert(values[ndx]._magnitude,
|
||||
values[ndx]._units,
|
||||
_replace_units(args_as_uc[ndx][0], values_by_name))
|
||||
|
||||
# third pass: convert other arguments
|
||||
for ndx in unit_args_ndx:
|
||||
|
||||
if isinstance(values[ndx], ureg.Quantity):
|
||||
new_values[ndx] = ureg._convert(values[ndx]._magnitude,
|
||||
values[ndx]._units,
|
||||
args_as_uc[ndx][0])
|
||||
else:
|
||||
if strict:
|
||||
raise ValueError('A wrapped function using strict=True requires '
|
||||
'quantity for all arguments with not None units. '
|
||||
'(error found for {0}, {1})'.format(args_as_uc[ndx][0], new_values[ndx]))
|
||||
|
||||
return new_values, values_by_name
|
||||
|
||||
return _converter
|
||||
|
||||
|
||||
def wraps(ureg, ret, args, strict=True):
|
||||
"""Wraps a function to become pint-aware.
|
||||
|
||||
Use it when a function requires a numerical value but in some specific
|
||||
units. The wrapper function will take a pint quantity, convert to the units
|
||||
specified in `args` and then call the wrapped function with the resulting
|
||||
magnitude.
|
||||
|
||||
The value returned by the wrapped function will be converted to the units
|
||||
specified in `ret`.
|
||||
|
||||
Use None to skip argument conversion.
|
||||
Set strict to False, to accept also numerical values.
|
||||
|
||||
:param ureg: a UnitRegistry instance.
|
||||
:param ret: output units.
|
||||
:param args: iterable of input units.
|
||||
:param strict: boolean to indicate that only quantities are accepted.
|
||||
:return: the wrapped function.
|
||||
:raises:
|
||||
:class:`ValueError` if strict and one of the arguments is not a Quantity.
|
||||
"""
|
||||
|
||||
if not isinstance(args, (list, tuple)):
|
||||
args = (args, )
|
||||
|
||||
converter = _parse_wrap_args(args)
|
||||
|
||||
if isinstance(ret, (list, tuple)):
|
||||
container, ret = True, ret.__class__([_to_units_container(arg) for arg in ret])
|
||||
elif isinstance(ret, string_types):
|
||||
container, ret = False, _to_units_container(ret)
|
||||
else:
|
||||
container = False
|
||||
|
||||
def decorator(func):
|
||||
assigned = tuple(attr for attr in functools.WRAPPER_ASSIGNMENTS if hasattr(func, attr))
|
||||
updated = tuple(attr for attr in functools.WRAPPER_UPDATES if hasattr(func, attr))
|
||||
|
||||
@functools.wraps(func, assigned=assigned, updated=updated)
|
||||
def wrapper(*values, **kw):
|
||||
|
||||
# In principle, the values are used as is
|
||||
# When then extract the magnitudes when needed.
|
||||
new_values, values_by_name = converter(ureg, values, strict)
|
||||
|
||||
result = func(*new_values, **kw)
|
||||
|
||||
if container:
|
||||
out_units = (_replace_units(r, values_by_name) if is_ref else r
|
||||
for (r, is_ref) in ret)
|
||||
return ret.__class__(res if unit is None else ureg.Quantity(res, unit)
|
||||
for unit, res in zip(out_units, result))
|
||||
|
||||
if ret is None:
|
||||
return result
|
||||
|
||||
return ureg.Quantity(result,
|
||||
_replace_units(ret[0], values_by_name) if ret[1] else ret[0])
|
||||
|
||||
return wrapper
|
||||
return decorator
|
||||
|
||||
|
||||
def check(ureg, *args):
|
||||
"""Decorator to for quantity type checking for function inputs.
|
||||
|
||||
Use it to ensure that the decorated function input parameters match
|
||||
the expected type of pint quantity.
|
||||
|
||||
Use None to skip argument checking.
|
||||
|
||||
:param ureg: a UnitRegistry instance.
|
||||
:param args: iterable of input units.
|
||||
:return: the wrapped function.
|
||||
:raises:
|
||||
:class:`DimensionalityError` if the parameters don't match dimensions
|
||||
"""
|
||||
dimensions = [ureg.get_dimensionality(dim) for dim in args]
|
||||
|
||||
def decorator(func):
|
||||
assigned = tuple(attr for attr in functools.WRAPPER_ASSIGNMENTS if hasattr(func, attr))
|
||||
updated = tuple(attr for attr in functools.WRAPPER_UPDATES if hasattr(func, attr))
|
||||
|
||||
@functools.wraps(func, assigned=assigned, updated=updated)
|
||||
def wrapper(*values, **kwargs):
|
||||
for dim, value in zip_longest(dimensions, values):
|
||||
if dim and value.dimensionality != dim:
|
||||
raise DimensionalityError(value, 'a quantity of',
|
||||
value.dimensionality, dim)
|
||||
return func(*values, **kwargs)
|
||||
return wrapper
|
||||
return decorator
|
@ -330,12 +330,45 @@ class TestRegistry(QuantityTestCase):
|
||||
h0 = ureg.wraps(None, [None, None])(hfunc)
|
||||
self.assertEqual(h0(3, 1), (3, 1))
|
||||
|
||||
h1 = ureg.wraps(['meter', 'cm'], [None, None])(hfunc)
|
||||
h1 = ureg.wraps(['meter', 'centimeter'], [None, None])(hfunc)
|
||||
self.assertEqual(h1(3, 1), [3 * ureg.meter, 1 * ureg.cm])
|
||||
|
||||
h2 = ureg.wraps(('meter', 'cm'), [None, None])(hfunc)
|
||||
h2 = ureg.wraps(('meter', 'centimeter'), [None, None])(hfunc)
|
||||
self.assertEqual(h2(3, 1), (3 * ureg.meter, 1 * ureg.cm))
|
||||
|
||||
def test_wrap_referencing(self):
|
||||
|
||||
ureg = self.ureg
|
||||
|
||||
def gfunc(x, y):
|
||||
return x + y
|
||||
|
||||
def gfunc2(x, y):
|
||||
return x ** 2 + y
|
||||
|
||||
def gfunc3(x, y):
|
||||
return x ** 2 * y
|
||||
|
||||
rst = 3. * ureg.meter + 1. * ureg.centimeter
|
||||
|
||||
g0 = ureg.wraps('=A', ['=A', '=A'])(gfunc)
|
||||
self.assertEqual(g0(3. * ureg.meter, 1. * ureg.centimeter), rst.to('meter'))
|
||||
|
||||
g1 = ureg.wraps('=A', ['=A', '=A'])(gfunc)
|
||||
self.assertEqual(g1(3. * ureg.meter, 1. * ureg.centimeter), rst.to('centimeter'))
|
||||
|
||||
g2 = ureg.wraps('=A', ['=A', '=A'])(gfunc)
|
||||
self.assertEqual(g2(3. * ureg.meter, 1. * ureg.centimeter), rst.to('meter'))
|
||||
|
||||
g3 = ureg.wraps('=A**2', ['=A', '=A**2'])(gfunc2)
|
||||
a = 3. * ureg.meter
|
||||
b = (2. * ureg.centimeter) ** 2
|
||||
self.assertEqual(g3(a, b), gfunc2(a, b))
|
||||
|
||||
g4 = ureg.wraps('=A**2 * B', ['=A', '=B'])(gfunc3)
|
||||
self.assertEqual(g4(3. * ureg.meter, 2. * ureg.second), ureg('(3*meter)**2 * 2 *second'))
|
||||
|
||||
|
||||
def test_check(self):
|
||||
def func(x):
|
||||
return x
|
||||
|
98
pint/unit.py
98
pint/unit.py
@ -14,7 +14,6 @@ from __future__ import division, unicode_literals, print_function, absolute_impo
|
||||
import os
|
||||
import math
|
||||
import itertools
|
||||
import functools
|
||||
import operator
|
||||
import pkg_resources
|
||||
from decimal import Decimal
|
||||
@ -25,6 +24,7 @@ from collections import defaultdict
|
||||
from tokenize import untokenize, NUMBER, STRING, NAME, OP
|
||||
from numbers import Number
|
||||
|
||||
from . import registry_helpers
|
||||
from .context import Context, ContextChain
|
||||
from .util import (logger, pi_theorem, solve_dependencies, ParserHelper,
|
||||
string_preprocessor, find_connected_nodes,
|
||||
@ -32,7 +32,7 @@ from .util import (logger, pi_theorem, solve_dependencies, ParserHelper,
|
||||
SharedRegistryObject, to_units_container,
|
||||
fix_str_conversions, SourceIterator)
|
||||
|
||||
from .compat import tokenizer, string_types, NUMERIC_TYPES, long_type, zip_longest
|
||||
from .compat import tokenizer, string_types, NUMERIC_TYPES, long_type
|
||||
from .formatting import siunitx_format_unit
|
||||
from .definitions import (Definition, UnitDefinition, PrefixDefinition,
|
||||
DimensionDefinition)
|
||||
@ -1267,99 +1267,9 @@ class UnitRegistry(object):
|
||||
|
||||
__call__ = parse_expression
|
||||
|
||||
def wraps(self, ret, args, strict=True):
|
||||
"""Wraps a function to become pint-aware.
|
||||
wraps = registry_helpers.wraps
|
||||
|
||||
Use it when a function requires a numerical value but in some specific
|
||||
units. The wrapper function will take a pint quantity, convert to the units
|
||||
specified in `args` and then call the wrapped function with the resulting
|
||||
magnitude.
|
||||
|
||||
The value returned by the wrapped function will be converted to the units
|
||||
specified in `ret`.
|
||||
|
||||
Use None to skip argument conversion.
|
||||
Set strict to False, to accept also numerical values.
|
||||
|
||||
:param ret: output units.
|
||||
:param args: iterable of input units.
|
||||
:param strict: boolean to indicate that only quantities are accepted.
|
||||
:return: the wrapped function.
|
||||
:raises:
|
||||
:class:`ValueError` if strict and one of the arguments is not a Quantity.
|
||||
"""
|
||||
|
||||
Q_ = self.Quantity
|
||||
|
||||
if not isinstance(args, (list, tuple)):
|
||||
args = (args, )
|
||||
|
||||
units = [to_units_container(arg, self) for arg in args]
|
||||
|
||||
if isinstance(ret, (list, tuple)):
|
||||
ret = ret.__class__([to_units_container(arg, self) for arg in ret])
|
||||
elif isinstance(ret, string_types):
|
||||
ret = self.parse_units(ret)
|
||||
|
||||
def decorator(func):
|
||||
assigned = tuple(attr for attr in functools.WRAPPER_ASSIGNMENTS if hasattr(func, attr))
|
||||
updated = tuple(attr for attr in functools.WRAPPER_UPDATES if hasattr(func, attr))
|
||||
@functools.wraps(func, assigned=assigned, updated=updated)
|
||||
def wrapper(*values, **kw):
|
||||
new_args = []
|
||||
for unit, value in zip(units, values):
|
||||
if unit is None:
|
||||
new_args.append(value)
|
||||
elif isinstance(value, Q_):
|
||||
new_args.append(self._convert(value._magnitude,
|
||||
value._units, unit))
|
||||
elif not strict:
|
||||
new_args.append(value)
|
||||
else:
|
||||
raise ValueError('A wrapped function using strict=True requires '
|
||||
'quantity for all arguments with not None units. '
|
||||
'(error found for {0}, {1})'.format(unit, value))
|
||||
|
||||
result = func(*new_args, **kw)
|
||||
|
||||
if isinstance(ret, (list, tuple)):
|
||||
return ret.__class__(res if unit is None else Q_(res, unit)
|
||||
for unit, res in zip(ret, result))
|
||||
elif ret is not None:
|
||||
return Q_(result, ret)
|
||||
|
||||
return result
|
||||
return wrapper
|
||||
return decorator
|
||||
|
||||
def check(self, *args):
|
||||
"""Decorator to for quantity type checking for function inputs.
|
||||
|
||||
Use it to ensure that the decorated function input parameters match
|
||||
the expected type of pint quantity.
|
||||
|
||||
Use None to skip argument checking.
|
||||
|
||||
:param args: iterable of input units.
|
||||
:return: the wrapped function.
|
||||
:raises:
|
||||
:class:`DimensionalityError` if the parameters don't match dimensions
|
||||
"""
|
||||
dimensions = [self.get_dimensionality(dim) for dim in args]
|
||||
|
||||
def decorator(func):
|
||||
assigned = tuple(attr for attr in functools.WRAPPER_ASSIGNMENTS if hasattr(func, attr))
|
||||
updated = tuple(attr for attr in functools.WRAPPER_UPDATES if hasattr(func, attr))
|
||||
|
||||
@functools.wraps(func, assigned=assigned, updated=updated)
|
||||
def wrapper(*values, **kwargs):
|
||||
for dim, value in zip_longest(dimensions, values):
|
||||
if dim and value.dimensionality != dim:
|
||||
raise DimensionalityError(value, 'a quantity of',
|
||||
value.dimensionality, dim)
|
||||
return func(*values, **kwargs)
|
||||
return wrapper
|
||||
return decorator
|
||||
check = registry_helpers.check
|
||||
|
||||
|
||||
def build_unit_class(registry):
|
||||
|
@ -11,6 +11,7 @@
|
||||
|
||||
from __future__ import division, unicode_literals, print_function, absolute_import
|
||||
|
||||
from decimal import Decimal
|
||||
import locale
|
||||
import sys
|
||||
import re
|
||||
|
Loading…
Reference in New Issue
Block a user