From ae14380e21ce31b0b3408f357dc75c75a3e95be9 Mon Sep 17 00:00:00 2001 From: Konsta Vesterinen Date: Mon, 26 Oct 2015 14:16:58 +0200 Subject: [PATCH] Fix observer notification on column properties, fixes #138 --- sqlalchemy_utils/__init__.py | 2 +- sqlalchemy_utils/observer.py | 34 +++++++++++++++----------- tests/observes/test_column_property.py | 27 ++++++++++++++++++++ 3 files changed, 48 insertions(+), 15 deletions(-) diff --git a/sqlalchemy_utils/__init__.py b/sqlalchemy_utils/__init__.py index de30ed3..9e2c582 100644 --- a/sqlalchemy_utils/__init__.py +++ b/sqlalchemy_utils/__init__.py @@ -93,4 +93,4 @@ from .types import ( # noqa WeekDaysType ) -__version__ = '0.31.0' +__version__ = '0.31.1' diff --git a/sqlalchemy_utils/observer.py b/sqlalchemy_utils/observer.py index c109480..bd7c1de 100644 --- a/sqlalchemy_utils/observer.py +++ b/sqlalchemy_utils/observer.py @@ -154,7 +154,7 @@ from collections import defaultdict, Iterable, namedtuple import sqlalchemy as sa -from sqlalchemy_utils.functions import getdotattr +from sqlalchemy_utils.functions import getdotattr, has_changes from sqlalchemy_utils.path import AttrPath from sqlalchemy_utils.utils import is_sequence @@ -236,7 +236,6 @@ class PropertyObserver(object): ) def gather_callback_args(self, obj, callbacks): - session = sa.orm.object_session(obj) for callback in callbacks: backref = callback.backref @@ -246,19 +245,26 @@ class PropertyObserver(object): root_objs = [root_objs] for root_obj in root_objs: - objects = getdotattr( - root_obj, - callback.fullpath, - lambda obj: obj not in session.deleted - ) + args = self.get_callback_args(root_obj, callback) + if args: + yield args - yield ( - root_obj, - callback.func, - objects - ) + def get_callback_args(self, root_obj, callback): + session = sa.orm.object_session(root_obj) + objects = getdotattr( + root_obj, + callback.fullpath, + lambda obj: obj not in session.deleted + ) + path = str(callback.fullpath) + if '.' in path or has_changes(root_obj, path): + return ( + root_obj, + callback.func, + objects + ) - def changed_objects(self, session): + def iterate_objects_and_callbacks(self, session): objs = itertools.chain(session.new, session.dirty, session.deleted) for obj in objs: for class_, callbacks in self.callback_map.items(): @@ -267,7 +273,7 @@ class PropertyObserver(object): def invoke_callbacks(self, session, ctx, instances): callback_args = defaultdict(lambda: defaultdict(set)) - for obj, callbacks in self.changed_objects(session): + for obj, callbacks in self.iterate_objects_and_callbacks(session): args = self.gather_callback_args(obj, callbacks) for (root_obj, func, objects) in args: if is_sequence(objects): diff --git a/tests/observes/test_column_property.py b/tests/observes/test_column_property.py index 8339f4b..058383a 100644 --- a/tests/observes/test_column_property.py +++ b/tests/observes/test_column_property.py @@ -1,4 +1,5 @@ import sqlalchemy as sa +from pytest import raises from sqlalchemy_utils.observer import observes from tests import TestCase @@ -24,3 +25,29 @@ class TestObservesForColumn(TestCase): self.session.add(product) self.session.flush() assert product.price == 200 + + +class TestObservesForColumnWithoutActualChanges(TestCase): + dns = 'postgres://postgres@localhost/sqlalchemy_utils_test' + + def create_models(self): + class Product(self.Base): + __tablename__ = 'product' + id = sa.Column(sa.Integer, primary_key=True) + price = sa.Column(sa.Integer) + + @observes('price') + def product_price_observer(self, price): + raise Exception('Trying to change price') + + self.Product = Product + + def test_only_notifies_observer_on_actual_changes(self): + product = self.Product() + self.session.add(product) + self.session.flush() + + with raises(Exception) as e: + product.price = 500 + self.session.commit() + assert str(e.value) == 'Trying to change price'