Merge pull request #231 from quantus/feature/multi-column-observer

Feature/multi column observer
This commit is contained in:
Konsta Vesterinen
2016-07-17 22:58:55 +03:00
committed by GitHub
2 changed files with 151 additions and 47 deletions

View File

@@ -148,6 +148,29 @@ Category has many Products.
session.commit() session.commit()
catalog.product_count # 1 catalog.product_count # 1
Observing multiple columns
-----------------------
You can also observe multiple columns by spesifying all the observable columns
in the decorator.
::
class Order(Base):
__tablename__ = 'order'
id = sa.Column(sa.Integer, primary_key=True)
unit_price = sa.Column(sa.Integer)
amount = sa.Column(sa.Integer)
total_price = sa.Column(sa.Integer)
@observes('amount', 'unit_price')
def total_price_observer(self, amount, unit_price):
self.total_price = amount * unit_price
""" """
import itertools import itertools
from collections import defaultdict, Iterable, namedtuple from collections import defaultdict, Iterable, namedtuple
@@ -158,7 +181,7 @@ from .functions import getdotattr, has_changes
from .path import AttrPath from .path import AttrPath
from .utils import is_sequence from .utils import is_sequence
Callback = namedtuple('Callback', ['func', 'path', 'backref', 'fullpath']) Callback = namedtuple('Callback', ['func', 'backref', 'fullpath'])
class PropertyObserver(object): class PropertyObserver(object):
@@ -208,16 +231,18 @@ class PropertyObserver(object):
) )
def gather_paths(self): def gather_paths(self):
for class_, callbacks in self.generator_registry.items(): for class_, generators in self.generator_registry.items():
for callback in callbacks: for callback in generators:
path = AttrPath(class_, callback.__observes__) full_paths = []
for call_path in callback.__observes__:
full_paths.append(AttrPath(class_, call_path))
for path in full_paths:
self.callback_map[class_].append( self.callback_map[class_].append(
Callback( Callback(
func=callback, func=callback,
path=path,
backref=None, backref=None,
fullpath=path fullpath=full_paths
) )
) )
@@ -229,9 +254,8 @@ class PropertyObserver(object):
self.callback_map[prop_class].append( self.callback_map[prop_class].append(
Callback( Callback(
func=callback, func=callback,
path=path[i:],
backref=~ (path[:i]), backref=~ (path[:i]),
fullpath=path fullpath=full_paths
) )
) )
@@ -252,12 +276,13 @@ class PropertyObserver(object):
def get_callback_args(self, root_obj, callback): def get_callback_args(self, root_obj, callback):
session = sa.orm.object_session(root_obj) session = sa.orm.object_session(root_obj)
objects = getdotattr( objects = [getdotattr(
root_obj, root_obj,
callback.fullpath, path,
lambda obj: obj not in session.deleted lambda obj: obj not in session.deleted
) ) for path in callback.fullpath]
path = str(callback.fullpath) paths = [str(path) for path in callback.fullpath]
for path in paths:
if '.' in path or has_changes(root_obj, path): if '.' in path or has_changes(root_obj, path):
return ( return (
root_obj, root_obj,
@@ -277,21 +302,25 @@ class PropertyObserver(object):
for obj, callbacks in self.iterate_objects_and_callbacks(session): for obj, callbacks in self.iterate_objects_and_callbacks(session):
args = self.gather_callback_args(obj, callbacks) args = self.gather_callback_args(obj, callbacks)
for (root_obj, func, objects) in args: for (root_obj, func, objects) in args:
if is_sequence(objects): if not callback_args[root_obj][func]:
callback_args[root_obj][func] = ( callback_args[root_obj][func] = {}
callback_args[root_obj][func] | set(objects) for i, object_ in enumerate(objects):
if is_sequence(object_):
callback_args[root_obj][func][i] = (
callback_args[root_obj][func].get(i, set()) |
set(object_)
) )
else: else:
callback_args[root_obj][func] = objects callback_args[root_obj][func][i] = object_
for root_obj, callback_objs in callback_args.items(): for root_obj, callback_objs in callback_args.items():
for callback, objs in callback_objs.items(): for callback, objs in callback_objs.items():
callback(root_obj, objs) callback(root_obj, *[objs[i] for i in range(len(objs))])
observer = PropertyObserver() observer = PropertyObserver()
def observes(path, observer=observer): def observes(*paths, **observer_kw):
""" """
Mark method as property observer for the given property path. Inside Mark method as property observer for the given property path. Inside
transaction observer gathers all changes made in given property path and transaction observer gathers all changes made in given property path and
@@ -327,14 +356,17 @@ def observes(path, observer=observer):
.. versionadded: 0.28.0 .. versionadded: 0.28.0
:param path: Dot-notated property path, eg. 'categories.products.price' :param *paths: One or more dot-notated property paths, eg.
:param observer: :meth:`PropertyObserver` object 'categories.products.price'
:param **observer: A dictionary where value for key 'observer' contains
:meth:`PropertyObserver` object
""" """
observer.register_listeners() observer_ = observer_kw.pop('observer', observer)
observer_.register_listeners()
def wraps(func): def wraps(func):
def wrapper(self, *args, **kwargs): def wrapper(self, *args, **kwargs):
return func(self, *args, **kwargs) return func(self, *args, **kwargs)
wrapper.__observes__ = path wrapper.__observes__ = paths
return wrapper return wrapper
return wraps return wraps

View File

@@ -58,3 +58,75 @@ class TestObservesForColumnWithoutActualChanges(object):
product.price = 500 product.price = 500
session.commit() session.commit()
assert str(e.value) == 'Trying to change price' assert str(e.value) == 'Trying to change price'
@pytest.mark.usefixtures('postgresql_dsn')
class TestObservesForMultipleColumns(object):
@pytest.fixture
def Order(self, Base):
class Order(Base):
__tablename__ = 'order'
id = sa.Column(sa.Integer, primary_key=True)
unit_price = sa.Column(sa.Integer)
amount = sa.Column(sa.Integer)
total_price = sa.Column(sa.Integer)
@observes('amount', 'unit_price')
def total_price_observer(self, amount, unit_price):
self.total_price = amount * unit_price
return Order
@pytest.fixture
def init_models(self, Order):
pass
def test_only_notifies_observer_on_actual_changes(self, session, Order):
order = Order()
order.amount = 2
order.unit_price = 10
session.add(order)
session.flush()
order.amount = 1
session.flush()
assert order.total_price == 10
order.unit_price = 100
session.flush()
assert order.total_price == 100
@pytest.mark.usefixtures('postgresql_dsn')
class TestObservesForMultipleColumnsFiresOnlyOnce(object):
@pytest.fixture
def Order(self, Base):
class Order(Base):
__tablename__ = 'order'
id = sa.Column(sa.Integer, primary_key=True)
unit_price = sa.Column(sa.Integer)
amount = sa.Column(sa.Integer)
@observes('amount', 'unit_price')
def total_price_observer(self, amount, unit_price):
self.call_count = self.call_count + 1
return Order
@pytest.fixture
def init_models(self, Order):
pass
def test_only_notifies_observer_on_actual_changes(self, session, Order):
order = Order()
order.amount = 2
order.unit_price = 10
order.call_count = 0
session.add(order)
session.flush()
assert order.call_count == 1
order.amount = 1
order.unit_price = 100
session.flush()
assert order.call_count == 2