From 1045402b896f5f767ccc233964d77d3f60ba4d7d Mon Sep 17 00:00:00 2001 From: Konsta Vesterinen Date: Fri, 12 Dec 2014 10:18:55 +0200 Subject: [PATCH] Add observes decorator --- sqlalchemy_utils/observer.py | 139 +++++++++++++++++++++++++++++ tests/observes/__init__.py | 0 tests/observes/test_m2m_m2m_m2m.py | 137 ++++++++++++++++++++++++++++ tests/observes/test_o2m_o2m_o2m.py | 107 ++++++++++++++++++++++ tests/observes/test_o2m_o2o_o2m.py | 95 ++++++++++++++++++++ tests/observes/test_o2o_o2o_o2o.py | 83 +++++++++++++++++ 6 files changed, 561 insertions(+) create mode 100644 sqlalchemy_utils/observer.py create mode 100644 tests/observes/__init__.py create mode 100644 tests/observes/test_m2m_m2m_m2m.py create mode 100644 tests/observes/test_o2m_o2m_o2m.py create mode 100644 tests/observes/test_o2m_o2o_o2m.py create mode 100644 tests/observes/test_o2o_o2o_o2o.py diff --git a/sqlalchemy_utils/observer.py b/sqlalchemy_utils/observer.py new file mode 100644 index 0000000..a6bf0b7 --- /dev/null +++ b/sqlalchemy_utils/observer.py @@ -0,0 +1,139 @@ +import sqlalchemy as sa + +from collections import defaultdict, namedtuple, Iterable +import itertools +from sqlalchemy_utils.functions import getdotattr +from sqlalchemy_utils.path import AttrPath +from sqlalchemy_utils.utils import is_sequence + + +Callback = namedtuple('Callback', ['func', 'path', 'backref', 'fullpath']) + + +class PropertyObserver(object): + def __init__(self): + self.listener_args = [ + ( + sa.orm.mapper, + 'mapper_configured', + self.update_generator_registry + ), + ( + sa.orm.mapper, + 'after_configured', + self.gather_paths + ), + ( + sa.orm.session.Session, + 'before_flush', + self.invoke_callbacks + ) + ] + self.callback_map = defaultdict(list) + # TODO: make the registry a WeakKey dict + self.generator_registry = defaultdict(list) + + def remove_listeners(self): + for args in self.listener_args: + sa.event.remove(*args) + + def register_listeners(self): + for args in self.listener_args: + if not sa.event.contains(*args): + sa.event.listen(*args) + + def update_generator_registry(self, mapper, class_): + """ + Adds generator functions to generator_registry. + """ + + for generator in class_.__dict__.values(): + if hasattr(generator, '__observes__'): + self.generator_registry[class_].append( + generator + ) + + def gather_paths(self): + for class_, callbacks in self.generator_registry.items(): + for callback in callbacks: + path = AttrPath(class_, callback.__observes__) + + self.callback_map[class_].append( + Callback( + func=callback, + path=path, + backref=None, + fullpath=path + ) + ) + + for index in range(len(path)): + i = index + 1 + prop_class = path[index].property.mapper.class_ + self.callback_map[prop_class].append( + Callback( + func=callback, + path=path[i:], + backref=~ (path[:i]), + fullpath=path + ) + ) + + def gather_callback_args(self, obj, callbacks): + session = sa.orm.object_session(obj) + for callback in callbacks: + backref = callback.backref + + root_objs = getdotattr(obj, backref) if backref else obj + if root_objs: + if not isinstance(root_objs, Iterable): + root_objs = [root_objs] + + for root_obj in root_objs: + objects = getdotattr( + root_obj, + callback.fullpath, + lambda obj: obj not in session.deleted + ) + + yield ( + root_obj, + callback.func, + objects + ) + + def changed_objects(self, session): + objs = itertools.chain(session.new, session.dirty, session.deleted) + for obj in objs: + for class_, callbacks in self.callback_map.items(): + if isinstance(obj, class_): + yield obj, callbacks + + def invoke_callbacks(self, session, ctx, instances): + callback_args = defaultdict(lambda: defaultdict(set)) + for obj, callbacks in self.changed_objects(session): + args = self.gather_callback_args(obj, callbacks) + for (root_obj, func, objects) in args: + if is_sequence(objects): + callback_args[root_obj][func] = ( + callback_args[root_obj][func] | set(objects) + ) + else: + callback_args[root_obj][func] = objects + + for root_obj, callback_objs in callback_args.items(): + for callback, objs in callback_objs.items(): + callback(root_obj, objs) + +observer = PropertyObserver() + + +def observes(path, observer=observer): + observer.register_listeners() + + def wraps(func): + def wrapper(self, *args, **kwargs): + return func(self, *args, **kwargs) + wrapper.__observes__ = path + return wrapper + return wraps diff --git a/tests/observes/__init__.py b/tests/observes/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/observes/test_m2m_m2m_m2m.py b/tests/observes/test_m2m_m2m_m2m.py new file mode 100644 index 0000000..3b416f2 --- /dev/null +++ b/tests/observes/test_m2m_m2m_m2m.py @@ -0,0 +1,137 @@ +import sqlalchemy as sa + +from sqlalchemy_utils.observer import observes +from tests import TestCase + + +class TestObservesForManyToManyToManyToMany(TestCase): + dns = 'postgres://postgres@localhost/sqlalchemy_utils_test' + + def create_models(self): + catalog_category = sa.Table( + 'catalog_category', + self.Base.metadata, + sa.Column('catalog_id', sa.Integer, sa.ForeignKey('catalog.id')), + sa.Column('category_id', sa.Integer, sa.ForeignKey('category.id')) + ) + + category_subcategory = sa.Table( + 'category_subcategory', + self.Base.metadata, + sa.Column( + 'category_id', + sa.Integer, + sa.ForeignKey('category.id') + ), + sa.Column( + 'subcategory_id', + sa.Integer, + sa.ForeignKey('sub_category.id') + ) + ) + + subcategory_product = sa.Table( + 'subcategory_product', + self.Base.metadata, + sa.Column( + 'subcategory_id', + sa.Integer, + sa.ForeignKey('sub_category.id') + ), + sa.Column( + 'product_id', + sa.Integer, + sa.ForeignKey('product.id') + ) + ) + + class Catalog(self.Base): + __tablename__ = 'catalog' + id = sa.Column(sa.Integer, primary_key=True) + product_count = sa.Column(sa.Integer, default=0) + + @observes('categories.sub_categories.products') + def product_observer(self, products): + self.product_count = len(products) + + categories = sa.orm.relationship( + 'Category', + backref='catalogs', + secondary=catalog_category + ) + + class Category(self.Base): + __tablename__ = 'category' + id = sa.Column(sa.Integer, primary_key=True) + + sub_categories = sa.orm.relationship( + 'SubCategory', + backref='categories', + secondary=category_subcategory + ) + + class SubCategory(self.Base): + __tablename__ = 'sub_category' + id = sa.Column(sa.Integer, primary_key=True) + products = sa.orm.relationship( + 'Product', + backref='sub_categories', + secondary=subcategory_product + ) + + class Product(self.Base): + __tablename__ = 'product' + id = sa.Column(sa.Integer, primary_key=True) + price = sa.Column(sa.Numeric) + + self.Catalog = Catalog + self.Category = Category + self.SubCategory = SubCategory + self.Product = Product + + def create_catalog(self): + sub_category = self.SubCategory(products=[self.Product()]) + category = self.Category(sub_categories=[sub_category]) + catalog = self.Catalog(categories=[category]) + self.session.add(catalog) + self.session.flush() + return catalog + + def test_simple_insert(self): + catalog = self.create_catalog() + assert catalog.product_count == 1 + + def test_add_leaf_object(self): + catalog = self.create_catalog() + product = self.Product() + catalog.categories[0].sub_categories[0].products.append(product) + self.session.flush() + assert catalog.product_count == 2 + + def test_remove_leaf_object(self): + catalog = self.create_catalog() + product = self.Product() + catalog.categories[0].sub_categories[0].products.append(product) + self.session.flush() + self.session.delete(product) + self.session.flush() + assert catalog.product_count == 1 + + def test_delete_intermediate_object(self): + catalog = self.create_catalog() + self.session.delete(catalog.categories[0].sub_categories[0]) + self.session.commit() + assert catalog.product_count == 0 + + def test_gathered_objects_are_distinct(self): + catalog = self.Catalog() + category = self.Category(catalogs=[catalog]) + product = self.Product() + category.sub_categories.append( + self.SubCategory(products=[product]) + ) + self.session.add( + self.SubCategory(categories=[category], products=[product]) + ) + self.session.commit() + assert catalog.product_count == 1 diff --git a/tests/observes/test_o2m_o2m_o2m.py b/tests/observes/test_o2m_o2m_o2m.py new file mode 100644 index 0000000..9656141 --- /dev/null +++ b/tests/observes/test_o2m_o2m_o2m.py @@ -0,0 +1,107 @@ +import sqlalchemy as sa + +from tests import TestCase +from sqlalchemy_utils.observer import observes + + +class TestObservesFor3LevelDeepOneToMany(TestCase): + dns = 'postgres://postgres@localhost/sqlalchemy_utils_test' + + def create_models(self): + class Catalog(self.Base): + __tablename__ = 'catalog' + id = sa.Column(sa.Integer, primary_key=True) + product_count = sa.Column(sa.Integer, default=0) + + @observes('categories.sub_categories.products') + def product_observer(self, products): + self.product_count = len(products) + + categories = sa.orm.relationship('Category', backref='catalog') + + class Category(self.Base): + __tablename__ = 'category' + id = sa.Column(sa.Integer, primary_key=True) + catalog_id = sa.Column(sa.Integer, sa.ForeignKey('catalog.id')) + + sub_categories = sa.orm.relationship( + 'SubCategory', backref='category' + ) + + class SubCategory(self.Base): + __tablename__ = 'sub_category' + id = sa.Column(sa.Integer, primary_key=True) + category_id = sa.Column(sa.Integer, sa.ForeignKey('category.id')) + products = sa.orm.relationship( + 'Product', + backref='sub_category' + ) + + class Product(self.Base): + __tablename__ = 'product' + id = sa.Column(sa.Integer, primary_key=True) + price = sa.Column(sa.Numeric) + + sub_category_id = sa.Column( + sa.Integer, sa.ForeignKey('sub_category.id') + ) + + def __repr__(self): + return '' % self.id + + self.Catalog = Catalog + self.Category = Category + self.SubCategory = SubCategory + self.Product = Product + + def create_catalog(self): + sub_category = self.SubCategory(products=[self.Product()]) + category = self.Category(sub_categories=[sub_category]) + catalog = self.Catalog(categories=[category]) + self.session.add(catalog) + self.session.commit() + return catalog + + def test_simple_insert(self): + catalog = self.create_catalog() + assert catalog.product_count == 1 + + def test_add_leaf_object(self): + catalog = self.create_catalog() + product = self.Product() + catalog.categories[0].sub_categories[0].products.append(product) + self.session.flush() + assert catalog.product_count == 2 + + def test_remove_leaf_object(self): + catalog = self.create_catalog() + product = self.Product() + catalog.categories[0].sub_categories[0].products.append(product) + self.session.flush() + self.session.delete(product) + self.session.commit() + assert catalog.product_count == 1 + self.session.delete( + catalog.categories[0].sub_categories[0].products[0] + ) + self.session.commit() + assert catalog.product_count == 0 + + def test_delete_intermediate_object(self): + catalog = self.create_catalog() + self.session.delete(catalog.categories[0].sub_categories[0]) + self.session.commit() + assert catalog.product_count == 0 + + def test_gathered_objects_are_distinct(self): + catalog = self.Catalog() + category = self.Category(catalog=catalog) + product = self.Product() + category.sub_categories.append( + self.SubCategory(products=[product]) + ) + self.session.add( + self.SubCategory(category=category, products=[product]) + ) + self.session.commit() + assert catalog.product_count == 1 diff --git a/tests/observes/test_o2m_o2o_o2m.py b/tests/observes/test_o2m_o2o_o2m.py new file mode 100644 index 0000000..a08aa09 --- /dev/null +++ b/tests/observes/test_o2m_o2o_o2m.py @@ -0,0 +1,95 @@ +import sqlalchemy as sa +from sqlalchemy_utils.observer import observes +from tests import TestCase + + +class TestObservesForOneToManyToOneToMany(TestCase): + dns = 'postgres://postgres@localhost/sqlalchemy_utils_test' + + def create_models(self): + class Catalog(self.Base): + __tablename__ = 'catalog' + id = sa.Column(sa.Integer, primary_key=True) + product_count = sa.Column(sa.Integer, default=0) + + @observes('categories.sub_category.products') + def product_observer(self, products): + self.product_count = len(products) + + categories = sa.orm.relationship('Category', backref='catalog') + + class Category(self.Base): + __tablename__ = 'category' + id = sa.Column(sa.Integer, primary_key=True) + catalog_id = sa.Column(sa.Integer, sa.ForeignKey('catalog.id')) + + sub_category = sa.orm.relationship( + 'SubCategory', + uselist=False, + backref='category' + ) + + class SubCategory(self.Base): + __tablename__ = 'sub_category' + id = sa.Column(sa.Integer, primary_key=True) + category_id = sa.Column(sa.Integer, sa.ForeignKey('category.id')) + products = sa.orm.relationship('Product', backref='sub_category') + + class Product(self.Base): + __tablename__ = 'product' + id = sa.Column(sa.Integer, primary_key=True) + price = sa.Column(sa.Numeric) + + sub_category_id = sa.Column( + sa.Integer, sa.ForeignKey('sub_category.id') + ) + + self.Catalog = Catalog + self.Category = Category + self.SubCategory = SubCategory + self.Product = Product + + def create_catalog(self): + sub_category = self.SubCategory(products=[self.Product()]) + category = self.Category(sub_category=sub_category) + catalog = self.Catalog(categories=[category]) + self.session.add(catalog) + self.session.flush() + return catalog + + def test_simple_insert(self): + catalog = self.create_catalog() + assert catalog.product_count == 1 + + def test_add_leaf_object(self): + catalog = self.create_catalog() + product = self.Product() + catalog.categories[0].sub_category.products.append(product) + self.session.flush() + assert catalog.product_count == 2 + + def test_remove_leaf_object(self): + catalog = self.create_catalog() + product = self.Product() + catalog.categories[0].sub_category.products.append(product) + self.session.flush() + self.session.delete(product) + self.session.flush() + assert catalog.product_count == 1 + + def test_delete_intermediate_object(self): + catalog = self.create_catalog() + self.session.delete(catalog.categories[0].sub_category) + self.session.commit() + assert catalog.product_count == 0 + + def test_gathered_objects_are_distinct(self): + catalog = self.Catalog() + category = self.Category(catalog=catalog) + product = self.Product() + category.sub_category = self.SubCategory(products=[product]) + self.session.add( + self.Category(catalog=catalog, sub_category=category.sub_category) + ) + self.session.commit() + assert catalog.product_count == 1 diff --git a/tests/observes/test_o2o_o2o_o2o.py b/tests/observes/test_o2o_o2o_o2o.py new file mode 100644 index 0000000..a62feba --- /dev/null +++ b/tests/observes/test_o2o_o2o_o2o.py @@ -0,0 +1,83 @@ +import sqlalchemy as sa +from sqlalchemy_utils.observer import observes +from tests import TestCase + + +class TestObservesForOneToOneToOneToOne(TestCase): + dns = 'postgres://postgres@localhost/sqlalchemy_utils_test' + + def create_models(self): + class Catalog(self.Base): + __tablename__ = 'catalog' + id = sa.Column(sa.Integer, primary_key=True) + product_price = sa.Column(sa.Integer) + + @observes('category.sub_category.product') + def product_observer(self, product): + self.product_price = product.price if product else None + + category = sa.orm.relationship( + 'Category', + uselist=False, + backref='catalog' + ) + + class Category(self.Base): + __tablename__ = 'category' + id = sa.Column(sa.Integer, primary_key=True) + catalog_id = sa.Column(sa.Integer, sa.ForeignKey('catalog.id')) + + sub_category = sa.orm.relationship( + 'SubCategory', + uselist=False, + backref='category' + ) + + class SubCategory(self.Base): + __tablename__ = 'sub_category' + id = sa.Column(sa.Integer, primary_key=True) + category_id = sa.Column(sa.Integer, sa.ForeignKey('category.id')) + product = sa.orm.relationship( + 'Product', + uselist=False, + backref='sub_category' + ) + + class Product(self.Base): + __tablename__ = 'product' + id = sa.Column(sa.Integer, primary_key=True) + price = sa.Column(sa.Integer) + + sub_category_id = sa.Column( + sa.Integer, sa.ForeignKey('sub_category.id') + ) + + self.Catalog = Catalog + self.Category = Category + self.SubCategory = SubCategory + self.Product = Product + + def create_catalog(self): + sub_category = self.SubCategory(product=self.Product(price=123)) + category = self.Category(sub_category=sub_category) + catalog = self.Catalog(category=category) + self.session.add(catalog) + self.session.flush() + return catalog + + def test_simple_insert(self): + catalog = self.create_catalog() + assert catalog.product_price == 123 + + def test_replace_leaf_object(self): + catalog = self.create_catalog() + product = self.Product(price=44) + catalog.category.sub_category.product = product + self.session.flush() + assert catalog.product_price == 44 + + def test_delete_leaf_object(self): + catalog = self.create_catalog() + self.session.delete(catalog.category.sub_category.product) + self.session.flush() + assert catalog.product_price is None