Make aggregates fully support column aliases

This commit is contained in:
Konsta Vesterinen
2014-12-16 15:41:46 +02:00
parent be21ffc1d5
commit dd3758e40a
7 changed files with 194 additions and 101 deletions

View File

@@ -375,7 +375,8 @@ except ImportError:
# SQLAlchemy 0.8 # SQLAlchemy 0.8
from sqlalchemy.sql.expression import _FunctionGenerator from sqlalchemy.sql.expression import _FunctionGenerator
from .relationships import chained_join from .functions.orm import get_column_key
from .relationships import chained_join, select_aggregate
aggregated_attrs = WeakKeyDictionary(defaultdict(list)) aggregated_attrs = WeakKeyDictionary(defaultdict(list))
@@ -404,64 +405,16 @@ class AggregatedAttribute(declared_attr):
return desc.column return desc.column
def aggregate_select(agg_expr, relationships):
"""
Return a subquery for fetching an aggregate value of given aggregate
expression and given sequence of relationships.
The returned aggregate query can be used when updating denormalized column
value with query such as:
UPDATE table SET column = {aggregate_query}
WHERE {condition}
:param agg_expr:
an expression to be selected, for example sa.func.count('1')
:param relationships:
Sequence of relationships to be used for building the aggregate
query.
"""
from_ = relationships[0].mapper.class_.__table__
for relationship in relationships[0:-1]:
property_ = relationship.property
if property_.secondary is not None:
from_ = from_.join(
property_.secondary,
property_.secondaryjoin
)
from_ = (
from_
.join(
property_.parent.class_,
property_.primaryjoin
)
)
prop = relationships[-1].property
condition = prop.primaryjoin
if prop.secondary is not None:
from_ = from_.join(
prop.secondary,
prop.secondaryjoin
)
query = sa.select(
[agg_expr],
from_obj=[from_]
)
return query.where(condition)
def local_condition(prop, objects): def local_condition(prop, objects):
pairs = prop.local_remote_pairs pairs = prop.local_remote_pairs
if prop.secondary is not None: if prop.secondary is not None:
column = pairs[1][0] parent_column = pairs[1][0]
key = pairs[1][0].key fetched_column = pairs[1][0]
else: else:
column = pairs[0][0] parent_column = pairs[0][0]
key = pairs[0][1].key fetched_column = pairs[0][1]
key = get_column_key(prop.mapper, fetched_column)
values = [] values = []
for obj in objects: for obj in objects:
@@ -471,7 +424,7 @@ def local_condition(prop, objects):
pass pass
if values: if values:
return column.in_(values) return parent_column.in_(values)
def aggregate_expression(expr, class_): def aggregate_expression(expr, class_):
@@ -492,7 +445,7 @@ class AggregatedValue(object):
@property @property
def aggregate_query(self): def aggregate_query(self):
query = aggregate_select(self.expr, self.relationships) query = select_aggregate(self.expr, self.relationships)
return query.correlate(self.class_).as_scalar() return query.correlate(self.class_).as_scalar()

View File

@@ -1 +1,2 @@
from .chained_join import chained_join from .chained_join import chained_join
from .select_aggregate import select_aggregate

View File

@@ -1,4 +1,7 @@
def chained_join(*relationships): def chained_join(*relationships):
"""
Return a chained Join object for given relationships.
"""
property_ = relationships[0].property property_ = relationships[0].property
if property_.secondary is not None: if property_.secondary is not None:

View File

@@ -0,0 +1,51 @@
import sqlalchemy as sa
def select_aggregate(agg_expr, relationships):
"""
Return a subquery for fetching an aggregate value of given aggregate
expression and given sequence of relationships.
The returned aggregate query can be used when updating denormalized column
value with query such as:
UPDATE table SET column = {aggregate_query}
WHERE {condition}
:param agg_expr:
an expression to be selected, for example sa.func.count('1')
:param relationships:
Sequence of relationships to be used for building the aggregate
query.
"""
from_ = relationships[0].mapper.class_.__table__
for relationship in relationships[0:-1]:
property_ = relationship.property
if property_.secondary is not None:
from_ = from_.join(
property_.secondary,
property_.secondaryjoin
)
from_ = (
from_
.join(
property_.parent.class_,
property_.primaryjoin
)
)
prop = relationships[-1].property
condition = prop.primaryjoin
if prop.secondary is not None:
from_ = from_.join(
prop.secondary,
prop.secondaryjoin
)
query = sa.select(
[agg_expr],
from_obj=[from_]
)
return query.where(condition)

View File

@@ -5,7 +5,7 @@ class ThreeLevelDeepOneToOne(object):
def create_models(self): def create_models(self):
class Catalog(self.Base): class Catalog(self.Base):
__tablename__ = 'catalog' __tablename__ = 'catalog'
id = sa.Column(sa.Integer, primary_key=True) id = sa.Column('_id', sa.Integer, primary_key=True)
category = sa.orm.relationship( category = sa.orm.relationship(
'Category', 'Category',
uselist=False, uselist=False,
@@ -14,8 +14,12 @@ class ThreeLevelDeepOneToOne(object):
class Category(self.Base): class Category(self.Base):
__tablename__ = 'category' __tablename__ = 'category'
id = sa.Column(sa.Integer, primary_key=True) id = sa.Column('_id', sa.Integer, primary_key=True)
catalog_id = sa.Column(sa.Integer, sa.ForeignKey('catalog.id')) catalog_id = sa.Column(
'_catalog_id',
sa.Integer,
sa.ForeignKey('catalog._id')
)
sub_category = sa.orm.relationship( sub_category = sa.orm.relationship(
'SubCategory', 'SubCategory',
@@ -25,8 +29,12 @@ class ThreeLevelDeepOneToOne(object):
class SubCategory(self.Base): class SubCategory(self.Base):
__tablename__ = 'sub_category' __tablename__ = 'sub_category'
id = sa.Column(sa.Integer, primary_key=True) id = sa.Column('_id', sa.Integer, primary_key=True)
category_id = sa.Column(sa.Integer, sa.ForeignKey('category.id')) category_id = sa.Column(
'_category_id',
sa.Integer,
sa.ForeignKey('category._id')
)
product = sa.orm.relationship( product = sa.orm.relationship(
'Product', 'Product',
uselist=False, uselist=False,
@@ -35,11 +43,13 @@ class ThreeLevelDeepOneToOne(object):
class Product(self.Base): class Product(self.Base):
__tablename__ = 'product' __tablename__ = 'product'
id = sa.Column(sa.Integer, primary_key=True) id = sa.Column('_id', sa.Integer, primary_key=True)
price = sa.Column(sa.Integer) price = sa.Column(sa.Integer)
sub_category_id = sa.Column( sub_category_id = sa.Column(
sa.Integer, sa.ForeignKey('sub_category.id') '_sub_category_id',
sa.Integer,
sa.ForeignKey('sub_category._id')
) )
self.Catalog = Catalog self.Catalog = Catalog
@@ -52,14 +62,18 @@ class ThreeLevelDeepOneToMany(object):
def create_models(self): def create_models(self):
class Catalog(self.Base): class Catalog(self.Base):
__tablename__ = 'catalog' __tablename__ = 'catalog'
id = sa.Column(sa.Integer, primary_key=True) id = sa.Column('_id', sa.Integer, primary_key=True)
categories = sa.orm.relationship('Category', backref='catalog') categories = sa.orm.relationship('Category', backref='catalog')
class Category(self.Base): class Category(self.Base):
__tablename__ = 'category' __tablename__ = 'category'
id = sa.Column(sa.Integer, primary_key=True) id = sa.Column('_id', sa.Integer, primary_key=True)
catalog_id = sa.Column(sa.Integer, sa.ForeignKey('catalog.id')) catalog_id = sa.Column(
'_catalog_id',
sa.Integer,
sa.ForeignKey('catalog._id')
)
sub_categories = sa.orm.relationship( sub_categories = sa.orm.relationship(
'SubCategory', backref='category' 'SubCategory', backref='category'
@@ -67,8 +81,12 @@ class ThreeLevelDeepOneToMany(object):
class SubCategory(self.Base): class SubCategory(self.Base):
__tablename__ = 'sub_category' __tablename__ = 'sub_category'
id = sa.Column(sa.Integer, primary_key=True) id = sa.Column('_id', sa.Integer, primary_key=True)
category_id = sa.Column(sa.Integer, sa.ForeignKey('category.id')) category_id = sa.Column(
'_category_id',
sa.Integer,
sa.ForeignKey('category._id')
)
products = sa.orm.relationship( products = sa.orm.relationship(
'Product', 'Product',
backref='sub_category' backref='sub_category'
@@ -76,11 +94,13 @@ class ThreeLevelDeepOneToMany(object):
class Product(self.Base): class Product(self.Base):
__tablename__ = 'product' __tablename__ = 'product'
id = sa.Column(sa.Integer, primary_key=True) id = sa.Column('_id', sa.Integer, primary_key=True)
price = sa.Column(sa.Numeric) price = sa.Column(sa.Numeric)
sub_category_id = sa.Column( sub_category_id = sa.Column(
sa.Integer, sa.ForeignKey('sub_category.id') '_sub_category_id',
sa.Integer,
sa.ForeignKey('sub_category._id')
) )
def __repr__(self): def __repr__(self):
@@ -97,8 +117,8 @@ class ThreeLevelDeepManyToMany(object):
catalog_category = sa.Table( catalog_category = sa.Table(
'catalog_category', 'catalog_category',
self.Base.metadata, self.Base.metadata,
sa.Column('catalog_id', sa.Integer, sa.ForeignKey('catalog.id')), sa.Column('catalog_id', sa.Integer, sa.ForeignKey('catalog._id')),
sa.Column('category_id', sa.Integer, sa.ForeignKey('category.id')) sa.Column('category_id', sa.Integer, sa.ForeignKey('category._id'))
) )
category_subcategory = sa.Table( category_subcategory = sa.Table(
@@ -107,12 +127,12 @@ class ThreeLevelDeepManyToMany(object):
sa.Column( sa.Column(
'category_id', 'category_id',
sa.Integer, sa.Integer,
sa.ForeignKey('category.id') sa.ForeignKey('category._id')
), ),
sa.Column( sa.Column(
'subcategory_id', 'subcategory_id',
sa.Integer, sa.Integer,
sa.ForeignKey('sub_category.id') sa.ForeignKey('sub_category._id')
) )
) )
@@ -122,18 +142,18 @@ class ThreeLevelDeepManyToMany(object):
sa.Column( sa.Column(
'subcategory_id', 'subcategory_id',
sa.Integer, sa.Integer,
sa.ForeignKey('sub_category.id') sa.ForeignKey('sub_category._id')
), ),
sa.Column( sa.Column(
'product_id', 'product_id',
sa.Integer, sa.Integer,
sa.ForeignKey('product.id') sa.ForeignKey('product._id')
) )
) )
class Catalog(self.Base): class Catalog(self.Base):
__tablename__ = 'catalog' __tablename__ = 'catalog'
id = sa.Column(sa.Integer, primary_key=True) id = sa.Column('_id', sa.Integer, primary_key=True)
categories = sa.orm.relationship( categories = sa.orm.relationship(
'Category', 'Category',
@@ -143,7 +163,7 @@ class ThreeLevelDeepManyToMany(object):
class Category(self.Base): class Category(self.Base):
__tablename__ = 'category' __tablename__ = 'category'
id = sa.Column(sa.Integer, primary_key=True) id = sa.Column('_id', sa.Integer, primary_key=True)
sub_categories = sa.orm.relationship( sub_categories = sa.orm.relationship(
'SubCategory', 'SubCategory',
@@ -153,7 +173,7 @@ class ThreeLevelDeepManyToMany(object):
class SubCategory(self.Base): class SubCategory(self.Base):
__tablename__ = 'sub_category' __tablename__ = 'sub_category'
id = sa.Column(sa.Integer, primary_key=True) id = sa.Column('_id', sa.Integer, primary_key=True)
products = sa.orm.relationship( products = sa.orm.relationship(
'Product', 'Product',
backref='sub_categories', backref='sub_categories',
@@ -162,7 +182,7 @@ class ThreeLevelDeepManyToMany(object):
class Product(self.Base): class Product(self.Base):
__tablename__ = 'product' __tablename__ = 'product'
id = sa.Column(sa.Integer, primary_key=True) id = sa.Column('_id', sa.Integer, primary_key=True)
price = sa.Column(sa.Numeric) price = sa.Column(sa.Numeric)
self.Catalog = Catalog self.Catalog = Catalog

View File

@@ -14,7 +14,7 @@ class TestChainedJoinFoDeepToManyToMany(ThreeLevelDeepManyToMany, TestCase):
def test_simple_join(self): def test_simple_join(self):
assert str(chained_join(self.Catalog.categories)) == ( assert str(chained_join(self.Catalog.categories)) == (
'catalog_category JOIN category ON ' 'catalog_category JOIN category ON '
'category.id = catalog_category.category_id' 'category._id = catalog_category.category_id'
) )
def test_two_relations(self): def test_two_relations(self):
@@ -23,10 +23,11 @@ class TestChainedJoinFoDeepToManyToMany(ThreeLevelDeepManyToMany, TestCase):
self.Category.sub_categories self.Category.sub_categories
) )
assert str(sql) == ( assert str(sql) == (
'catalog_category JOIN category ON category.id = ' 'catalog_category JOIN category ON category._id = '
'catalog_category.category_id JOIN category_subcategory ON ' 'catalog_category.category_id JOIN category_subcategory ON '
'category.id = category_subcategory.category_id JOIN sub_category ' 'category._id = category_subcategory.category_id JOIN '
'ON sub_category.id = category_subcategory.subcategory_id' 'sub_category ON sub_category._id = '
'category_subcategory.subcategory_id'
) )
def test_three_relations(self): def test_three_relations(self):
@@ -36,13 +37,13 @@ class TestChainedJoinFoDeepToManyToMany(ThreeLevelDeepManyToMany, TestCase):
self.SubCategory.products self.SubCategory.products
) )
assert str(sql) == ( assert str(sql) == (
'catalog_category JOIN category ON category.id = ' 'catalog_category JOIN category ON category._id = '
'catalog_category.category_id JOIN category_subcategory ON ' 'catalog_category.category_id JOIN category_subcategory ON '
'category.id = category_subcategory.category_id JOIN sub_category ' 'category._id = category_subcategory.category_id JOIN sub_category'
'ON sub_category.id = category_subcategory.subcategory_id JOIN ' ' ON sub_category._id = category_subcategory.subcategory_id JOIN '
'subcategory_product ON sub_category.id = ' 'subcategory_product ON sub_category._id = '
'subcategory_product.subcategory_id JOIN product ON product.id = ' 'subcategory_product.subcategory_id JOIN product ON product._id ='
'subcategory_product.product_id' ' subcategory_product.product_id'
) )
@@ -59,8 +60,8 @@ class TestChainedJoinForDeepOneToMany(ThreeLevelDeepOneToMany, TestCase):
self.Category.sub_categories self.Category.sub_categories
) )
assert str(sql) == ( assert str(sql) == (
'category JOIN sub_category ON category.id = ' 'category JOIN sub_category ON category._id = '
'sub_category.category_id' 'sub_category._category_id'
) )
def test_three_relations(self): def test_three_relations(self):
@@ -70,9 +71,9 @@ class TestChainedJoinForDeepOneToMany(ThreeLevelDeepOneToMany, TestCase):
self.SubCategory.products self.SubCategory.products
) )
assert str(sql) == ( assert str(sql) == (
'category JOIN sub_category ON category.id = ' 'category JOIN sub_category ON category._id = '
'sub_category.category_id JOIN product ON sub_category.id = ' 'sub_category._category_id JOIN product ON sub_category._id = '
'product.sub_category_id' 'product._sub_category_id'
) )
@@ -89,8 +90,8 @@ class TestChainedJoinForDeepOneToOne(ThreeLevelDeepOneToOne, TestCase):
self.Category.sub_category self.Category.sub_category
) )
assert str(sql) == ( assert str(sql) == (
'category JOIN sub_category ON category.id = ' 'category JOIN sub_category ON category._id = '
'sub_category.category_id' 'sub_category._category_id'
) )
def test_three_relations(self): def test_three_relations(self):
@@ -100,7 +101,7 @@ class TestChainedJoinForDeepOneToOne(ThreeLevelDeepOneToOne, TestCase):
self.SubCategory.product self.SubCategory.product
) )
assert str(sql) == ( assert str(sql) == (
'category JOIN sub_category ON category.id = ' 'category JOIN sub_category ON category._id = '
'sub_category.category_id JOIN product ON sub_category.id = ' 'sub_category._category_id JOIN product ON sub_category._id = '
'product.sub_category_id' 'product._sub_category_id'
) )

View File

@@ -0,0 +1,64 @@
import sqlalchemy as sa
from sqlalchemy_utils.aggregates import select_aggregate
from tests import TestCase
from tests.mixins import (
ThreeLevelDeepManyToMany,
ThreeLevelDeepOneToMany,
ThreeLevelDeepOneToOne,
)
def normalize(sql):
return ' '.join(sql.replace('\n', '').split())
class TestAggregateQueryForDeepToManyToMany(
ThreeLevelDeepManyToMany,
TestCase
):
dns = 'postgres://postgres@localhost/sqlalchemy_utils_test'
create_tables = False
def assert_sql(self, construct, sql):
assert normalize(str(construct)) == normalize(sql)
def build_update(self, *relationships):
expr = sa.func.count(sa.text('1'))
return (
self.Catalog.__table__.update().values(
_id=select_aggregate(
expr,
relationships
).correlate(self.Catalog)
)
)
def test_simple_join(self):
self.assert_sql(
self.build_update(self.Catalog.categories),
(
'''UPDATE catalog SET _id=(SELECT count(1) AS count_1
FROM category JOIN catalog_category ON category._id =
catalog_category.category_id WHERE catalog._id =
catalog_category.catalog_id)'''
)
)
def test_two_relations(self):
self.assert_sql(
self.build_update(
self.Category.sub_categories,
self.Catalog.categories,
),
(
'''UPDATE catalog SET _id=(SELECT count(1) AS count_1
FROM sub_category
JOIN category_subcategory
ON sub_category._id = category_subcategory.subcategory_id
JOIN category
ON category._id = category_subcategory.category_id
JOIN catalog_category
ON category._id = catalog_category.category_id
WHERE catalog._id = catalog_category.catalog_id)'''
)
)