Fix observer notification on column properties, fixes #138
This commit is contained in:
@@ -93,4 +93,4 @@ from .types import ( # noqa
|
|||||||
WeekDaysType
|
WeekDaysType
|
||||||
)
|
)
|
||||||
|
|
||||||
__version__ = '0.31.0'
|
__version__ = '0.31.1'
|
||||||
|
@@ -154,7 +154,7 @@ from collections import defaultdict, Iterable, namedtuple
|
|||||||
|
|
||||||
import sqlalchemy as sa
|
import sqlalchemy as sa
|
||||||
|
|
||||||
from sqlalchemy_utils.functions import getdotattr
|
from sqlalchemy_utils.functions import getdotattr, has_changes
|
||||||
from sqlalchemy_utils.path import AttrPath
|
from sqlalchemy_utils.path import AttrPath
|
||||||
from sqlalchemy_utils.utils import is_sequence
|
from sqlalchemy_utils.utils import is_sequence
|
||||||
|
|
||||||
@@ -236,7 +236,6 @@ class PropertyObserver(object):
|
|||||||
)
|
)
|
||||||
|
|
||||||
def gather_callback_args(self, obj, callbacks):
|
def gather_callback_args(self, obj, callbacks):
|
||||||
session = sa.orm.object_session(obj)
|
|
||||||
for callback in callbacks:
|
for callback in callbacks:
|
||||||
backref = callback.backref
|
backref = callback.backref
|
||||||
|
|
||||||
@@ -246,19 +245,26 @@ class PropertyObserver(object):
|
|||||||
root_objs = [root_objs]
|
root_objs = [root_objs]
|
||||||
|
|
||||||
for root_obj in root_objs:
|
for root_obj in root_objs:
|
||||||
objects = getdotattr(
|
args = self.get_callback_args(root_obj, callback)
|
||||||
root_obj,
|
if args:
|
||||||
callback.fullpath,
|
yield args
|
||||||
lambda obj: obj not in session.deleted
|
|
||||||
)
|
|
||||||
|
|
||||||
yield (
|
def get_callback_args(self, root_obj, callback):
|
||||||
root_obj,
|
session = sa.orm.object_session(root_obj)
|
||||||
callback.func,
|
objects = getdotattr(
|
||||||
objects
|
root_obj,
|
||||||
)
|
callback.fullpath,
|
||||||
|
lambda obj: obj not in session.deleted
|
||||||
|
)
|
||||||
|
path = str(callback.fullpath)
|
||||||
|
if '.' in path or has_changes(root_obj, path):
|
||||||
|
return (
|
||||||
|
root_obj,
|
||||||
|
callback.func,
|
||||||
|
objects
|
||||||
|
)
|
||||||
|
|
||||||
def changed_objects(self, session):
|
def iterate_objects_and_callbacks(self, session):
|
||||||
objs = itertools.chain(session.new, session.dirty, session.deleted)
|
objs = itertools.chain(session.new, session.dirty, session.deleted)
|
||||||
for obj in objs:
|
for obj in objs:
|
||||||
for class_, callbacks in self.callback_map.items():
|
for class_, callbacks in self.callback_map.items():
|
||||||
@@ -267,7 +273,7 @@ class PropertyObserver(object):
|
|||||||
|
|
||||||
def invoke_callbacks(self, session, ctx, instances):
|
def invoke_callbacks(self, session, ctx, instances):
|
||||||
callback_args = defaultdict(lambda: defaultdict(set))
|
callback_args = defaultdict(lambda: defaultdict(set))
|
||||||
for obj, callbacks in self.changed_objects(session):
|
for obj, callbacks in self.iterate_objects_and_callbacks(session):
|
||||||
args = self.gather_callback_args(obj, callbacks)
|
args = self.gather_callback_args(obj, callbacks)
|
||||||
for (root_obj, func, objects) in args:
|
for (root_obj, func, objects) in args:
|
||||||
if is_sequence(objects):
|
if is_sequence(objects):
|
||||||
|
@@ -1,4 +1,5 @@
|
|||||||
import sqlalchemy as sa
|
import sqlalchemy as sa
|
||||||
|
from pytest import raises
|
||||||
|
|
||||||
from sqlalchemy_utils.observer import observes
|
from sqlalchemy_utils.observer import observes
|
||||||
from tests import TestCase
|
from tests import TestCase
|
||||||
@@ -24,3 +25,29 @@ class TestObservesForColumn(TestCase):
|
|||||||
self.session.add(product)
|
self.session.add(product)
|
||||||
self.session.flush()
|
self.session.flush()
|
||||||
assert product.price == 200
|
assert product.price == 200
|
||||||
|
|
||||||
|
|
||||||
|
class TestObservesForColumnWithoutActualChanges(TestCase):
|
||||||
|
dns = 'postgres://postgres@localhost/sqlalchemy_utils_test'
|
||||||
|
|
||||||
|
def create_models(self):
|
||||||
|
class Product(self.Base):
|
||||||
|
__tablename__ = 'product'
|
||||||
|
id = sa.Column(sa.Integer, primary_key=True)
|
||||||
|
price = sa.Column(sa.Integer)
|
||||||
|
|
||||||
|
@observes('price')
|
||||||
|
def product_price_observer(self, price):
|
||||||
|
raise Exception('Trying to change price')
|
||||||
|
|
||||||
|
self.Product = Product
|
||||||
|
|
||||||
|
def test_only_notifies_observer_on_actual_changes(self):
|
||||||
|
product = self.Product()
|
||||||
|
self.session.add(product)
|
||||||
|
self.session.flush()
|
||||||
|
|
||||||
|
with raises(Exception) as e:
|
||||||
|
product.price = 500
|
||||||
|
self.session.commit()
|
||||||
|
assert str(e.value) == 'Trying to change price'
|
||||||
|
Reference in New Issue
Block a user