501: Fixes for issues #95, #468, #483, #462 r=hgrecco
This commit is contained in:
bors[bot]
2017-04-15 03:36:48 +00:00
5 changed files with 64 additions and 16 deletions

View File

@@ -893,9 +893,21 @@ class _Quantity(SharedRegistryObject):
if isinstance(getattr(other, '_magnitude', other), ndarray):
# arrays are refused as exponent, because they would create
# len(array) quanitites of len(set(array)) different units
if np.size(other) > 1:
raise DimensionalityError(self._units, 'dimensionless')
# len(array) quantities of len(set(array)) different units
# unless the base is dimensionless.
if self.dimensionless:
if getattr(other, 'dimensionless', False):
self._magnitude **= other.m_as('')
return self
elif not getattr(other, 'dimensionless', True):
raise DimensionalityError(other._units, 'dimensionless')
else:
self._magnitude **= other
return self
elif np.size(other) > 1:
raise DimensionalityError(self._units, 'dimensionless',
extra_msg='Quantity array exponents are only allowed '
'if the base is dimensionless')
if other == 1:
return self
@@ -930,9 +942,19 @@ class _Quantity(SharedRegistryObject):
if isinstance(getattr(other, '_magnitude', other), ndarray):
# arrays are refused as exponent, because they would create
# len(array) quantities of len(set(array)) different units
if np.size(other) > 1:
raise DimensionalityError(self._units, 'dimensionless')
# len(array) quantities of len(set(array)) different units
# unless the base is dimensionless.
if self.dimensionless:
if getattr(other, 'dimensionless', False):
return self.__class__(self.m ** other.m_as(''))
elif not getattr(other, 'dimensionless', True):
raise DimensionalityError(other._units, 'dimensionless')
else:
return self.__class__(self.m ** other)
elif np.size(other) > 1:
raise DimensionalityError(self._units, 'dimensionless',
extra_msg='Quantity array exponents are only allowed '
'if the base is dimensionless')
new_self = self
if other == 1:
@@ -1052,6 +1074,7 @@ class _Quantity(SharedRegistryObject):
#: will set on output.
__set_units = {'cos': '', 'sin': '', 'tan': '',
'cosh': '', 'sinh': '', 'tanh': '',
'log': '', 'exp': '',
'arccos': __radian, 'arcsin': __radian,
'arctan': __radian, 'arctan2': __radian,
'arccosh': __radian, 'arcsinh': __radian,

View File

@@ -225,9 +225,10 @@ class BaseRegistry(meta.with_metaclass(_Meta)):
"""
if isinstance(definition, string_types):
definition = Definition.from_string(definition)
self._define(definition)
for line in definition.split('\n'):
self._define(Definition.from_string(line))
else:
self._define(definition)
def _define(self, definition):
"""Add unit to the registry.

View File

@@ -29,7 +29,7 @@ def _replace_units(original_units, values_by_name):
return getattr(q, "_units", UnitsContainer({}))
def _to_units_container(a):
def _to_units_container(a, registry=None):
"""Convert a unit compatible type to a UnitsContainer,
checking if it is string field prefixed with an equal
(which is considered a reference)
@@ -38,10 +38,10 @@ def _to_units_container(a):
"""
if isinstance(a, string_types) and '=' in a:
return to_units_container(a.split('=', 1)[1]), True
return to_units_container(a), False
return to_units_container(a, registry), False
def _parse_wrap_args(args):
def _parse_wrap_args(args, registry=None):
# Arguments which contain definitions
# (i.e. names that appear alone and for the first time)
@@ -55,7 +55,7 @@ def _parse_wrap_args(args):
unit_args_ndx = set()
# _to_units_container
args_as_uc = [_to_units_container(arg) for arg in args]
args_as_uc = [_to_units_container(arg, registry) for arg in args]
# Check for references in args, remove None values
for ndx, (arg, is_ref) in enumerate(args_as_uc):
@@ -154,9 +154,9 @@ def wraps(ureg, ret, args, strict=True):
converter = _parse_wrap_args(args)
if isinstance(ret, (list, tuple)):
container, ret = True, ret.__class__([_to_units_container(arg) for arg in ret])
container, ret = True, ret.__class__([_to_units_container(arg, ureg) for arg in ret])
else:
container, ret = False, _to_units_container(ret)
container, ret = False, _to_units_container(ret, ureg)
def decorator(func):
assigned = tuple(attr for attr in functools.WRAPPER_ASSIGNMENTS if hasattr(func, attr))

View File

@@ -530,3 +530,27 @@ class TestIssuesNP(QuantityTestCase):
'1 count')
self.assertEqual('{0:~}'.format(1 * self.ureg('MiB')),
'1 MiB')
def test_issue482(self):
q = self.ureg.Quantity(1, self.ureg.dimensionless)
qe = np.exp(q)
self.assertIsInstance(qe, self.ureg.Quantity)
def test_issue468(self):
ureg = UnitRegistry()
@ureg.wraps(('kg'), 'meter')
def f(x):
return x
x = ureg.Quantity(1., 'meter')
y = f(x)
z = x * y
self.assertEquals(z, ureg.Quantity(1., 'meter * kilogram'))
def test_issue483(self):
ureg = self.ureg
a = np.asarray([1, 2, 3])
q = [1, 2, 3] * ureg.dimensionless
p = (q ** q).m
np.testing.assert_array_equal(p, a ** a)

View File

@@ -427,7 +427,7 @@ class TestNDArrayQunatityMath(QuantityTestCase):
@helpers.requires_numpy()
def test_exponentiation_array_exp(self):
arr = np.array(range(3), dtype=np.float)
q = self.Q_(arr, None)
q = self.Q_(arr, 'meter')
for op_ in [op.pow, op.ipow]:
q_cp = copy.copy(q)