diff --git a/sqlalchemy_utils/aggregates.py b/sqlalchemy_utils/aggregates.py index 69e1087..a5f8caf 100644 --- a/sqlalchemy_utils/aggregates.py +++ b/sqlalchemy_utils/aggregates.py @@ -374,7 +374,7 @@ from .functions.orm import get_column_key from .relationships import ( chained_join, path_to_relationships, - select_aggregate + select_correlated_expression ) @@ -447,9 +447,14 @@ class AggregatedValue(object): @property def aggregate_query(self): - query = select_aggregate(self.expr, self.relationships) + query = select_correlated_expression( + self.class_, + self.expr, + self.path, + self.relationships[0].mapper.class_ + ) - return query.correlate(self.class_).as_scalar() + return query.as_scalar() def update_query(self, objects): table = self.class_.__table__ diff --git a/sqlalchemy_utils/listeners.py b/sqlalchemy_utils/listeners.py index ce6d774..691578b 100644 --- a/sqlalchemy_utils/listeners.py +++ b/sqlalchemy_utils/listeners.py @@ -23,7 +23,7 @@ def coercion_listener(mapper, class_): def instant_defaults_listener(target, args, kwargs): for key, column in sa.inspect(target.__class__).columns.items(): - if column.default is not None: + if hasattr(column, 'default') and column.default is not None: if callable(column.default.arg): setattr(target, key, column.default.arg(target)) else: diff --git a/sqlalchemy_utils/relationships/__init__.py b/sqlalchemy_utils/relationships/__init__.py index 8069c65..1bbb85d 100644 --- a/sqlalchemy_utils/relationships/__init__.py +++ b/sqlalchemy_utils/relationships/__init__.py @@ -2,7 +2,6 @@ import sqlalchemy as sa from sqlalchemy.sql.util import ClauseAdapter from .chained_join import chained_join # noqa -from .select_aggregate import select_aggregate # noqa def path_to_relationships(path, cls): diff --git a/sqlalchemy_utils/relationships/select_aggregate.py b/sqlalchemy_utils/relationships/select_aggregate.py deleted file mode 100644 index 379b33f..0000000 --- a/sqlalchemy_utils/relationships/select_aggregate.py +++ /dev/null @@ -1,51 +0,0 @@ -import sqlalchemy as sa - - -def select_aggregate(agg_expr, relationships): - """ - Return a subquery for fetching an aggregate value of given aggregate - expression and given sequence of relationships. - - The returned aggregate query can be used when updating denormalized column - value with query such as: - - UPDATE table SET column = {aggregate_query} - WHERE {condition} - - :param agg_expr: - an expression to be selected, for example sa.func.count('1') - :param relationships: - Sequence of relationships to be used for building the aggregate - query. - """ - from_ = relationships[0].mapper.class_.__table__ - for relationship in relationships[0:-1]: - property_ = relationship.property - if property_.secondary is not None: - from_ = from_.join( - property_.secondary, - property_.secondaryjoin - ) - - from_ = ( - from_ - .join( - property_.parent.class_, - property_.primaryjoin - ) - ) - - prop = relationships[-1].property - condition = prop.primaryjoin - if prop.secondary is not None: - from_ = from_.join( - prop.secondary, - prop.secondaryjoin - ) - - query = sa.select( - [agg_expr], - from_obj=[from_] - ) - - return query.where(condition) diff --git a/tests/relationships/test_select_aggregate.py b/tests/relationships/test_select_aggregate.py deleted file mode 100644 index f4fe687..0000000 --- a/tests/relationships/test_select_aggregate.py +++ /dev/null @@ -1,61 +0,0 @@ -import sqlalchemy as sa - -from sqlalchemy_utils.aggregates import select_aggregate -from tests import TestCase -from tests.mixins import ThreeLevelDeepManyToMany - - -def normalize(sql): - return ' '.join(sql.replace('\n', '').split()) - - -class TestAggregateQueryForDeepToManyToMany( - ThreeLevelDeepManyToMany, - TestCase -): - dns = 'postgres://postgres@localhost/sqlalchemy_utils_test' - create_tables = False - - def assert_sql(self, construct, sql): - assert normalize(str(construct)) == normalize(sql) - - def build_update(self, *relationships): - expr = sa.func.count(sa.text('1')) - return ( - self.Catalog.__table__.update().values( - _id=select_aggregate( - expr, - relationships - ).correlate(self.Catalog) - ) - ) - - def test_simple_join(self): - self.assert_sql( - self.build_update(self.Catalog.categories), - ( - '''UPDATE catalog SET _id=(SELECT count(1) AS count_1 - FROM category JOIN catalog_category ON category._id = - catalog_category.category_id WHERE catalog._id = - catalog_category.catalog_id)''' - ) - ) - - def test_two_relations(self): - self.assert_sql( - self.build_update( - self.Category.sub_categories, - self.Catalog.categories, - ), - ( - '''UPDATE catalog SET _id=(SELECT count(1) AS count_1 - FROM sub_category - JOIN category_subcategory - ON sub_category._id = category_subcategory.subcategory_id - JOIN category - ON category._id = category_subcategory.category_id - JOIN catalog_category - ON category._id = catalog_category.category_id - WHERE catalog._id = catalog_category.catalog_id)''' - ) - )