diff --git a/sqlalchemy_utils/aggregates.py b/sqlalchemy_utils/aggregates.py index 958ed1b..07b3f0c 100644 --- a/sqlalchemy_utils/aggregates.py +++ b/sqlalchemy_utils/aggregates.py @@ -375,7 +375,8 @@ except ImportError: # SQLAlchemy 0.8 from sqlalchemy.sql.expression import _FunctionGenerator -from .relationships import chained_join +from .functions.orm import get_column_key +from .relationships import chained_join, select_aggregate aggregated_attrs = WeakKeyDictionary(defaultdict(list)) @@ -404,64 +405,16 @@ class AggregatedAttribute(declared_attr): return desc.column -def aggregate_select(agg_expr, relationships): - """ - Return a subquery for fetching an aggregate value of given aggregate - expression and given sequence of relationships. - - The returned aggregate query can be used when updating denormalized column - value with query such as: - - UPDATE table SET column = {aggregate_query} - WHERE {condition} - - :param agg_expr: - an expression to be selected, for example sa.func.count('1') - :param relationships: - Sequence of relationships to be used for building the aggregate - query. - """ - from_ = relationships[0].mapper.class_.__table__ - for relationship in relationships[0:-1]: - property_ = relationship.property - if property_.secondary is not None: - from_ = from_.join( - property_.secondary, - property_.secondaryjoin - ) - - from_ = ( - from_ - .join( - property_.parent.class_, - property_.primaryjoin - ) - ) - - prop = relationships[-1].property - condition = prop.primaryjoin - if prop.secondary is not None: - from_ = from_.join( - prop.secondary, - prop.secondaryjoin - ) - - query = sa.select( - [agg_expr], - from_obj=[from_] - ) - - return query.where(condition) - - def local_condition(prop, objects): pairs = prop.local_remote_pairs if prop.secondary is not None: - column = pairs[1][0] - key = pairs[1][0].key + parent_column = pairs[1][0] + fetched_column = pairs[1][0] else: - column = pairs[0][0] - key = pairs[0][1].key + parent_column = pairs[0][0] + fetched_column = pairs[0][1] + + key = get_column_key(prop.mapper, fetched_column) values = [] for obj in objects: @@ -471,7 +424,7 @@ def local_condition(prop, objects): pass if values: - return column.in_(values) + return parent_column.in_(values) def aggregate_expression(expr, class_): @@ -492,7 +445,7 @@ class AggregatedValue(object): @property def aggregate_query(self): - query = aggregate_select(self.expr, self.relationships) + query = select_aggregate(self.expr, self.relationships) return query.correlate(self.class_).as_scalar() diff --git a/sqlalchemy_utils/relationships/__init__.py b/sqlalchemy_utils/relationships/__init__.py index 81784bf..9f4a56f 100644 --- a/sqlalchemy_utils/relationships/__init__.py +++ b/sqlalchemy_utils/relationships/__init__.py @@ -1 +1,2 @@ from .chained_join import chained_join +from .select_aggregate import select_aggregate diff --git a/sqlalchemy_utils/relationships/chained_join.py b/sqlalchemy_utils/relationships/chained_join.py index 68803cf..2508d1a 100644 --- a/sqlalchemy_utils/relationships/chained_join.py +++ b/sqlalchemy_utils/relationships/chained_join.py @@ -1,4 +1,7 @@ def chained_join(*relationships): + """ + Return a chained Join object for given relationships. + """ property_ = relationships[0].property if property_.secondary is not None: diff --git a/sqlalchemy_utils/relationships/select_aggregate.py b/sqlalchemy_utils/relationships/select_aggregate.py new file mode 100644 index 0000000..379b33f --- /dev/null +++ b/sqlalchemy_utils/relationships/select_aggregate.py @@ -0,0 +1,51 @@ +import sqlalchemy as sa + + +def select_aggregate(agg_expr, relationships): + """ + Return a subquery for fetching an aggregate value of given aggregate + expression and given sequence of relationships. + + The returned aggregate query can be used when updating denormalized column + value with query such as: + + UPDATE table SET column = {aggregate_query} + WHERE {condition} + + :param agg_expr: + an expression to be selected, for example sa.func.count('1') + :param relationships: + Sequence of relationships to be used for building the aggregate + query. + """ + from_ = relationships[0].mapper.class_.__table__ + for relationship in relationships[0:-1]: + property_ = relationship.property + if property_.secondary is not None: + from_ = from_.join( + property_.secondary, + property_.secondaryjoin + ) + + from_ = ( + from_ + .join( + property_.parent.class_, + property_.primaryjoin + ) + ) + + prop = relationships[-1].property + condition = prop.primaryjoin + if prop.secondary is not None: + from_ = from_.join( + prop.secondary, + prop.secondaryjoin + ) + + query = sa.select( + [agg_expr], + from_obj=[from_] + ) + + return query.where(condition) diff --git a/tests/mixins.py b/tests/mixins.py index 9bb52d4..1224024 100644 --- a/tests/mixins.py +++ b/tests/mixins.py @@ -5,7 +5,7 @@ class ThreeLevelDeepOneToOne(object): def create_models(self): class Catalog(self.Base): __tablename__ = 'catalog' - id = sa.Column(sa.Integer, primary_key=True) + id = sa.Column('_id', sa.Integer, primary_key=True) category = sa.orm.relationship( 'Category', uselist=False, @@ -14,8 +14,12 @@ class ThreeLevelDeepOneToOne(object): class Category(self.Base): __tablename__ = 'category' - id = sa.Column(sa.Integer, primary_key=True) - catalog_id = sa.Column(sa.Integer, sa.ForeignKey('catalog.id')) + id = sa.Column('_id', sa.Integer, primary_key=True) + catalog_id = sa.Column( + '_catalog_id', + sa.Integer, + sa.ForeignKey('catalog._id') + ) sub_category = sa.orm.relationship( 'SubCategory', @@ -25,8 +29,12 @@ class ThreeLevelDeepOneToOne(object): class SubCategory(self.Base): __tablename__ = 'sub_category' - id = sa.Column(sa.Integer, primary_key=True) - category_id = sa.Column(sa.Integer, sa.ForeignKey('category.id')) + id = sa.Column('_id', sa.Integer, primary_key=True) + category_id = sa.Column( + '_category_id', + sa.Integer, + sa.ForeignKey('category._id') + ) product = sa.orm.relationship( 'Product', uselist=False, @@ -35,11 +43,13 @@ class ThreeLevelDeepOneToOne(object): class Product(self.Base): __tablename__ = 'product' - id = sa.Column(sa.Integer, primary_key=True) + id = sa.Column('_id', sa.Integer, primary_key=True) price = sa.Column(sa.Integer) sub_category_id = sa.Column( - sa.Integer, sa.ForeignKey('sub_category.id') + '_sub_category_id', + sa.Integer, + sa.ForeignKey('sub_category._id') ) self.Catalog = Catalog @@ -52,14 +62,18 @@ class ThreeLevelDeepOneToMany(object): def create_models(self): class Catalog(self.Base): __tablename__ = 'catalog' - id = sa.Column(sa.Integer, primary_key=True) + id = sa.Column('_id', sa.Integer, primary_key=True) categories = sa.orm.relationship('Category', backref='catalog') class Category(self.Base): __tablename__ = 'category' - id = sa.Column(sa.Integer, primary_key=True) - catalog_id = sa.Column(sa.Integer, sa.ForeignKey('catalog.id')) + id = sa.Column('_id', sa.Integer, primary_key=True) + catalog_id = sa.Column( + '_catalog_id', + sa.Integer, + sa.ForeignKey('catalog._id') + ) sub_categories = sa.orm.relationship( 'SubCategory', backref='category' @@ -67,8 +81,12 @@ class ThreeLevelDeepOneToMany(object): class SubCategory(self.Base): __tablename__ = 'sub_category' - id = sa.Column(sa.Integer, primary_key=True) - category_id = sa.Column(sa.Integer, sa.ForeignKey('category.id')) + id = sa.Column('_id', sa.Integer, primary_key=True) + category_id = sa.Column( + '_category_id', + sa.Integer, + sa.ForeignKey('category._id') + ) products = sa.orm.relationship( 'Product', backref='sub_category' @@ -76,11 +94,13 @@ class ThreeLevelDeepOneToMany(object): class Product(self.Base): __tablename__ = 'product' - id = sa.Column(sa.Integer, primary_key=True) + id = sa.Column('_id', sa.Integer, primary_key=True) price = sa.Column(sa.Numeric) sub_category_id = sa.Column( - sa.Integer, sa.ForeignKey('sub_category.id') + '_sub_category_id', + sa.Integer, + sa.ForeignKey('sub_category._id') ) def __repr__(self): @@ -97,8 +117,8 @@ class ThreeLevelDeepManyToMany(object): catalog_category = sa.Table( 'catalog_category', self.Base.metadata, - sa.Column('catalog_id', sa.Integer, sa.ForeignKey('catalog.id')), - sa.Column('category_id', sa.Integer, sa.ForeignKey('category.id')) + sa.Column('catalog_id', sa.Integer, sa.ForeignKey('catalog._id')), + sa.Column('category_id', sa.Integer, sa.ForeignKey('category._id')) ) category_subcategory = sa.Table( @@ -107,12 +127,12 @@ class ThreeLevelDeepManyToMany(object): sa.Column( 'category_id', sa.Integer, - sa.ForeignKey('category.id') + sa.ForeignKey('category._id') ), sa.Column( 'subcategory_id', sa.Integer, - sa.ForeignKey('sub_category.id') + sa.ForeignKey('sub_category._id') ) ) @@ -122,18 +142,18 @@ class ThreeLevelDeepManyToMany(object): sa.Column( 'subcategory_id', sa.Integer, - sa.ForeignKey('sub_category.id') + sa.ForeignKey('sub_category._id') ), sa.Column( 'product_id', sa.Integer, - sa.ForeignKey('product.id') + sa.ForeignKey('product._id') ) ) class Catalog(self.Base): __tablename__ = 'catalog' - id = sa.Column(sa.Integer, primary_key=True) + id = sa.Column('_id', sa.Integer, primary_key=True) categories = sa.orm.relationship( 'Category', @@ -143,7 +163,7 @@ class ThreeLevelDeepManyToMany(object): class Category(self.Base): __tablename__ = 'category' - id = sa.Column(sa.Integer, primary_key=True) + id = sa.Column('_id', sa.Integer, primary_key=True) sub_categories = sa.orm.relationship( 'SubCategory', @@ -153,7 +173,7 @@ class ThreeLevelDeepManyToMany(object): class SubCategory(self.Base): __tablename__ = 'sub_category' - id = sa.Column(sa.Integer, primary_key=True) + id = sa.Column('_id', sa.Integer, primary_key=True) products = sa.orm.relationship( 'Product', backref='sub_categories', @@ -162,7 +182,7 @@ class ThreeLevelDeepManyToMany(object): class Product(self.Base): __tablename__ = 'product' - id = sa.Column(sa.Integer, primary_key=True) + id = sa.Column('_id', sa.Integer, primary_key=True) price = sa.Column(sa.Numeric) self.Catalog = Catalog diff --git a/tests/relationships/test_chained_join.py b/tests/relationships/test_chained_join.py index 3e748ca..a19ebcf 100644 --- a/tests/relationships/test_chained_join.py +++ b/tests/relationships/test_chained_join.py @@ -14,7 +14,7 @@ class TestChainedJoinFoDeepToManyToMany(ThreeLevelDeepManyToMany, TestCase): def test_simple_join(self): assert str(chained_join(self.Catalog.categories)) == ( 'catalog_category JOIN category ON ' - 'category.id = catalog_category.category_id' + 'category._id = catalog_category.category_id' ) def test_two_relations(self): @@ -23,10 +23,11 @@ class TestChainedJoinFoDeepToManyToMany(ThreeLevelDeepManyToMany, TestCase): self.Category.sub_categories ) assert str(sql) == ( - 'catalog_category JOIN category ON category.id = ' + 'catalog_category JOIN category ON category._id = ' 'catalog_category.category_id JOIN category_subcategory ON ' - 'category.id = category_subcategory.category_id JOIN sub_category ' - 'ON sub_category.id = category_subcategory.subcategory_id' + 'category._id = category_subcategory.category_id JOIN ' + 'sub_category ON sub_category._id = ' + 'category_subcategory.subcategory_id' ) def test_three_relations(self): @@ -36,13 +37,13 @@ class TestChainedJoinFoDeepToManyToMany(ThreeLevelDeepManyToMany, TestCase): self.SubCategory.products ) assert str(sql) == ( - 'catalog_category JOIN category ON category.id = ' + 'catalog_category JOIN category ON category._id = ' 'catalog_category.category_id JOIN category_subcategory ON ' - 'category.id = category_subcategory.category_id JOIN sub_category ' - 'ON sub_category.id = category_subcategory.subcategory_id JOIN ' - 'subcategory_product ON sub_category.id = ' - 'subcategory_product.subcategory_id JOIN product ON product.id = ' - 'subcategory_product.product_id' + 'category._id = category_subcategory.category_id JOIN sub_category' + ' ON sub_category._id = category_subcategory.subcategory_id JOIN ' + 'subcategory_product ON sub_category._id = ' + 'subcategory_product.subcategory_id JOIN product ON product._id =' + ' subcategory_product.product_id' ) @@ -59,8 +60,8 @@ class TestChainedJoinForDeepOneToMany(ThreeLevelDeepOneToMany, TestCase): self.Category.sub_categories ) assert str(sql) == ( - 'category JOIN sub_category ON category.id = ' - 'sub_category.category_id' + 'category JOIN sub_category ON category._id = ' + 'sub_category._category_id' ) def test_three_relations(self): @@ -70,9 +71,9 @@ class TestChainedJoinForDeepOneToMany(ThreeLevelDeepOneToMany, TestCase): self.SubCategory.products ) assert str(sql) == ( - 'category JOIN sub_category ON category.id = ' - 'sub_category.category_id JOIN product ON sub_category.id = ' - 'product.sub_category_id' + 'category JOIN sub_category ON category._id = ' + 'sub_category._category_id JOIN product ON sub_category._id = ' + 'product._sub_category_id' ) @@ -89,8 +90,8 @@ class TestChainedJoinForDeepOneToOne(ThreeLevelDeepOneToOne, TestCase): self.Category.sub_category ) assert str(sql) == ( - 'category JOIN sub_category ON category.id = ' - 'sub_category.category_id' + 'category JOIN sub_category ON category._id = ' + 'sub_category._category_id' ) def test_three_relations(self): @@ -100,7 +101,7 @@ class TestChainedJoinForDeepOneToOne(ThreeLevelDeepOneToOne, TestCase): self.SubCategory.product ) assert str(sql) == ( - 'category JOIN sub_category ON category.id = ' - 'sub_category.category_id JOIN product ON sub_category.id = ' - 'product.sub_category_id' + 'category JOIN sub_category ON category._id = ' + 'sub_category._category_id JOIN product ON sub_category._id = ' + 'product._sub_category_id' ) diff --git a/tests/relationships/test_select_aggregate.py b/tests/relationships/test_select_aggregate.py new file mode 100644 index 0000000..b0f76ef --- /dev/null +++ b/tests/relationships/test_select_aggregate.py @@ -0,0 +1,64 @@ +import sqlalchemy as sa +from sqlalchemy_utils.aggregates import select_aggregate +from tests import TestCase +from tests.mixins import ( + ThreeLevelDeepManyToMany, + ThreeLevelDeepOneToMany, + ThreeLevelDeepOneToOne, +) + + +def normalize(sql): + return ' '.join(sql.replace('\n', '').split()) + + +class TestAggregateQueryForDeepToManyToMany( + ThreeLevelDeepManyToMany, + TestCase +): + dns = 'postgres://postgres@localhost/sqlalchemy_utils_test' + create_tables = False + + def assert_sql(self, construct, sql): + assert normalize(str(construct)) == normalize(sql) + + def build_update(self, *relationships): + expr = sa.func.count(sa.text('1')) + return ( + self.Catalog.__table__.update().values( + _id=select_aggregate( + expr, + relationships + ).correlate(self.Catalog) + ) + ) + + def test_simple_join(self): + self.assert_sql( + self.build_update(self.Catalog.categories), + ( + '''UPDATE catalog SET _id=(SELECT count(1) AS count_1 + FROM category JOIN catalog_category ON category._id = + catalog_category.category_id WHERE catalog._id = + catalog_category.catalog_id)''' + ) + ) + + def test_two_relations(self): + self.assert_sql( + self.build_update( + self.Category.sub_categories, + self.Catalog.categories, + ), + ( + '''UPDATE catalog SET _id=(SELECT count(1) AS count_1 + FROM sub_category + JOIN category_subcategory + ON sub_category._id = category_subcategory.subcategory_id + JOIN category + ON category._id = category_subcategory.category_id + JOIN catalog_category + ON category._id = catalog_category.category_id + WHERE catalog._id = catalog_category.catalog_id)''' + ) + )