Make aggregates fully support column aliases

This commit is contained in:
Konsta Vesterinen
2014-12-16 15:41:46 +02:00
parent be21ffc1d5
commit dd3758e40a
7 changed files with 194 additions and 101 deletions

View File

@@ -375,7 +375,8 @@ except ImportError:
# SQLAlchemy 0.8
from sqlalchemy.sql.expression import _FunctionGenerator
from .relationships import chained_join
from .functions.orm import get_column_key
from .relationships import chained_join, select_aggregate
aggregated_attrs = WeakKeyDictionary(defaultdict(list))
@@ -404,64 +405,16 @@ class AggregatedAttribute(declared_attr):
return desc.column
def aggregate_select(agg_expr, relationships):
"""
Return a subquery for fetching an aggregate value of given aggregate
expression and given sequence of relationships.
The returned aggregate query can be used when updating denormalized column
value with query such as:
UPDATE table SET column = {aggregate_query}
WHERE {condition}
:param agg_expr:
an expression to be selected, for example sa.func.count('1')
:param relationships:
Sequence of relationships to be used for building the aggregate
query.
"""
from_ = relationships[0].mapper.class_.__table__
for relationship in relationships[0:-1]:
property_ = relationship.property
if property_.secondary is not None:
from_ = from_.join(
property_.secondary,
property_.secondaryjoin
)
from_ = (
from_
.join(
property_.parent.class_,
property_.primaryjoin
)
)
prop = relationships[-1].property
condition = prop.primaryjoin
if prop.secondary is not None:
from_ = from_.join(
prop.secondary,
prop.secondaryjoin
)
query = sa.select(
[agg_expr],
from_obj=[from_]
)
return query.where(condition)
def local_condition(prop, objects):
pairs = prop.local_remote_pairs
if prop.secondary is not None:
column = pairs[1][0]
key = pairs[1][0].key
parent_column = pairs[1][0]
fetched_column = pairs[1][0]
else:
column = pairs[0][0]
key = pairs[0][1].key
parent_column = pairs[0][0]
fetched_column = pairs[0][1]
key = get_column_key(prop.mapper, fetched_column)
values = []
for obj in objects:
@@ -471,7 +424,7 @@ def local_condition(prop, objects):
pass
if values:
return column.in_(values)
return parent_column.in_(values)
def aggregate_expression(expr, class_):
@@ -492,7 +445,7 @@ class AggregatedValue(object):
@property
def aggregate_query(self):
query = aggregate_select(self.expr, self.relationships)
query = select_aggregate(self.expr, self.relationships)
return query.correlate(self.class_).as_scalar()

View File

@@ -1 +1,2 @@
from .chained_join import chained_join
from .select_aggregate import select_aggregate

View File

@@ -1,4 +1,7 @@
def chained_join(*relationships):
"""
Return a chained Join object for given relationships.
"""
property_ = relationships[0].property
if property_.secondary is not None:

View File

@@ -0,0 +1,51 @@
import sqlalchemy as sa
def select_aggregate(agg_expr, relationships):
"""
Return a subquery for fetching an aggregate value of given aggregate
expression and given sequence of relationships.
The returned aggregate query can be used when updating denormalized column
value with query such as:
UPDATE table SET column = {aggregate_query}
WHERE {condition}
:param agg_expr:
an expression to be selected, for example sa.func.count('1')
:param relationships:
Sequence of relationships to be used for building the aggregate
query.
"""
from_ = relationships[0].mapper.class_.__table__
for relationship in relationships[0:-1]:
property_ = relationship.property
if property_.secondary is not None:
from_ = from_.join(
property_.secondary,
property_.secondaryjoin
)
from_ = (
from_
.join(
property_.parent.class_,
property_.primaryjoin
)
)
prop = relationships[-1].property
condition = prop.primaryjoin
if prop.secondary is not None:
from_ = from_.join(
prop.secondary,
prop.secondaryjoin
)
query = sa.select(
[agg_expr],
from_obj=[from_]
)
return query.where(condition)

View File

@@ -5,7 +5,7 @@ class ThreeLevelDeepOneToOne(object):
def create_models(self):
class Catalog(self.Base):
__tablename__ = 'catalog'
id = sa.Column(sa.Integer, primary_key=True)
id = sa.Column('_id', sa.Integer, primary_key=True)
category = sa.orm.relationship(
'Category',
uselist=False,
@@ -14,8 +14,12 @@ class ThreeLevelDeepOneToOne(object):
class Category(self.Base):
__tablename__ = 'category'
id = sa.Column(sa.Integer, primary_key=True)
catalog_id = sa.Column(sa.Integer, sa.ForeignKey('catalog.id'))
id = sa.Column('_id', sa.Integer, primary_key=True)
catalog_id = sa.Column(
'_catalog_id',
sa.Integer,
sa.ForeignKey('catalog._id')
)
sub_category = sa.orm.relationship(
'SubCategory',
@@ -25,8 +29,12 @@ class ThreeLevelDeepOneToOne(object):
class SubCategory(self.Base):
__tablename__ = 'sub_category'
id = sa.Column(sa.Integer, primary_key=True)
category_id = sa.Column(sa.Integer, sa.ForeignKey('category.id'))
id = sa.Column('_id', sa.Integer, primary_key=True)
category_id = sa.Column(
'_category_id',
sa.Integer,
sa.ForeignKey('category._id')
)
product = sa.orm.relationship(
'Product',
uselist=False,
@@ -35,11 +43,13 @@ class ThreeLevelDeepOneToOne(object):
class Product(self.Base):
__tablename__ = 'product'
id = sa.Column(sa.Integer, primary_key=True)
id = sa.Column('_id', sa.Integer, primary_key=True)
price = sa.Column(sa.Integer)
sub_category_id = sa.Column(
sa.Integer, sa.ForeignKey('sub_category.id')
'_sub_category_id',
sa.Integer,
sa.ForeignKey('sub_category._id')
)
self.Catalog = Catalog
@@ -52,14 +62,18 @@ class ThreeLevelDeepOneToMany(object):
def create_models(self):
class Catalog(self.Base):
__tablename__ = 'catalog'
id = sa.Column(sa.Integer, primary_key=True)
id = sa.Column('_id', sa.Integer, primary_key=True)
categories = sa.orm.relationship('Category', backref='catalog')
class Category(self.Base):
__tablename__ = 'category'
id = sa.Column(sa.Integer, primary_key=True)
catalog_id = sa.Column(sa.Integer, sa.ForeignKey('catalog.id'))
id = sa.Column('_id', sa.Integer, primary_key=True)
catalog_id = sa.Column(
'_catalog_id',
sa.Integer,
sa.ForeignKey('catalog._id')
)
sub_categories = sa.orm.relationship(
'SubCategory', backref='category'
@@ -67,8 +81,12 @@ class ThreeLevelDeepOneToMany(object):
class SubCategory(self.Base):
__tablename__ = 'sub_category'
id = sa.Column(sa.Integer, primary_key=True)
category_id = sa.Column(sa.Integer, sa.ForeignKey('category.id'))
id = sa.Column('_id', sa.Integer, primary_key=True)
category_id = sa.Column(
'_category_id',
sa.Integer,
sa.ForeignKey('category._id')
)
products = sa.orm.relationship(
'Product',
backref='sub_category'
@@ -76,11 +94,13 @@ class ThreeLevelDeepOneToMany(object):
class Product(self.Base):
__tablename__ = 'product'
id = sa.Column(sa.Integer, primary_key=True)
id = sa.Column('_id', sa.Integer, primary_key=True)
price = sa.Column(sa.Numeric)
sub_category_id = sa.Column(
sa.Integer, sa.ForeignKey('sub_category.id')
'_sub_category_id',
sa.Integer,
sa.ForeignKey('sub_category._id')
)
def __repr__(self):
@@ -97,8 +117,8 @@ class ThreeLevelDeepManyToMany(object):
catalog_category = sa.Table(
'catalog_category',
self.Base.metadata,
sa.Column('catalog_id', sa.Integer, sa.ForeignKey('catalog.id')),
sa.Column('category_id', sa.Integer, sa.ForeignKey('category.id'))
sa.Column('catalog_id', sa.Integer, sa.ForeignKey('catalog._id')),
sa.Column('category_id', sa.Integer, sa.ForeignKey('category._id'))
)
category_subcategory = sa.Table(
@@ -107,12 +127,12 @@ class ThreeLevelDeepManyToMany(object):
sa.Column(
'category_id',
sa.Integer,
sa.ForeignKey('category.id')
sa.ForeignKey('category._id')
),
sa.Column(
'subcategory_id',
sa.Integer,
sa.ForeignKey('sub_category.id')
sa.ForeignKey('sub_category._id')
)
)
@@ -122,18 +142,18 @@ class ThreeLevelDeepManyToMany(object):
sa.Column(
'subcategory_id',
sa.Integer,
sa.ForeignKey('sub_category.id')
sa.ForeignKey('sub_category._id')
),
sa.Column(
'product_id',
sa.Integer,
sa.ForeignKey('product.id')
sa.ForeignKey('product._id')
)
)
class Catalog(self.Base):
__tablename__ = 'catalog'
id = sa.Column(sa.Integer, primary_key=True)
id = sa.Column('_id', sa.Integer, primary_key=True)
categories = sa.orm.relationship(
'Category',
@@ -143,7 +163,7 @@ class ThreeLevelDeepManyToMany(object):
class Category(self.Base):
__tablename__ = 'category'
id = sa.Column(sa.Integer, primary_key=True)
id = sa.Column('_id', sa.Integer, primary_key=True)
sub_categories = sa.orm.relationship(
'SubCategory',
@@ -153,7 +173,7 @@ class ThreeLevelDeepManyToMany(object):
class SubCategory(self.Base):
__tablename__ = 'sub_category'
id = sa.Column(sa.Integer, primary_key=True)
id = sa.Column('_id', sa.Integer, primary_key=True)
products = sa.orm.relationship(
'Product',
backref='sub_categories',
@@ -162,7 +182,7 @@ class ThreeLevelDeepManyToMany(object):
class Product(self.Base):
__tablename__ = 'product'
id = sa.Column(sa.Integer, primary_key=True)
id = sa.Column('_id', sa.Integer, primary_key=True)
price = sa.Column(sa.Numeric)
self.Catalog = Catalog

View File

@@ -14,7 +14,7 @@ class TestChainedJoinFoDeepToManyToMany(ThreeLevelDeepManyToMany, TestCase):
def test_simple_join(self):
assert str(chained_join(self.Catalog.categories)) == (
'catalog_category JOIN category ON '
'category.id = catalog_category.category_id'
'category._id = catalog_category.category_id'
)
def test_two_relations(self):
@@ -23,10 +23,11 @@ class TestChainedJoinFoDeepToManyToMany(ThreeLevelDeepManyToMany, TestCase):
self.Category.sub_categories
)
assert str(sql) == (
'catalog_category JOIN category ON category.id = '
'catalog_category JOIN category ON category._id = '
'catalog_category.category_id JOIN category_subcategory ON '
'category.id = category_subcategory.category_id JOIN sub_category '
'ON sub_category.id = category_subcategory.subcategory_id'
'category._id = category_subcategory.category_id JOIN '
'sub_category ON sub_category._id = '
'category_subcategory.subcategory_id'
)
def test_three_relations(self):
@@ -36,13 +37,13 @@ class TestChainedJoinFoDeepToManyToMany(ThreeLevelDeepManyToMany, TestCase):
self.SubCategory.products
)
assert str(sql) == (
'catalog_category JOIN category ON category.id = '
'catalog_category JOIN category ON category._id = '
'catalog_category.category_id JOIN category_subcategory ON '
'category.id = category_subcategory.category_id JOIN sub_category '
'ON sub_category.id = category_subcategory.subcategory_id JOIN '
'subcategory_product ON sub_category.id = '
'subcategory_product.subcategory_id JOIN product ON product.id = '
'subcategory_product.product_id'
'category._id = category_subcategory.category_id JOIN sub_category'
' ON sub_category._id = category_subcategory.subcategory_id JOIN '
'subcategory_product ON sub_category._id = '
'subcategory_product.subcategory_id JOIN product ON product._id ='
' subcategory_product.product_id'
)
@@ -59,8 +60,8 @@ class TestChainedJoinForDeepOneToMany(ThreeLevelDeepOneToMany, TestCase):
self.Category.sub_categories
)
assert str(sql) == (
'category JOIN sub_category ON category.id = '
'sub_category.category_id'
'category JOIN sub_category ON category._id = '
'sub_category._category_id'
)
def test_three_relations(self):
@@ -70,9 +71,9 @@ class TestChainedJoinForDeepOneToMany(ThreeLevelDeepOneToMany, TestCase):
self.SubCategory.products
)
assert str(sql) == (
'category JOIN sub_category ON category.id = '
'sub_category.category_id JOIN product ON sub_category.id = '
'product.sub_category_id'
'category JOIN sub_category ON category._id = '
'sub_category._category_id JOIN product ON sub_category._id = '
'product._sub_category_id'
)
@@ -89,8 +90,8 @@ class TestChainedJoinForDeepOneToOne(ThreeLevelDeepOneToOne, TestCase):
self.Category.sub_category
)
assert str(sql) == (
'category JOIN sub_category ON category.id = '
'sub_category.category_id'
'category JOIN sub_category ON category._id = '
'sub_category._category_id'
)
def test_three_relations(self):
@@ -100,7 +101,7 @@ class TestChainedJoinForDeepOneToOne(ThreeLevelDeepOneToOne, TestCase):
self.SubCategory.product
)
assert str(sql) == (
'category JOIN sub_category ON category.id = '
'sub_category.category_id JOIN product ON sub_category.id = '
'product.sub_category_id'
'category JOIN sub_category ON category._id = '
'sub_category._category_id JOIN product ON sub_category._id = '
'product._sub_category_id'
)

View File

@@ -0,0 +1,64 @@
import sqlalchemy as sa
from sqlalchemy_utils.aggregates import select_aggregate
from tests import TestCase
from tests.mixins import (
ThreeLevelDeepManyToMany,
ThreeLevelDeepOneToMany,
ThreeLevelDeepOneToOne,
)
def normalize(sql):
return ' '.join(sql.replace('\n', '').split())
class TestAggregateQueryForDeepToManyToMany(
ThreeLevelDeepManyToMany,
TestCase
):
dns = 'postgres://postgres@localhost/sqlalchemy_utils_test'
create_tables = False
def assert_sql(self, construct, sql):
assert normalize(str(construct)) == normalize(sql)
def build_update(self, *relationships):
expr = sa.func.count(sa.text('1'))
return (
self.Catalog.__table__.update().values(
_id=select_aggregate(
expr,
relationships
).correlate(self.Catalog)
)
)
def test_simple_join(self):
self.assert_sql(
self.build_update(self.Catalog.categories),
(
'''UPDATE catalog SET _id=(SELECT count(1) AS count_1
FROM category JOIN catalog_category ON category._id =
catalog_category.category_id WHERE catalog._id =
catalog_category.catalog_id)'''
)
)
def test_two_relations(self):
self.assert_sql(
self.build_update(
self.Category.sub_categories,
self.Catalog.categories,
),
(
'''UPDATE catalog SET _id=(SELECT count(1) AS count_1
FROM sub_category
JOIN category_subcategory
ON sub_category._id = category_subcategory.subcategory_id
JOIN category
ON category._id = category_subcategory.category_id
JOIN catalog_category
ON category._id = catalog_category.category_id
WHERE catalog._id = catalog_category.catalog_id)'''
)
)