From 2c22e69edcec6252a18f05ce9ff9ef3fc4f0a7e9 Mon Sep 17 00:00:00 2001 From: Konsta Vesterinen Date: Fri, 12 Dec 2014 10:11:19 +0200 Subject: [PATCH 1/8] Add condition parameter for getdotattr --- sqlalchemy_utils/decorators.py | 1 - sqlalchemy_utils/functions/orm.py | 30 ++++++++++++++++++++++++------ sqlalchemy_utils/utils.py | 9 +++++++++ 3 files changed, 33 insertions(+), 7 deletions(-) diff --git a/sqlalchemy_utils/decorators.py b/sqlalchemy_utils/decorators.py index af957b5..9a8a071 100644 --- a/sqlalchemy_utils/decorators.py +++ b/sqlalchemy_utils/decorators.py @@ -3,7 +3,6 @@ import itertools import sqlalchemy as sa import six from .functions import getdotattr -from .path import AttrPath class AttributeValueGenerator(object): diff --git a/sqlalchemy_utils/functions/orm.py b/sqlalchemy_utils/functions/orm.py index 86ba278..12cb752 100644 --- a/sqlalchemy_utils/functions/orm.py +++ b/sqlalchemy_utils/functions/orm.py @@ -16,6 +16,7 @@ from sqlalchemy.orm.properties import ColumnProperty from sqlalchemy.orm.query import _ColumnEntity from sqlalchemy.orm.session import object_session from sqlalchemy.orm.util import AliasedInsp +from sqlalchemy_utils.utils import is_sequence def get_column_key(model, column): @@ -623,7 +624,7 @@ def get_declarative_base(model): return model -def getdotattr(obj_or_class, dot_path): +def getdotattr(obj_or_class, dot_path, condition=None): """ Allow dot-notated strings to be passed to `getattr`. @@ -638,22 +639,39 @@ def getdotattr(obj_or_class, dot_path): :param dot_path: Attribute path with dot mark as separator """ last = obj_or_class - # Coerce object style paths to strings. - path = str(dot_path) - for path in dot_path.split('.'): + for path in str(dot_path).split('.'): getter = attrgetter(path) - if isinstance(last, list): - last = sum((getter(el) for el in last), []) + + if is_sequence(last): + tmp = [] + for element in last: + value = getter(element) + if is_sequence(value): + tmp.extend(value) + else: + tmp.append(value) + last = tmp elif isinstance(last, InstrumentedAttribute): last = getter(last.property.mapper.class_) elif last is None: return None else: last = getter(last) + if condition is not None: + if is_sequence(last): + last = [v for v in last if condition(v)] + else: + if not condition(last): + return None + return last +def is_deleted(obj): + return obj in sa.orm.object_session(obj).deleted + + def has_changes(obj, attrs=None, exclude=None): """ Simple shortcut function for checking if given attributes of given diff --git a/sqlalchemy_utils/utils.py b/sqlalchemy_utils/utils.py index 4efed7b..8c5b6b4 100644 --- a/sqlalchemy_utils/utils.py +++ b/sqlalchemy_utils/utils.py @@ -1,4 +1,7 @@ import sys +from collections import Iterable + +import six def str_coercible(cls): @@ -11,3 +14,9 @@ def str_coercible(cls): cls.__str__ = __str__ return cls + + +def is_sequence(value): + return ( + isinstance(value, Iterable) and not isinstance(value, six.string_types) + ) From a30b42421da68ba599291cac0b83b798b0a85c93 Mon Sep 17 00:00:00 2001 From: Konsta Vesterinen Date: Fri, 12 Dec 2014 10:13:48 +0200 Subject: [PATCH 2/8] Add direction and uselist properties --- sqlalchemy_utils/path.py | 16 +++++++++++++++- tests/test_path.py | 14 +++++++++++++- 2 files changed, 28 insertions(+), 2 deletions(-) diff --git a/sqlalchemy_utils/path.py b/sqlalchemy_utils/path.py index 96f274b..aaa812e 100644 --- a/sqlalchemy_utils/path.py +++ b/sqlalchemy_utils/path.py @@ -1,5 +1,6 @@ import sqlalchemy as sa from sqlalchemy.orm.attributes import InstrumentedAttribute +from sqlalchemy.util.langhelpers import symbol from .utils import str_coercible @@ -105,9 +106,22 @@ class AttrPath(object): if el is element: return index + @property + def direction(self): + symbols = [part.property.direction for part in self.parts] + if symbol('MANYTOMANY') in symbols: + return symbol('MANYTOMANY') + elif symbol('MANYTOONE') in symbols and symbol('ONETOMANY') in symbols: + return symbol('MANYTOMANY') + return symbols[0] + + @property + def uselist(self): + return any(part.property.uselist for part in self.parts) + def __getitem__(self, slice): result = self.parts[slice] - if isinstance(result, list): + if isinstance(result, list) and result: if result[0] is self.parts[0]: class_ = self.class_ else: diff --git a/tests/test_path.py b/tests/test_path.py index bc77e77..3246fa1 100644 --- a/tests/test_path.py +++ b/tests/test_path.py @@ -1,6 +1,7 @@ import six -from pytest import mark import sqlalchemy as sa +from pytest import mark +from sqlalchemy.util.langhelpers import symbol from sqlalchemy_utils.path import Path, AttrPath from tests import TestCase @@ -41,6 +42,17 @@ class TestAttrPath(TestCase): self.Section = Section self.SubSection = SubSection + @mark.parametrize( + ('class_', 'path', 'direction'), + ( + ('SubSection', 'section', symbol('MANYTOONE')), + ) + ) + def test_direction(self, class_, path, direction): + assert ( + AttrPath(getattr(self, class_), path).direction == direction + ) + def test_invert(self): path = ~ AttrPath(self.SubSection, 'section.document') assert path.parts == [ From 1045402b896f5f767ccc233964d77d3f60ba4d7d Mon Sep 17 00:00:00 2001 From: Konsta Vesterinen Date: Fri, 12 Dec 2014 10:18:55 +0200 Subject: [PATCH 3/8] Add observes decorator --- sqlalchemy_utils/observer.py | 139 +++++++++++++++++++++++++++++ tests/observes/__init__.py | 0 tests/observes/test_m2m_m2m_m2m.py | 137 ++++++++++++++++++++++++++++ tests/observes/test_o2m_o2m_o2m.py | 107 ++++++++++++++++++++++ tests/observes/test_o2m_o2o_o2m.py | 95 ++++++++++++++++++++ tests/observes/test_o2o_o2o_o2o.py | 83 +++++++++++++++++ 6 files changed, 561 insertions(+) create mode 100644 sqlalchemy_utils/observer.py create mode 100644 tests/observes/__init__.py create mode 100644 tests/observes/test_m2m_m2m_m2m.py create mode 100644 tests/observes/test_o2m_o2m_o2m.py create mode 100644 tests/observes/test_o2m_o2o_o2m.py create mode 100644 tests/observes/test_o2o_o2o_o2o.py diff --git a/sqlalchemy_utils/observer.py b/sqlalchemy_utils/observer.py new file mode 100644 index 0000000..a6bf0b7 --- /dev/null +++ b/sqlalchemy_utils/observer.py @@ -0,0 +1,139 @@ +import sqlalchemy as sa + +from collections import defaultdict, namedtuple, Iterable +import itertools +from sqlalchemy_utils.functions import getdotattr +from sqlalchemy_utils.path import AttrPath +from sqlalchemy_utils.utils import is_sequence + + +Callback = namedtuple('Callback', ['func', 'path', 'backref', 'fullpath']) + + +class PropertyObserver(object): + def __init__(self): + self.listener_args = [ + ( + sa.orm.mapper, + 'mapper_configured', + self.update_generator_registry + ), + ( + sa.orm.mapper, + 'after_configured', + self.gather_paths + ), + ( + sa.orm.session.Session, + 'before_flush', + self.invoke_callbacks + ) + ] + self.callback_map = defaultdict(list) + # TODO: make the registry a WeakKey dict + self.generator_registry = defaultdict(list) + + def remove_listeners(self): + for args in self.listener_args: + sa.event.remove(*args) + + def register_listeners(self): + for args in self.listener_args: + if not sa.event.contains(*args): + sa.event.listen(*args) + + def update_generator_registry(self, mapper, class_): + """ + Adds generator functions to generator_registry. + """ + + for generator in class_.__dict__.values(): + if hasattr(generator, '__observes__'): + self.generator_registry[class_].append( + generator + ) + + def gather_paths(self): + for class_, callbacks in self.generator_registry.items(): + for callback in callbacks: + path = AttrPath(class_, callback.__observes__) + + self.callback_map[class_].append( + Callback( + func=callback, + path=path, + backref=None, + fullpath=path + ) + ) + + for index in range(len(path)): + i = index + 1 + prop_class = path[index].property.mapper.class_ + self.callback_map[prop_class].append( + Callback( + func=callback, + path=path[i:], + backref=~ (path[:i]), + fullpath=path + ) + ) + + def gather_callback_args(self, obj, callbacks): + session = sa.orm.object_session(obj) + for callback in callbacks: + backref = callback.backref + + root_objs = getdotattr(obj, backref) if backref else obj + if root_objs: + if not isinstance(root_objs, Iterable): + root_objs = [root_objs] + + for root_obj in root_objs: + objects = getdotattr( + root_obj, + callback.fullpath, + lambda obj: obj not in session.deleted + ) + + yield ( + root_obj, + callback.func, + objects + ) + + def changed_objects(self, session): + objs = itertools.chain(session.new, session.dirty, session.deleted) + for obj in objs: + for class_, callbacks in self.callback_map.items(): + if isinstance(obj, class_): + yield obj, callbacks + + def invoke_callbacks(self, session, ctx, instances): + callback_args = defaultdict(lambda: defaultdict(set)) + for obj, callbacks in self.changed_objects(session): + args = self.gather_callback_args(obj, callbacks) + for (root_obj, func, objects) in args: + if is_sequence(objects): + callback_args[root_obj][func] = ( + callback_args[root_obj][func] | set(objects) + ) + else: + callback_args[root_obj][func] = objects + + for root_obj, callback_objs in callback_args.items(): + for callback, objs in callback_objs.items(): + callback(root_obj, objs) + +observer = PropertyObserver() + + +def observes(path, observer=observer): + observer.register_listeners() + + def wraps(func): + def wrapper(self, *args, **kwargs): + return func(self, *args, **kwargs) + wrapper.__observes__ = path + return wrapper + return wraps diff --git a/tests/observes/__init__.py b/tests/observes/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/observes/test_m2m_m2m_m2m.py b/tests/observes/test_m2m_m2m_m2m.py new file mode 100644 index 0000000..3b416f2 --- /dev/null +++ b/tests/observes/test_m2m_m2m_m2m.py @@ -0,0 +1,137 @@ +import sqlalchemy as sa + +from sqlalchemy_utils.observer import observes +from tests import TestCase + + +class TestObservesForManyToManyToManyToMany(TestCase): + dns = 'postgres://postgres@localhost/sqlalchemy_utils_test' + + def create_models(self): + catalog_category = sa.Table( + 'catalog_category', + self.Base.metadata, + sa.Column('catalog_id', sa.Integer, sa.ForeignKey('catalog.id')), + sa.Column('category_id', sa.Integer, sa.ForeignKey('category.id')) + ) + + category_subcategory = sa.Table( + 'category_subcategory', + self.Base.metadata, + sa.Column( + 'category_id', + sa.Integer, + sa.ForeignKey('category.id') + ), + sa.Column( + 'subcategory_id', + sa.Integer, + sa.ForeignKey('sub_category.id') + ) + ) + + subcategory_product = sa.Table( + 'subcategory_product', + self.Base.metadata, + sa.Column( + 'subcategory_id', + sa.Integer, + sa.ForeignKey('sub_category.id') + ), + sa.Column( + 'product_id', + sa.Integer, + sa.ForeignKey('product.id') + ) + ) + + class Catalog(self.Base): + __tablename__ = 'catalog' + id = sa.Column(sa.Integer, primary_key=True) + product_count = sa.Column(sa.Integer, default=0) + + @observes('categories.sub_categories.products') + def product_observer(self, products): + self.product_count = len(products) + + categories = sa.orm.relationship( + 'Category', + backref='catalogs', + secondary=catalog_category + ) + + class Category(self.Base): + __tablename__ = 'category' + id = sa.Column(sa.Integer, primary_key=True) + + sub_categories = sa.orm.relationship( + 'SubCategory', + backref='categories', + secondary=category_subcategory + ) + + class SubCategory(self.Base): + __tablename__ = 'sub_category' + id = sa.Column(sa.Integer, primary_key=True) + products = sa.orm.relationship( + 'Product', + backref='sub_categories', + secondary=subcategory_product + ) + + class Product(self.Base): + __tablename__ = 'product' + id = sa.Column(sa.Integer, primary_key=True) + price = sa.Column(sa.Numeric) + + self.Catalog = Catalog + self.Category = Category + self.SubCategory = SubCategory + self.Product = Product + + def create_catalog(self): + sub_category = self.SubCategory(products=[self.Product()]) + category = self.Category(sub_categories=[sub_category]) + catalog = self.Catalog(categories=[category]) + self.session.add(catalog) + self.session.flush() + return catalog + + def test_simple_insert(self): + catalog = self.create_catalog() + assert catalog.product_count == 1 + + def test_add_leaf_object(self): + catalog = self.create_catalog() + product = self.Product() + catalog.categories[0].sub_categories[0].products.append(product) + self.session.flush() + assert catalog.product_count == 2 + + def test_remove_leaf_object(self): + catalog = self.create_catalog() + product = self.Product() + catalog.categories[0].sub_categories[0].products.append(product) + self.session.flush() + self.session.delete(product) + self.session.flush() + assert catalog.product_count == 1 + + def test_delete_intermediate_object(self): + catalog = self.create_catalog() + self.session.delete(catalog.categories[0].sub_categories[0]) + self.session.commit() + assert catalog.product_count == 0 + + def test_gathered_objects_are_distinct(self): + catalog = self.Catalog() + category = self.Category(catalogs=[catalog]) + product = self.Product() + category.sub_categories.append( + self.SubCategory(products=[product]) + ) + self.session.add( + self.SubCategory(categories=[category], products=[product]) + ) + self.session.commit() + assert catalog.product_count == 1 diff --git a/tests/observes/test_o2m_o2m_o2m.py b/tests/observes/test_o2m_o2m_o2m.py new file mode 100644 index 0000000..9656141 --- /dev/null +++ b/tests/observes/test_o2m_o2m_o2m.py @@ -0,0 +1,107 @@ +import sqlalchemy as sa + +from tests import TestCase +from sqlalchemy_utils.observer import observes + + +class TestObservesFor3LevelDeepOneToMany(TestCase): + dns = 'postgres://postgres@localhost/sqlalchemy_utils_test' + + def create_models(self): + class Catalog(self.Base): + __tablename__ = 'catalog' + id = sa.Column(sa.Integer, primary_key=True) + product_count = sa.Column(sa.Integer, default=0) + + @observes('categories.sub_categories.products') + def product_observer(self, products): + self.product_count = len(products) + + categories = sa.orm.relationship('Category', backref='catalog') + + class Category(self.Base): + __tablename__ = 'category' + id = sa.Column(sa.Integer, primary_key=True) + catalog_id = sa.Column(sa.Integer, sa.ForeignKey('catalog.id')) + + sub_categories = sa.orm.relationship( + 'SubCategory', backref='category' + ) + + class SubCategory(self.Base): + __tablename__ = 'sub_category' + id = sa.Column(sa.Integer, primary_key=True) + category_id = sa.Column(sa.Integer, sa.ForeignKey('category.id')) + products = sa.orm.relationship( + 'Product', + backref='sub_category' + ) + + class Product(self.Base): + __tablename__ = 'product' + id = sa.Column(sa.Integer, primary_key=True) + price = sa.Column(sa.Numeric) + + sub_category_id = sa.Column( + sa.Integer, sa.ForeignKey('sub_category.id') + ) + + def __repr__(self): + return '' % self.id + + self.Catalog = Catalog + self.Category = Category + self.SubCategory = SubCategory + self.Product = Product + + def create_catalog(self): + sub_category = self.SubCategory(products=[self.Product()]) + category = self.Category(sub_categories=[sub_category]) + catalog = self.Catalog(categories=[category]) + self.session.add(catalog) + self.session.commit() + return catalog + + def test_simple_insert(self): + catalog = self.create_catalog() + assert catalog.product_count == 1 + + def test_add_leaf_object(self): + catalog = self.create_catalog() + product = self.Product() + catalog.categories[0].sub_categories[0].products.append(product) + self.session.flush() + assert catalog.product_count == 2 + + def test_remove_leaf_object(self): + catalog = self.create_catalog() + product = self.Product() + catalog.categories[0].sub_categories[0].products.append(product) + self.session.flush() + self.session.delete(product) + self.session.commit() + assert catalog.product_count == 1 + self.session.delete( + catalog.categories[0].sub_categories[0].products[0] + ) + self.session.commit() + assert catalog.product_count == 0 + + def test_delete_intermediate_object(self): + catalog = self.create_catalog() + self.session.delete(catalog.categories[0].sub_categories[0]) + self.session.commit() + assert catalog.product_count == 0 + + def test_gathered_objects_are_distinct(self): + catalog = self.Catalog() + category = self.Category(catalog=catalog) + product = self.Product() + category.sub_categories.append( + self.SubCategory(products=[product]) + ) + self.session.add( + self.SubCategory(category=category, products=[product]) + ) + self.session.commit() + assert catalog.product_count == 1 diff --git a/tests/observes/test_o2m_o2o_o2m.py b/tests/observes/test_o2m_o2o_o2m.py new file mode 100644 index 0000000..a08aa09 --- /dev/null +++ b/tests/observes/test_o2m_o2o_o2m.py @@ -0,0 +1,95 @@ +import sqlalchemy as sa +from sqlalchemy_utils.observer import observes +from tests import TestCase + + +class TestObservesForOneToManyToOneToMany(TestCase): + dns = 'postgres://postgres@localhost/sqlalchemy_utils_test' + + def create_models(self): + class Catalog(self.Base): + __tablename__ = 'catalog' + id = sa.Column(sa.Integer, primary_key=True) + product_count = sa.Column(sa.Integer, default=0) + + @observes('categories.sub_category.products') + def product_observer(self, products): + self.product_count = len(products) + + categories = sa.orm.relationship('Category', backref='catalog') + + class Category(self.Base): + __tablename__ = 'category' + id = sa.Column(sa.Integer, primary_key=True) + catalog_id = sa.Column(sa.Integer, sa.ForeignKey('catalog.id')) + + sub_category = sa.orm.relationship( + 'SubCategory', + uselist=False, + backref='category' + ) + + class SubCategory(self.Base): + __tablename__ = 'sub_category' + id = sa.Column(sa.Integer, primary_key=True) + category_id = sa.Column(sa.Integer, sa.ForeignKey('category.id')) + products = sa.orm.relationship('Product', backref='sub_category') + + class Product(self.Base): + __tablename__ = 'product' + id = sa.Column(sa.Integer, primary_key=True) + price = sa.Column(sa.Numeric) + + sub_category_id = sa.Column( + sa.Integer, sa.ForeignKey('sub_category.id') + ) + + self.Catalog = Catalog + self.Category = Category + self.SubCategory = SubCategory + self.Product = Product + + def create_catalog(self): + sub_category = self.SubCategory(products=[self.Product()]) + category = self.Category(sub_category=sub_category) + catalog = self.Catalog(categories=[category]) + self.session.add(catalog) + self.session.flush() + return catalog + + def test_simple_insert(self): + catalog = self.create_catalog() + assert catalog.product_count == 1 + + def test_add_leaf_object(self): + catalog = self.create_catalog() + product = self.Product() + catalog.categories[0].sub_category.products.append(product) + self.session.flush() + assert catalog.product_count == 2 + + def test_remove_leaf_object(self): + catalog = self.create_catalog() + product = self.Product() + catalog.categories[0].sub_category.products.append(product) + self.session.flush() + self.session.delete(product) + self.session.flush() + assert catalog.product_count == 1 + + def test_delete_intermediate_object(self): + catalog = self.create_catalog() + self.session.delete(catalog.categories[0].sub_category) + self.session.commit() + assert catalog.product_count == 0 + + def test_gathered_objects_are_distinct(self): + catalog = self.Catalog() + category = self.Category(catalog=catalog) + product = self.Product() + category.sub_category = self.SubCategory(products=[product]) + self.session.add( + self.Category(catalog=catalog, sub_category=category.sub_category) + ) + self.session.commit() + assert catalog.product_count == 1 diff --git a/tests/observes/test_o2o_o2o_o2o.py b/tests/observes/test_o2o_o2o_o2o.py new file mode 100644 index 0000000..a62feba --- /dev/null +++ b/tests/observes/test_o2o_o2o_o2o.py @@ -0,0 +1,83 @@ +import sqlalchemy as sa +from sqlalchemy_utils.observer import observes +from tests import TestCase + + +class TestObservesForOneToOneToOneToOne(TestCase): + dns = 'postgres://postgres@localhost/sqlalchemy_utils_test' + + def create_models(self): + class Catalog(self.Base): + __tablename__ = 'catalog' + id = sa.Column(sa.Integer, primary_key=True) + product_price = sa.Column(sa.Integer) + + @observes('category.sub_category.product') + def product_observer(self, product): + self.product_price = product.price if product else None + + category = sa.orm.relationship( + 'Category', + uselist=False, + backref='catalog' + ) + + class Category(self.Base): + __tablename__ = 'category' + id = sa.Column(sa.Integer, primary_key=True) + catalog_id = sa.Column(sa.Integer, sa.ForeignKey('catalog.id')) + + sub_category = sa.orm.relationship( + 'SubCategory', + uselist=False, + backref='category' + ) + + class SubCategory(self.Base): + __tablename__ = 'sub_category' + id = sa.Column(sa.Integer, primary_key=True) + category_id = sa.Column(sa.Integer, sa.ForeignKey('category.id')) + product = sa.orm.relationship( + 'Product', + uselist=False, + backref='sub_category' + ) + + class Product(self.Base): + __tablename__ = 'product' + id = sa.Column(sa.Integer, primary_key=True) + price = sa.Column(sa.Integer) + + sub_category_id = sa.Column( + sa.Integer, sa.ForeignKey('sub_category.id') + ) + + self.Catalog = Catalog + self.Category = Category + self.SubCategory = SubCategory + self.Product = Product + + def create_catalog(self): + sub_category = self.SubCategory(product=self.Product(price=123)) + category = self.Category(sub_category=sub_category) + catalog = self.Catalog(category=category) + self.session.add(catalog) + self.session.flush() + return catalog + + def test_simple_insert(self): + catalog = self.create_catalog() + assert catalog.product_price == 123 + + def test_replace_leaf_object(self): + catalog = self.create_catalog() + product = self.Product(price=44) + catalog.category.sub_category.product = product + self.session.flush() + assert catalog.product_price == 44 + + def test_delete_leaf_object(self): + catalog = self.create_catalog() + self.session.delete(catalog.category.sub_category.product) + self.session.flush() + assert catalog.product_price is None From e1de4836e50b743776db71690988eff65e0fed83 Mon Sep 17 00:00:00 2001 From: Konsta Vesterinen Date: Fri, 12 Dec 2014 10:57:22 +0200 Subject: [PATCH 4/8] Add docs for observers --- docs/index.rst | 1 + docs/observers.rst | 6 ++ sqlalchemy_utils/decorators.py | 3 + sqlalchemy_utils/observer.py | 192 +++++++++++++++++++++++++++++++++ 4 files changed, 202 insertions(+) create mode 100644 docs/observers.rst diff --git a/docs/index.rst b/docs/index.rst index e43720d..97cc80a 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -13,6 +13,7 @@ SQLAlchemy-Utils provides custom data types and various utility functions for SQ data_types range_data_types aggregates + observers decorators generic_relationship database_helpers diff --git a/docs/observers.rst b/docs/observers.rst new file mode 100644 index 0000000..33a12a9 --- /dev/null +++ b/docs/observers.rst @@ -0,0 +1,6 @@ +Observers +========= + +.. automodule:: sqlalchemy_utils.observer + +.. autofunction:: observes diff --git a/sqlalchemy_utils/decorators.py b/sqlalchemy_utils/decorators.py index 9a8a071..8df1d4c 100644 --- a/sqlalchemy_utils/decorators.py +++ b/sqlalchemy_utils/decorators.py @@ -79,6 +79,9 @@ generator = AttributeValueGenerator() def generates(attr, source=None, generator=generator): """ + .. deprecated:: 0.28.0 + Use :meth:`.observer.observes` instead. + Decorator that marks given function as attribute value generator. Many times you may have generated property values. Usual cases include diff --git a/sqlalchemy_utils/observer.py b/sqlalchemy_utils/observer.py index a6bf0b7..d5b51ac 100644 --- a/sqlalchemy_utils/observer.py +++ b/sqlalchemy_utils/observer.py @@ -1,3 +1,154 @@ +""" +This module provides a decorator function for observing changes in given +property. Internally the decorator is implemented using SQLAlchemy event +listeners. Both column properties and relationship properties can be observed. + +Property observers can be used for pre-calculating aggregates and automatic +real-time data denormalization. + +Simple observers +---------------- + +At the heart of the observer extension is the :func:`observes` decorator. You +mark some property path as being observed and the marked method will get +notified when any changes are made to given path. + +Consider the following model structure: + +:: + + class Director(Base): + __tablename__ = 'director' + id = sa.Column(sa.Integer, primary_key=True) + name = sa.Column(sa.String) + date_of_birth = sa.Column(sa.Date) + + class Movie(Base): + __tablename__ = 'movie' + id = sa.Column(sa.Integer, primary_key=True) + name = sa.Column(sa.String) + director_id = sa.Column(sa.Integer, sa.ForeignKey(Director.id)) + director = sa.orm.relationship(Director, backref='movies') + + +Now consider we want to show movies in some listing ordered by director id +first and movie id secondly. If we have many movies then using joins and +ordering by Director.name will be very slow. Here is where denormalization +and :func:`observes` comes to rescue the day. Let's add a new column called +director_name to Movie which will get automatically copied from associated +Director. + + +:: + + from sqlalchemy_utils import observes + + + class Movie(Base): + # same as before.. + director_name = sa.Column(sa.String) + + @observes('director') + def director_observer(self, director): + self.director_name = director.name + +.. note:: + + This example could be done much more efficiently using a compound foreing + key from direcor_name, director_id to Director.name, Director.id but for + the sake of simplicity we added this as an example. + + +Observes vs aggregated +---------------------- + +:func:`observes` and :func:`.aggregates.aggregated` can be used for similar +things. However performance wise you should take the following things into +consideration: + +* :func:`observes` works always inside transaction and deals with objects. If + the relationship observer is observing has large number of objects its better + to use :func:`.aggregates.aggregated`. +* :func:`.aggregates.aggregated` always executes one additional query per + aggregate so in scenarios where the observed relationship has only handful of + objects its better to use :func:`observes` instead. + + +Example 1. Movie with many ratings + +Let's say we have a Movie object with potentially thousands of ratings. In this +case we should always use :func:`.aggregates.aggregated` since iterating +through thousands of objects is slow and very memory consuming. + +Example 2. Product with denormalized catalog name + +Each product belongs to one catalog. Here it is natural to use :func:`observes` +for data denormalization. + + +Deeply nested observing +----------------------- + +Consider the following model structure where Catalog has many Categories and +Category has many Products. + +:: + + class Catalog(Base): + __tablename__ = 'catalog' + id = sa.Column(sa.Integer, primary_key=True) + product_count = sa.Column(sa.Integer, default=0) + + @observes('categories.products') + def product_observer(self, products): + self.product_count = len(products) + + categories = sa.orm.relationship('Category', backref='catalog') + + class Category(Base): + __tablename__ = 'category' + id = sa.Column(sa.Integer, primary_key=True) + catalog_id = sa.Column(sa.Integer, sa.ForeignKey('catalog.id')) + + products = sa.orm.relationship('Product', backref='category') + + class Product(Base): + __tablename__ = 'product' + id = sa.Column(sa.Integer, primary_key=True) + price = sa.Column(sa.Numeric) + + category_id = sa.Column(sa.Integer, sa.ForeignKey('category.id')) + + +:func:`observes` is smart enough to: + +* Notify catalog objects of any changes in associated Product objects +* Notify catalog objects of any changes in Category objects that affect + products (for example if Category gets deleted, or a new Category is added to + Catalog with any number of Products) + + +:: + + category = Category( + products=[Product(), Product()] + ) + category2 = Category( + product=[Product()] + ) + + catalog = Catalog( + categories=[category, category2] + ) + session.add(catalog) + session.commit() + catalog.product_count # 2 + + session.delete(category) + session.commit() + catalog.product_count # 1 + +""" import sqlalchemy as sa from collections import defaultdict, namedtuple, Iterable @@ -42,6 +193,9 @@ class PropertyObserver(object): if not sa.event.contains(*args): sa.event.listen(*args) + def __repr__(self): + return '' + def update_generator_registry(self, mapper, class_): """ Adds generator functions to generator_registry. @@ -129,6 +283,44 @@ observer = PropertyObserver() def observes(path, observer=observer): + """ + Mark method as property observer for given property path. Inside + transaction observer gathers all changes made in given property path and + feeds the changed objects to observer-marked method at the before flush + phase. + + :: + + from sqlalchemy_utils import observes + + + class Catalog(Base): + __tablename__ = 'catalog' + id = sa.Column(sa.Integer, primary_key=True) + category_count = sa.Column(sa.Integer, default=0) + + @observes('categories') + def category_observer(self, categories): + self.category_count = len(categories) + + class Category(Base): + __tablename__ = 'category' + id = sa.Column(sa.Integer, primary_key=True) + catalog_id = sa.Column(sa.Integer, sa.ForeignKey('catalog.id')) + + + catalog = Catalog(categories=[Category(), Category()]) + session.add(catalog) + session.commit() + + catalog.category_count # 2 + + + .. versionadded: 0.28.0 + + :param path: Dot-notated property path, eg. 'categories.products.price' + :param observer: :meth:`PropertyObserver` object + """ observer.register_listeners() def wraps(func): From 1e3d547f61fdd96aa6861d1738dfd8df6eb4ce64 Mon Sep 17 00:00:00 2001 From: Konsta Vesterinen Date: Fri, 12 Dec 2014 10:57:46 +0200 Subject: [PATCH 5/8] Remove empty line --- tests/__init__.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/__init__.py b/tests/__init__.py index 6b885f7..0fedbd6 100644 --- a/tests/__init__.py +++ b/tests/__init__.py @@ -117,4 +117,3 @@ def assert_contains(clause, query): # Test that query executes query.all() assert clause in str(query) - From a69d3c942f4dd9de3af5194335068b085de748f2 Mon Sep 17 00:00:00 2001 From: Konsta Vesterinen Date: Fri, 12 Dec 2014 12:21:30 +0200 Subject: [PATCH 6/8] Fix function reference --- sqlalchemy_utils/decorators.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sqlalchemy_utils/decorators.py b/sqlalchemy_utils/decorators.py index 8df1d4c..64e00ab 100644 --- a/sqlalchemy_utils/decorators.py +++ b/sqlalchemy_utils/decorators.py @@ -80,7 +80,7 @@ generator = AttributeValueGenerator() def generates(attr, source=None, generator=generator): """ .. deprecated:: 0.28.0 - Use :meth:`.observer.observes` instead. + Use :func:`.observer.observes` instead. Decorator that marks given function as attribute value generator. From f16a4dc3211412873d7804a78955a08ae89de2a7 Mon Sep 17 00:00:00 2001 From: Konsta Vesterinen Date: Fri, 12 Dec 2014 13:03:06 +0200 Subject: [PATCH 7/8] Bump version --- CHANGES.rst | 3 ++- sqlalchemy_utils/__init__.py | 3 ++- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/CHANGES.rst b/CHANGES.rst index 65e0806..051d57c 100644 --- a/CHANGES.rst +++ b/CHANGES.rst @@ -4,10 +4,11 @@ Changelog Here you can see the full list of changes between each SQLAlchemy-Utils release. -0.27.12 (2014-12-xx) +0.28.0 (2014-12-12) ^^^^^^^^^^^^^^^^^^^^ - Fixed PhoneNumber string coercion (#93) +- Added observes decorator (generates decorator will be deprecated later) 0.27.11 (2014-12-06) diff --git a/sqlalchemy_utils/__init__.py b/sqlalchemy_utils/__init__.py index c3ecb99..bf841d1 100644 --- a/sqlalchemy_utils/__init__.py +++ b/sqlalchemy_utils/__init__.py @@ -46,6 +46,7 @@ from .listeners import ( ) from .generic import generic_relationship from .proxy_dict import ProxyDict, proxy_dict +from .observer import observes from .query_chain import QueryChain from .types import ( ArrowType, @@ -80,7 +81,7 @@ from .types import ( from .models import Timestamp -__version__ = '0.27.11' +__version__ = '0.28.0' __all__ = ( From 74d332bdeb955f4e9334922f2bd96aef9a7ce405 Mon Sep 17 00:00:00 2001 From: Konsta Vesterinen Date: Fri, 12 Dec 2014 19:26:29 +0200 Subject: [PATCH 8/8] Refactor naturally_equivalent --- sqlalchemy_utils/functions/orm.py | 9 +++------ 1 file changed, 3 insertions(+), 6 deletions(-) diff --git a/sqlalchemy_utils/functions/orm.py b/sqlalchemy_utils/functions/orm.py index 12cb752..a2948f5 100644 --- a/sqlalchemy_utils/functions/orm.py +++ b/sqlalchemy_utils/functions/orm.py @@ -843,13 +843,10 @@ def naturally_equivalent(obj, obj2): :param obj: SQLAlchemy declarative model object :param obj2: SQLAlchemy declarative model object to compare with `obj` """ - for prop in sa.inspect(obj.__class__).iterate_properties: - if not isinstance(prop, sa.orm.ColumnProperty): + for column_key, column in sa.inspect(obj.__class__).columns.items(): + if column.primary_key: continue - if prop.columns[0].primary_key: - continue - - if not (getattr(obj, prop.key) == getattr(obj2, prop.key)): + if not (getattr(obj, column_key) == getattr(obj2, column_key)): return False return True