diff --git a/pint/unit.py b/pint/unit.py index 27f6b5e..9802822 100644 --- a/pint/unit.py +++ b/pint/unit.py @@ -936,18 +936,11 @@ class UnitRegistry(object): if check_nonmult and input_units in self._base_units_cache: return copy.deepcopy(self._base_units_cache[input_units]) - factor = 1. - units = UnitsContainer() - for key, value in input_units.items(): - key = self.get_name(key) - reg = self._units[key] - if reg.is_base: - units.add(key, value) - else: - fac, uni = self.get_base_units(reg.reference, check_nonmult=False) - if factor is not None: - factor *= (reg.converter.scale * fac) ** value - units *= uni ** value + accumulators = [1., defaultdict(lambda: 0.0)] + self._get_base_units(input_units, 1.0, accumulators) + + factor = accumulators[0] + units = UnitsContainer(dict((k, v) for k, v in accumulators[1].items() if v != 0.)) # Check if any of the final units is non multiplicative and return None instead. if check_nonmult: @@ -957,6 +950,18 @@ class UnitRegistry(object): return factor, units + def _get_base_units(self, ref, exp, accumulators): + for key, value in ref.items(): + key = self.get_name(key) + reg = self._units[key] + exp2 = exp*value + if reg.is_base: + accumulators[1][key] += exp2 + else: + accumulators[0] *= reg.converter.scale ** exp2 + if reg.reference != None: + self._get_base_units(reg.reference, exp2, accumulators) + def get_compatible_units(self, input_units): if not input_units: return 1., UnitsContainer()