533: Bug fixing and some Refactoring r=hgrecco
This commit is contained in:
bors[bot] 2017-06-28 01:15:41 +00:00
commit 2906e54141
4 changed files with 69 additions and 53 deletions

View File

@ -910,6 +910,33 @@ class NonMultiplicativeRegistry(BaseRegistry):
except KeyError:
raise UndefinedUnitError(u)
def _validate_and_extract(self, units):
nonmult_units = [(u, e) for u, e in units.items()
if not self._is_multiplicative(u)]
# Let's validate source offset units
if len(nonmult_units) > 1:
# More than one src offset unit is not allowed
raise ValueError('more than one offset unit.')
elif len(nonmult_units) == 1:
# A single src offset unit is present. Extract it
# But check that:
# - the exponent is 1
# - is not used in multiplicative context
nonmult_unit, exponent = nonmult_units.pop()
if exponent != 1:
raise ValueError('offset units in higher order.')
if len(units) > 1 and not self.autoconvert_offset_to_baseunit:
raise ValueError('offset unit used in multiplicative context.')
return nonmult_unit
return None
def _convert(self, value, src, dst, inplace=False):
"""Convert value from some source to destination units.
@ -927,13 +954,19 @@ class NonMultiplicativeRegistry(BaseRegistry):
# Conversion needs to consider if non-multiplicative (AKA offset
# units) are involved. Conversion is only possible if src and dst
# have at most one offset unit per dimension.
src_offset_units = [(u, e) for u, e in src.items()
if not self._is_multiplicative(u)]
dst_offset_units = [(u, e) for u, e in dst.items()
if not self._is_multiplicative(u)]
# have at most one offset unit per dimension. Other rules are applied
# by validate and extract.
try:
src_offset_unit = self._validate_and_extract(src)
except ValueError as ex:
raise DimensionalityError(src, dst, extra_msg=' - In source units, %s ' % ex)
if not (src_offset_units or dst_offset_units):
try:
dst_offset_unit = self._validate_and_extract(dst)
except ValueError as ex:
raise DimensionalityError(src, dst, extra_msg=' - In destination units, %s ' % ex)
if not (src_offset_unit or dst_offset_unit):
return super(NonMultiplicativeRegistry, self)._convert(value, src, dst, inplace)
src_dim = self._get_dimensionality(src)
@ -944,53 +977,21 @@ class NonMultiplicativeRegistry(BaseRegistry):
if src_dim != dst_dim:
raise DimensionalityError(src, dst, src_dim, dst_dim)
# For offset units we need to check if the conversion is allowed.
if src_offset_units or dst_offset_units:
# Validate that not more than one offset unit is present
if len(src_offset_units) > 1 or len(dst_offset_units) > 1:
raise DimensionalityError(
src, dst, src_dim, dst_dim,
extra_msg=' - more than one offset unit.')
# validate that offset unit is not used in multiplicative context
if ((len(src_offset_units) == 1 and len(src) > 1)
or (len(dst_offset_units) == 1 and len(dst) > 1)
and not self.autoconvert_offset_to_baseunit):
raise DimensionalityError(
src, dst, src_dim, dst_dim,
extra_msg=' - offset unit used in multiplicative context.')
# Validate that order of offset unit is exactly one.
if src_offset_units:
if src_offset_units[0][1] != 1:
raise DimensionalityError(
src, dst, src_dim, dst_dim,
extra_msg=' - offset units in higher order.')
else:
if dst_offset_units[0][1] != 1:
raise DimensionalityError(
src, dst, src_dim, dst_dim,
extra_msg=' - offset units in higher order.')
# Here we convert only the offset quantities. Any remaining scaled
# quantities will be converted later.
# TODO: Shouldn't this (until factor, units) be inside the If above?
# clean src from offset units by converting to reference
for u, e in src_offset_units:
value = self._units[u].converter.to_reference(value, inplace)
src = src.remove([u for u, e in src_offset_units])
if src_offset_unit:
value = self._units[src_offset_unit].converter.to_reference(value, inplace)
src = src.remove([src_offset_unit])
# clean dst units from offset units
dst = dst.remove([u for u, e in dst_offset_units])
dst = dst.remove([dst_offset_unit])
# Convert non multiplicative units to the dst.
value = super(NonMultiplicativeRegistry, self)._convert(value, src, dst, inplace, False)
# Finally convert to offset units specified in destination
for u, e in dst_offset_units:
value = self._units[u].converter.from_reference(value, inplace)
if dst_offset_unit:
value = self._units[dst_offset_unit].converter.from_reference(value, inplace)
return value

View File

@ -201,7 +201,7 @@ def check(ureg, *args):
:raises:
:class:`DimensionalityError` if the parameters don't match dimensions
"""
dimensions = [ureg.get_dimensionality(dim) for dim in args]
dimensions = [ureg.get_dimensionality(dim) if dim is not None else None for dim in args]
def decorator(func):
assigned = tuple(attr for attr in functools.WRAPPER_ASSIGNMENTS if hasattr(func, attr))
@ -210,9 +210,14 @@ def check(ureg, *args):
@functools.wraps(func, assigned=assigned, updated=updated)
def wrapper(*values, **kwargs):
for dim, value in zip_longest(dimensions, values):
if dim and value.dimensionality != dim:
if dim is None:
continue
val_dim = ureg.get_dimensionality(value)
if val_dim != dim:
raise DimensionalityError(value, 'a quantity of',
value.dimensionality, dim)
val_dim, dim)
return func(*values, **kwargs)
return wrapper
return decorator

View File

@ -561,3 +561,13 @@ class TestIssuesNP(QuantityTestCase):
q = [1, 2, 3] * ureg.dimensionless
p = (q ** q).m
np.testing.assert_array_equal(p, a ** a)
def test_issue532(self):
ureg = self.ureg
@ureg.check(ureg(''))
def f(x):
return 2 * x
self.assertEqual(f(ureg.Quantity(1, '')), 2)
self.assertRaises(DimensionalityError, f, ureg.Quantity(1, 'm'))

View File

@ -391,12 +391,12 @@ class TestRegistry(QuantityTestCase):
ureg = self.ureg
f0 = ureg.check('[length]')(func)
self.assertRaises(AttributeError, f0, 3.)
self.assertRaises(DimensionalityError, f0, 3.)
self.assertEqual(f0(3. * ureg.centimeter), 0.03 * ureg.meter)
self.assertRaises(DimensionalityError, f0, 3. * ureg.kilogram)
f0b = ureg.check(ureg.meter)(func)
self.assertRaises(AttributeError, f0b, 3.)
self.assertRaises(DimensionalityError, f0b, 3.)
self.assertEqual(f0b(3. * ureg.centimeter), 0.03 * ureg.meter)
self.assertRaises(DimensionalityError, f0b, 3. * ureg.kilogram)
@ -408,13 +408,13 @@ class TestRegistry(QuantityTestCase):
self.assertEqual(g0(6 * ureg.parsec, 2), 3 * ureg.parsec)
g1 = ureg.check('[speed]', '[time]')(gfunc)
self.assertRaises(AttributeError, g1, 3.0, 1)
self.assertRaises(DimensionalityError, g1, 3.0, 1)
self.assertRaises(DimensionalityError, g1, 1 * ureg.parsec, 1 * ureg.angstrom)
self.assertRaises(TypeError, g1, 1 * ureg.km / ureg.hour, 1 * ureg.hour, 3.0)
self.assertEqual(g1(3.6 * ureg.km / ureg.hour, 1 * ureg.second), 1 * ureg.meter / ureg.second ** 2)
g2 = ureg.check('[speed]')(gfunc)
self.assertRaises(AttributeError, g2, 3.0, 1)
self.assertRaises(DimensionalityError, g2, 3.0, 1)
self.assertRaises(DimensionalityError, g2, 2 * ureg.parsec)
self.assertRaises(DimensionalityError, g2, 2 * ureg.parsec, 1.0)
self.assertEqual(g2(2.0 * ureg.km / ureg.hour, 2), 1 * ureg.km / ureg.hour)