Fix observer notification on column properties, fixes #138

This commit is contained in:
Konsta Vesterinen
2015-10-26 14:16:58 +02:00
parent 5207b40e67
commit ae14380e21
3 changed files with 48 additions and 15 deletions

View File

@@ -93,4 +93,4 @@ from .types import ( # noqa
WeekDaysType
)
__version__ = '0.31.0'
__version__ = '0.31.1'

View File

@@ -154,7 +154,7 @@ from collections import defaultdict, Iterable, namedtuple
import sqlalchemy as sa
from sqlalchemy_utils.functions import getdotattr
from sqlalchemy_utils.functions import getdotattr, has_changes
from sqlalchemy_utils.path import AttrPath
from sqlalchemy_utils.utils import is_sequence
@@ -236,7 +236,6 @@ class PropertyObserver(object):
)
def gather_callback_args(self, obj, callbacks):
session = sa.orm.object_session(obj)
for callback in callbacks:
backref = callback.backref
@@ -246,19 +245,26 @@ class PropertyObserver(object):
root_objs = [root_objs]
for root_obj in root_objs:
objects = getdotattr(
root_obj,
callback.fullpath,
lambda obj: obj not in session.deleted
)
args = self.get_callback_args(root_obj, callback)
if args:
yield args
yield (
root_obj,
callback.func,
objects
)
def get_callback_args(self, root_obj, callback):
session = sa.orm.object_session(root_obj)
objects = getdotattr(
root_obj,
callback.fullpath,
lambda obj: obj not in session.deleted
)
path = str(callback.fullpath)
if '.' in path or has_changes(root_obj, path):
return (
root_obj,
callback.func,
objects
)
def changed_objects(self, session):
def iterate_objects_and_callbacks(self, session):
objs = itertools.chain(session.new, session.dirty, session.deleted)
for obj in objs:
for class_, callbacks in self.callback_map.items():
@@ -267,7 +273,7 @@ class PropertyObserver(object):
def invoke_callbacks(self, session, ctx, instances):
callback_args = defaultdict(lambda: defaultdict(set))
for obj, callbacks in self.changed_objects(session):
for obj, callbacks in self.iterate_objects_and_callbacks(session):
args = self.gather_callback_args(obj, callbacks)
for (root_obj, func, objects) in args:
if is_sequence(objects):

View File

@@ -1,4 +1,5 @@
import sqlalchemy as sa
from pytest import raises
from sqlalchemy_utils.observer import observes
from tests import TestCase
@@ -24,3 +25,29 @@ class TestObservesForColumn(TestCase):
self.session.add(product)
self.session.flush()
assert product.price == 200
class TestObservesForColumnWithoutActualChanges(TestCase):
dns = 'postgres://postgres@localhost/sqlalchemy_utils_test'
def create_models(self):
class Product(self.Base):
__tablename__ = 'product'
id = sa.Column(sa.Integer, primary_key=True)
price = sa.Column(sa.Integer)
@observes('price')
def product_price_observer(self, price):
raise Exception('Trying to change price')
self.Product = Product
def test_only_notifies_observer_on_actual_changes(self):
product = self.Product()
self.session.add(product)
self.session.flush()
with raises(Exception) as e:
product.price = 500
self.session.commit()
assert str(e.value) == 'Trying to change price'