Implemented parameterized contexts

Transformation functions can include additional keyword arguments
to specify parameters. The values for these parameters can be specified
as default values for the context or given when entering the context.

See Issue #65
This commit is contained in:
Hernan Grecco
2013-10-09 19:58:51 -03:00
parent abf8de2dd3
commit 6163525abd
3 changed files with 385 additions and 59 deletions

View File

@@ -99,7 +99,7 @@ class ChainMap(MutableMapping):
__copy__ = copy
def new_child(self, m=None): # like Django's Context.push()
def new_child(self, m=None): # like Django's _Context.push()
'''
New ChainMap with a new map followed by all previous maps. If no
map is provided, an empty dict is used.
@@ -109,7 +109,7 @@ class ChainMap(MutableMapping):
return self.__class__(m, *self.maps)
@property
def parents(self): # like Django's Context.pop()
def parents(self): # like Django's _Context.pop()
'New ChainMap from maps[1:].'
return self.__class__(*self.maps[1:])

View File

@@ -3,23 +3,76 @@
from __future__ import division, unicode_literals, print_function, absolute_import
import unittest
from collections import defaultdict
from pint import UnitRegistry
from pint.unit import UnitsContainer, _freeze
from pint.unit import UnitsContainer, _freeze, _Context
def add_ctxs(ureg):
a, b = _freeze(UnitsContainer({'[length]': 1})), _freeze(UnitsContainer({'[time]': -1}))
d = {}
d[(a, b)] = lambda x: ureg.speed_of_light / x
d[(b, a)] = lambda x: ureg.speed_of_light / x
a, b = UnitsContainer({'[length]': 1}), UnitsContainer({'[time]': -1})
d = _Context()
d.add_transformation(a, b, lambda x: ureg.speed_of_light / x)
d.add_transformation(b, a, lambda x: ureg.speed_of_light / x)
ureg._contexts['sp'] = d
a, b = _freeze(UnitsContainer({'[length]': 1})), _freeze(UnitsContainer({'[current]': -1}))
d = {}
d[(a, b)] = lambda x: 1 / x
d[(b, a)] = lambda x: 1 / x
a, b = UnitsContainer({'[length]': 1}), UnitsContainer({'[current]': -1})
d = _Context()
d.add_transformation(a, b, lambda x: 1 / x)
d.add_transformation(b, a, lambda x: 1 / x)
ureg._contexts['ab'] = d
def add_arg_ctxs(ureg):
a, b = UnitsContainer({'[length]': 1}), UnitsContainer({'[time]': -1})
d = _Context()
d.add_transformation(a, b, lambda x, n: ureg.speed_of_light / x / n)
d.add_transformation(b, a, lambda x, n: ureg.speed_of_light / x / n)
ureg._contexts['sp'] = d
a, b = UnitsContainer({'[length]': 1}), UnitsContainer({'[current]': -1})
d = _Context()
d.add_transformation(a, b, lambda x: 1 / x)
d.add_transformation(b, a, lambda x: 1 / x)
ureg._contexts['ab'] = d
def add_argdef_ctxs(ureg):
a, b = UnitsContainer({'[length]': 1}), UnitsContainer({'[time]': -1})
d = _Context(n=1)
assert d.defaults == dict(n=1)
d.add_transformation(a, b, lambda x, n: ureg.speed_of_light / x / n)
d.add_transformation(b, a, lambda x, n: ureg.speed_of_light / x / n)
ureg._contexts['sp'] = d
a, b = UnitsContainer({'[length]': 1}), UnitsContainer({'[current]': -1})
d = _Context()
d.add_transformation(a, b, lambda x: 1 / x)
d.add_transformation(b, a, lambda x: 1 / x)
ureg._contexts['ab'] = d
def add_sharedargdef_ctxs(ureg):
a, b = UnitsContainer({'[length]': 1}), UnitsContainer({'[time]': -1})
d = _Context(n=1)
assert d.defaults == dict(n=1)
d.add_transformation(a, b, lambda x, n: ureg.speed_of_light / x / n)
d.add_transformation(b, a, lambda x, n: ureg.speed_of_light / x / n)
ureg._contexts['sp'] = d
a, b = UnitsContainer({'[length]': 1}), UnitsContainer({'[current]': 1})
d = _Context(n=0)
d.add_transformation(a, b, lambda x, n: ureg.ampere * ureg.meter * n / x)
d.add_transformation(b, a, lambda x, n: ureg.ampere * ureg.meter * n / x)
ureg._contexts['ab'] = d
@@ -31,10 +84,50 @@ class TestContexts(unittest.TestCase):
add_ctxs(ureg)
with ureg.context('sp'):
self.assertTrue(ureg._active_ctx)
self.assertTrue(ureg._active_ctx_graph)
self.assertTrue(ureg._active_ctx.graph)
self.assertFalse(ureg._active_ctx)
self.assertFalse(ureg._active_ctx_graph)
self.assertFalse(ureg._active_ctx.graph)
def test_graph(self):
ureg = UnitRegistry()
add_ctxs(ureg)
l = _freeze({'[length]': 1.})
t = _freeze({'[time]': -1.})
c = _freeze({'[current]': -1.})
g_sp = defaultdict(set)
g_sp.update({l: {t, },
t: {l, }})
g_ab = defaultdict(set)
g_ab.update({l: {c, },
c: {l, }})
g = defaultdict(set)
g.update({l: {t, c},
t: {l, },
c: {l, }})
with ureg.context('sp'):
self.assertEqual(ureg._active_ctx.graph, g_sp)
with ureg.context('ab'):
self.assertEqual(ureg._active_ctx.graph, g_ab)
with ureg.context('sp'):
with ureg.context('ab'):
self.assertEqual(ureg._active_ctx.graph, g)
with ureg.context('ab'):
with ureg.context('sp'):
self.assertEqual(ureg._active_ctx.graph, g)
with ureg.context('sp', 'ab'):
self.assertEqual(ureg._active_ctx.graph, g)
with ureg.context('ab', 'sp'):
self.assertEqual(ureg._active_ctx.graph, g)
def test_known_nested_context(self):
ureg = UnitRegistry()
@@ -42,21 +135,21 @@ class TestContexts(unittest.TestCase):
with ureg.context('sp'):
x = dict(ureg._active_ctx)
y = dict(ureg._active_ctx_graph)
y = dict(ureg._active_ctx.graph)
self.assertTrue(ureg._active_ctx)
self.assertTrue(ureg._active_ctx_graph)
self.assertTrue(ureg._active_ctx.graph)
with ureg.context('ab'):
self.assertTrue(ureg._active_ctx)
self.assertTrue(ureg._active_ctx_graph)
self.assertTrue(ureg._active_ctx.graph)
self.assertNotEqual(x, ureg._active_ctx)
self.assertNotEqual(y, ureg._active_ctx_graph)
self.assertNotEqual(y, ureg._active_ctx.graph)
self.assertEqual(x, ureg._active_ctx)
self.assertEqual(y, ureg._active_ctx_graph)
self.assertEqual(y, ureg._active_ctx.graph)
self.assertFalse(ureg._active_ctx)
self.assertFalse(ureg._active_ctx_graph)
self.assertFalse(ureg._active_ctx.graph)
def test_unknown_context(self):
ureg = UnitRegistry()
@@ -70,7 +163,7 @@ class TestContexts(unittest.TestCase):
value = False
self.assertTrue(value)
self.assertFalse(ureg._active_ctx)
self.assertFalse(ureg._active_ctx_graph)
self.assertFalse(ureg._active_ctx.graph)
def test_unknown_nested_context(self):
ureg = UnitRegistry()
@@ -78,7 +171,7 @@ class TestContexts(unittest.TestCase):
with ureg.context('sp'):
x = dict(ureg._active_ctx)
y = dict(ureg._active_ctx_graph)
y = dict(ureg._active_ctx.graph)
try:
with ureg.context('la'):
pass
@@ -90,10 +183,10 @@ class TestContexts(unittest.TestCase):
self.assertTrue(value)
self.assertEqual(x, ureg._active_ctx)
self.assertEqual(y, ureg._active_ctx_graph)
self.assertEqual(y, ureg._active_ctx.graph)
self.assertFalse(ureg._active_ctx)
self.assertFalse(ureg._active_ctx_graph)
self.assertFalse(ureg._active_ctx.graph)
def test_one_context(self):
@@ -142,3 +235,105 @@ class TestContexts(unittest.TestCase):
with ureg.context('sp'):
self.assertEqual(q.to('Hz'), s)
self.assertRaises(ValueError, q.to, 'Hz')
def test_context_with_arg(self):
ureg = UnitRegistry()
add_arg_ctxs(ureg)
q = 500 * ureg.meter
s = (ureg.speed_of_light / q).to('Hz')
self.assertRaises(ValueError, q.to, 'Hz')
with ureg.context('sp', n=1):
self.assertEqual(q.to('Hz'), s)
with ureg.context('ab'):
self.assertEqual(q.to('Hz'), s)
self.assertEqual(q.to('Hz'), s)
with ureg.context('ab'):
self.assertRaises(ValueError, q.to, 'Hz')
with ureg.context('sp', n=1):
self.assertEqual(q.to('Hz'), s)
self.assertRaises(ValueError, q.to, 'Hz')
with ureg.context('sp'):
self.assertRaises(TypeError, q.to, 'Hz')
def test_context_with_arg_def(self):
ureg = UnitRegistry()
add_argdef_ctxs(ureg)
q = 500 * ureg.meter
s = (ureg.speed_of_light / q).to('Hz')
self.assertRaises(ValueError, q.to, 'Hz')
with ureg.context('sp'):
self.assertEqual(q.to('Hz'), s)
with ureg.context('ab'):
self.assertEqual(q.to('Hz'), s)
self.assertEqual(q.to('Hz'), s)
with ureg.context('ab'):
self.assertRaises(ValueError, q.to, 'Hz')
with ureg.context('sp'):
self.assertEqual(q.to('Hz'), s)
self.assertRaises(ValueError, q.to, 'Hz')
self.assertRaises(ValueError, q.to, 'Hz')
with ureg.context('sp', n=2):
self.assertEqual(q.to('Hz'), s / 2)
with ureg.context('ab'):
self.assertEqual(q.to('Hz'), s / 2)
self.assertEqual(q.to('Hz'), s / 2)
with ureg.context('ab'):
self.assertRaises(ValueError, q.to, 'Hz')
with ureg.context('sp', n=2):
self.assertEqual(q.to('Hz'), s / 2)
self.assertRaises(ValueError, q.to, 'Hz')
def test_context_with_sharedarg_def(self):
ureg = UnitRegistry()
add_sharedargdef_ctxs(ureg)
q = 500 * ureg.meter
s = (ureg.speed_of_light / q).to('Hz')
u = (1 / 500) * ureg.ampere
with ureg.context('sp'):
self.assertEqual(q.to('Hz'), s)
with ureg.context('ab'):
self.assertEqual(q.to('ampere'), u)
with ureg.context('ab'):
self.assertEqual(q.to('ampere'), 0 * u)
with ureg.context('sp'):
self.assertRaises(ZeroDivisionError, ureg.Quantity.to, q, 'Hz')
with ureg.context('sp', n=2):
self.assertEqual(q.to('Hz'), s / 2)
with ureg.context('ab'):
self.assertEqual(q.to('ampere'), 2 * u)
with ureg.context('ab', n=3):
self.assertEqual(q.to('ampere'), 3 * u)
with ureg.context('sp'):
self.assertEqual(q.to('Hz'), s / 3)
with ureg.context('sp', n=2):
self.assertEqual(q.to('Hz'), s / 2)
with ureg.context('ab', n=4):
self.assertEqual(q.to('ampere'), 4 * u)
with ureg.context('ab', n=3):
self.assertEqual(q.to('ampere'), 3 * u)
with ureg.context('sp', n=6):
self.assertEqual(q.to('Hz'), s / 6)

View File

@@ -16,9 +16,10 @@ import copy
import math
import itertools
import functools
import weakref
import pkg_resources
from decimal import Decimal
from collections import defaultdict
from collections import defaultdict, MutableMapping
from contextlib import contextmanager
from io import open
@@ -174,9 +175,135 @@ def _is_dim(name):
def _freeze(d):
"""Return a hashable view of dict.
"""
if isinstance(d, frozenset):
return d
return frozenset(d.items())
class _Context(object):
"""A specialized container that defines transformation functions from
one dimension to another. Each Dimension are specified using a UnitsContainer.
Simple transformation are given with a function taking a single parameter.
>>> timedim = UnitsContainer({'[time]': 1})
>>> spacedim = UnitsContainer({'[length]': 1})
>>> def f(time):
... 'Time to length converter'
... return 3. * time
>>> c = _Context()
>>> c.add_transformation(timedim, spacedim, f)
>>> c.transform(timedim, spacedim, 2)
6
Conversion functions may take optional keyword arguments and the context can
have default values for these arguments.
>>> def f(time, n):
... 'Time to length converter, n is the index of refraction of the material'
... return 3. * time / n
>>> c = _Context(n=3)
>>> c.add_transformation(timedim, spacedim, f)
>>> c.transform(timedim, spacedim, 2)
2
"""
def __init__(self, **defaults):
#: Maps (src, dst) -> transformation function
self.funcs = {}
#: Maps defaults variable names to values
self.defaults = defaults
#: Maps (src, dst) -> self
#: Used as a convenience dictionary to be composed by _ContextChain
self.refs_to_self = weakref.WeakValueDictionary()
@classmethod
def from_context(cls, context, **defaults):
"""Creates a new context that shares the funcs dictionary with the original
context. The default values are copied from the original context and updated
with the new defaults.
If defaults is empty, return the same context.
"""
if defaults:
newdef = dict(context.defaults, **defaults)
c = cls(**newdef)
c.funcs = context.funcs
for edge in context.funcs.keys():
c.refs_to_self[edge] = c
return c
return context
def add_transformation(self, src, dst, func):
"""Add a transformation function to the context.
"""
_key = self.__keytransform__(src, dst)
self.funcs[_key] = func
self.refs_to_self[_key] = self
def __keytransform__(self, src, dst):
return _freeze(src), _freeze(dst)
def transform(self, src, dst, value):
"""Transform a value.
"""
_key = self.__keytransform__(src, dst)
return self.funcs[_key](value, **self.defaults)
class _ContextChain(ChainMap):
"""A specialized ChainMap for contexts that simplifies finding rules
to transform from one dimension to another.
"""
def __init__(self, *args, **kwargs):
super(_ContextChain, self).__init__(*args, **kwargs)
self._graph = None
def insert_contexts(self, *contexts):
"""Insert one or more contexts in reversed order the chained map.
(A rule in last context will take precedence)
To facilitate the identification of the context with the matching rule,
the *refs_to_self* dictionary of the context is used.
"""
self.maps = [ctx.refs_to_self for ctx in reversed(contexts)] + self.maps
self._graph = None
def remove_contexts(self, n):
"""Remove the last n inserted contexts from the chain.
"""
self.maps = self.maps[n:]
self._graph = None
@property
def defaults(self):
if self:
return list(self.maps[0].values())[0].defaults
return {}
@property
def graph(self):
"""The graph relating
"""
if self._graph is None:
self._graph = defaultdict(set)
for fr_, to_ in self:
self._graph[fr_].add(to_)
return self._graph
def transform(self, src, dst, value):
"""Transform the value, finding the rule in the chained context.
(A rule in last context will take precedence)
:raises: KeyError if the rule is not found.
"""
return self[(src, dst)].transform(src, dst, value)
class PrefixDefinition(Definition):
"""Definition of a prefix.
"""
@@ -418,20 +545,11 @@ class UnitRegistry(object):
#: Map suffix name (string) to canonical , and unit alias to canonical unit name
self._suffixes = {'': None, 's': ''}
# A context defines transformation rules between base dimensions (e.g. time and length).
# Transformations are stored in a dict with:
# - key: tuple with source and destination dimensions represented
# as set of UnitContainer.items()
# - value: conversion function taking a single value.
#: Map context name (string) or abbreviation to context.
self._contexts = {}
#: Stores active contexts.
self._active_ctx = ChainMap()
#: Store a graph representation of the context.
self._active_ctx_graph = None
self._active_ctx = _ContextChain()
#: When performing a multiplication of units, interpret
#: non-multiplicative units as their *delta* counterparts.
@@ -459,34 +577,54 @@ class UnitRegistry(object):
'convert', 'get_base_units']
@contextmanager
def context(self, *names):
def context(self, *names, **kwargs):
"""Used as a context manager, this function enables to activate a context
which is removed after usage.
:param names: name of the context.
:param kwargs: keyword arguments for the
Multiple contexts can be called in single call or nested::
Context are called by their name::
>>> with ureg.context('one', 'two'):
... pass
>>> with ureg.context('one'):
... with ureg.context('two'):
... pass
If the context has an argument, you can specify it's value as a keyword
argument::
>>> with ureg.context('one', n=1):
... pass
Multiple contexts can be entered in single call:
>>> with ureg.context('one', 'two', n=1):
... pass
or nested allowing you to give different values to the same keyword argument::
>>> with ureg.context('one', n=1):
... with ureg.context('two', n=2):
... pass
A nested context inherits the defaults from the containing context::
>>> with ureg.context('one', n=1):
... with ureg.context('two'): # Here n takes the value of the upper context
... pass
"""
# For each name, we first find the corresponding context.
ctxs = tuple(self._contexts[name] for name in names)
# If present, copy the defaults from the containing contexts
if self._active_ctx.defaults:
kwargs = dict(self._active_ctx.defaults, **kwargs)
# For each name, we first find the corresponding context
# and create a new one with the new defaults.
ctxs = tuple(_Context.from_context(self._contexts[name], **kwargs)
for name in names)
# And then add them to the active context.
for ctx in ctxs:
self._active_ctx = self._active_ctx.new_child(ctx)
# The graph representing connections between dimensions is rebuilt
# from the connections (edges) stored in the context.
self._active_ctx_graph = defaultdict(list)
for fr_, to_ in self._active_ctx.keys():
self._active_ctx_graph[fr_].append(to_)
self._active_ctx.insert_contexts(*ctxs)
try:
# After adding the context and rebuilding the graph, the registry
@@ -495,14 +633,7 @@ class UnitRegistry(object):
finally:
# Upon leaving the with statement,
# the added contexts are removed from the active one.
for _ in names:
self._active_ctx = self._active_ctx.parents
# The graph representing connections between dimensions is rebuilt
# from the connections (edges) remaining in the context.
self._active_ctx_graph = defaultdict(list)
for fr_, to_ in self._active_ctx.keys():
self._active_ctx_graph[fr_].append(to_)
self._active_ctx.remove_contexts(len(names))
def define(self, definition):
"""Add unit to the registry.
@@ -749,11 +880,11 @@ class UnitRegistry(object):
# destination dimensionality. If it exists, we transform the source value
# by applying sequentially each transformation of the path.
if self._active_ctx:
path = find_shortest_path(self._active_ctx_graph, _freeze(src_dim), _freeze(dst_dim))
path = find_shortest_path(self._active_ctx.graph, _freeze(src_dim), _freeze(dst_dim))
if path:
src = self.Quantity(value, src)
for a, b in zip(path[:-1], path[1:]):
src = self._active_ctx[(a, b)](src)
src = self._active_ctx.transform(a, b, src)
value, src = src.magnitude, src.units