Allow array exponents (with len > 1) when base is dimensionless
Close #483
This commit is contained in:
		@@ -893,9 +893,21 @@ class _Quantity(SharedRegistryObject):
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
            if isinstance(getattr(other, '_magnitude', other), ndarray):
 | 
					            if isinstance(getattr(other, '_magnitude', other), ndarray):
 | 
				
			||||||
                # arrays are refused as exponent, because they would create
 | 
					                # arrays are refused as exponent, because they would create
 | 
				
			||||||
                #  len(array) quanitites of len(set(array)) different units
 | 
					                # len(array) quantities of len(set(array)) different units
 | 
				
			||||||
                if np.size(other) > 1:
 | 
					                # unless the base is dimensionless.
 | 
				
			||||||
                    raise DimensionalityError(self._units, 'dimensionless')
 | 
					                if self.dimensionless:
 | 
				
			||||||
 | 
					                    if getattr(other, 'dimensionless', False):
 | 
				
			||||||
 | 
					                        self._magnitude **= other.m_as('')
 | 
				
			||||||
 | 
					                        return self
 | 
				
			||||||
 | 
					                    elif not getattr(other, 'dimensionless', True):
 | 
				
			||||||
 | 
					                        raise DimensionalityError(other._units, 'dimensionless')
 | 
				
			||||||
 | 
					                    else:
 | 
				
			||||||
 | 
					                        self._magnitude **= other
 | 
				
			||||||
 | 
					                        return self
 | 
				
			||||||
 | 
					                elif np.size(other) > 1:
 | 
				
			||||||
 | 
					                    raise DimensionalityError(self._units, 'dimensionless',
 | 
				
			||||||
 | 
					                                              extra_msg='Quantity array exponents are only allowed '
 | 
				
			||||||
 | 
					                                                        'if the base is dimensionless')
 | 
				
			||||||
 | 
					
 | 
				
			||||||
            if other == 1:
 | 
					            if other == 1:
 | 
				
			||||||
                return self
 | 
					                return self
 | 
				
			||||||
@@ -931,8 +943,18 @@ class _Quantity(SharedRegistryObject):
 | 
				
			|||||||
            if isinstance(getattr(other, '_magnitude', other), ndarray):
 | 
					            if isinstance(getattr(other, '_magnitude', other), ndarray):
 | 
				
			||||||
                # arrays are refused as exponent, because they would create
 | 
					                # arrays are refused as exponent, because they would create
 | 
				
			||||||
                # len(array) quantities of len(set(array)) different units
 | 
					                # len(array) quantities of len(set(array)) different units
 | 
				
			||||||
                if np.size(other) > 1:
 | 
					                # unless the base is dimensionless.
 | 
				
			||||||
                    raise DimensionalityError(self._units, 'dimensionless')
 | 
					                if self.dimensionless:
 | 
				
			||||||
 | 
					                    if getattr(other, 'dimensionless', False):
 | 
				
			||||||
 | 
					                        return self.__class__(self.m ** other.m_as(''))
 | 
				
			||||||
 | 
					                    elif not getattr(other, 'dimensionless', True):
 | 
				
			||||||
 | 
					                        raise DimensionalityError(other._units, 'dimensionless')
 | 
				
			||||||
 | 
					                    else:
 | 
				
			||||||
 | 
					                        return self.__class__(self.m ** other)
 | 
				
			||||||
 | 
					                elif np.size(other) > 1:
 | 
				
			||||||
 | 
					                    raise DimensionalityError(self._units, 'dimensionless',
 | 
				
			||||||
 | 
					                                              extra_msg='Quantity array exponents are only allowed '
 | 
				
			||||||
 | 
					                                                        'if the base is dimensionless')
 | 
				
			||||||
 | 
					
 | 
				
			||||||
            new_self = self
 | 
					            new_self = self
 | 
				
			||||||
            if other == 1:
 | 
					            if other == 1:
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -547,3 +547,10 @@ class TestIssuesNP(QuantityTestCase):
 | 
				
			|||||||
        y = f(x)
 | 
					        y = f(x)
 | 
				
			||||||
        z = x * y
 | 
					        z = x * y
 | 
				
			||||||
        self.assertEquals(z, ureg.Quantity(1., 'meter * kilogram'))
 | 
					        self.assertEquals(z, ureg.Quantity(1., 'meter * kilogram'))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def test_issue483(self):
 | 
				
			||||||
 | 
					        ureg = self.ureg
 | 
				
			||||||
 | 
					        a = np.asarray([1, 2, 3])
 | 
				
			||||||
 | 
					        q = [1, 2, 3] * ureg.dimensionless
 | 
				
			||||||
 | 
					        p = (q ** q).m
 | 
				
			||||||
 | 
					        np.testing.assert_array_equal(p, a ** a)
 | 
				
			||||||
@@ -427,7 +427,7 @@ class TestNDArrayQunatityMath(QuantityTestCase):
 | 
				
			|||||||
    @helpers.requires_numpy()
 | 
					    @helpers.requires_numpy()
 | 
				
			||||||
    def test_exponentiation_array_exp(self):
 | 
					    def test_exponentiation_array_exp(self):
 | 
				
			||||||
        arr = np.array(range(3), dtype=np.float)
 | 
					        arr = np.array(range(3), dtype=np.float)
 | 
				
			||||||
        q = self.Q_(arr, None)
 | 
					        q = self.Q_(arr, 'meter')
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        for op_ in [op.pow, op.ipow]:
 | 
					        for op_ in [op.pow, op.ipow]:
 | 
				
			||||||
            q_cp = copy.copy(q)
 | 
					            q_cp = copy.copy(q)
 | 
				
			||||||
 
 | 
				
			|||||||
		Reference in New Issue
	
	Block a user