Refactor aggregates

* Make aggregates use select_correlated_expression
* Remove select_aggregate
This commit is contained in:
Konsta Vesterinen
2015-07-16 12:11:23 +03:00
parent 74392872e6
commit 1368207b74
5 changed files with 9 additions and 117 deletions

View File

@@ -374,7 +374,7 @@ from .functions.orm import get_column_key
from .relationships import (
chained_join,
path_to_relationships,
select_aggregate
select_correlated_expression
)
@@ -447,9 +447,14 @@ class AggregatedValue(object):
@property
def aggregate_query(self):
query = select_aggregate(self.expr, self.relationships)
query = select_correlated_expression(
self.class_,
self.expr,
self.path,
self.relationships[0].mapper.class_
)
return query.correlate(self.class_).as_scalar()
return query.as_scalar()
def update_query(self, objects):
table = self.class_.__table__

View File

@@ -23,7 +23,7 @@ def coercion_listener(mapper, class_):
def instant_defaults_listener(target, args, kwargs):
for key, column in sa.inspect(target.__class__).columns.items():
if column.default is not None:
if hasattr(column, 'default') and column.default is not None:
if callable(column.default.arg):
setattr(target, key, column.default.arg(target))
else:

View File

@@ -2,7 +2,6 @@ import sqlalchemy as sa
from sqlalchemy.sql.util import ClauseAdapter
from .chained_join import chained_join # noqa
from .select_aggregate import select_aggregate # noqa
def path_to_relationships(path, cls):

View File

@@ -1,51 +0,0 @@
import sqlalchemy as sa
def select_aggregate(agg_expr, relationships):
"""
Return a subquery for fetching an aggregate value of given aggregate
expression and given sequence of relationships.
The returned aggregate query can be used when updating denormalized column
value with query such as:
UPDATE table SET column = {aggregate_query}
WHERE {condition}
:param agg_expr:
an expression to be selected, for example sa.func.count('1')
:param relationships:
Sequence of relationships to be used for building the aggregate
query.
"""
from_ = relationships[0].mapper.class_.__table__
for relationship in relationships[0:-1]:
property_ = relationship.property
if property_.secondary is not None:
from_ = from_.join(
property_.secondary,
property_.secondaryjoin
)
from_ = (
from_
.join(
property_.parent.class_,
property_.primaryjoin
)
)
prop = relationships[-1].property
condition = prop.primaryjoin
if prop.secondary is not None:
from_ = from_.join(
prop.secondary,
prop.secondaryjoin
)
query = sa.select(
[agg_expr],
from_obj=[from_]
)
return query.where(condition)

View File

@@ -1,61 +0,0 @@
import sqlalchemy as sa
from sqlalchemy_utils.aggregates import select_aggregate
from tests import TestCase
from tests.mixins import ThreeLevelDeepManyToMany
def normalize(sql):
return ' '.join(sql.replace('\n', '').split())
class TestAggregateQueryForDeepToManyToMany(
ThreeLevelDeepManyToMany,
TestCase
):
dns = 'postgres://postgres@localhost/sqlalchemy_utils_test'
create_tables = False
def assert_sql(self, construct, sql):
assert normalize(str(construct)) == normalize(sql)
def build_update(self, *relationships):
expr = sa.func.count(sa.text('1'))
return (
self.Catalog.__table__.update().values(
_id=select_aggregate(
expr,
relationships
).correlate(self.Catalog)
)
)
def test_simple_join(self):
self.assert_sql(
self.build_update(self.Catalog.categories),
(
'''UPDATE catalog SET _id=(SELECT count(1) AS count_1
FROM category JOIN catalog_category ON category._id =
catalog_category.category_id WHERE catalog._id =
catalog_category.catalog_id)'''
)
)
def test_two_relations(self):
self.assert_sql(
self.build_update(
self.Category.sub_categories,
self.Catalog.categories,
),
(
'''UPDATE catalog SET _id=(SELECT count(1) AS count_1
FROM sub_category
JOIN category_subcategory
ON sub_category._id = category_subcategory.subcategory_id
JOIN category
ON category._id = category_subcategory.category_id
JOIN catalog_category
ON category._id = catalog_category.category_id
WHERE catalog._id = catalog_category.catalog_id)'''
)
)