From af7aa08c64a40a1a104522bfc356db00ad0a5f2e Mon Sep 17 00:00:00 2001 From: Konsta Vesterinen Date: Thu, 19 Sep 2013 15:12:32 +0300 Subject: [PATCH] Added support for relation hybrid property sorting for sort_query --- sqlalchemy_utils/functions/sort_query.py | 66 ++++++++++++++---------- tests/test_sort_query.py | 42 ++++++++++----- 2 files changed, 69 insertions(+), 39 deletions(-) diff --git a/sqlalchemy_utils/functions/sort_query.py b/sqlalchemy_utils/functions/sort_query.py index 5af1ed9..ac4e2b3 100644 --- a/sqlalchemy_utils/functions/sort_query.py +++ b/sqlalchemy_utils/functions/sort_query.py @@ -1,3 +1,4 @@ +from sqlalchemy import inspect from sqlalchemy.orm.mapper import Mapper from sqlalchemy.orm.properties import ColumnProperty from sqlalchemy.orm.query import _ColumnEntity @@ -5,6 +6,31 @@ from sqlalchemy.orm.util import AliasedInsp from sqlalchemy.sql.expression import desc, asc, Label +def attrs(expr): + if isinstance(expr, AliasedInsp): + return expr.mapper.attrs + else: + return inspect(expr).attrs + + +def sort_expression(expr, attr_name): + if isinstance(expr, AliasedInsp): + return getattr(expr.selectable.c, attr_name) + else: + return getattr(expr, attr_name) + + +def matches_entity(alias, entity): + if not alias: + return True + if isinstance(entity, AliasedInsp): + name = entity.name + else: + name = entity.__table__.name + + return name == alias + + class QuerySorter(object): def __init__(self, separator='-'): self.entities = [] @@ -38,43 +64,31 @@ class QuerySorter(object): return self.query.order_by(sort['func'](sort['attr'])) for entity in self.entities: - if isinstance(entity, AliasedInsp): - if sort['entity'] and entity.name != sort['entity']: - continue + if not matches_entity(sort['entity'], entity): + continue - selectable = entity.selectable - - if sort['attr'] in selectable.c: - attr = selectable.c[sort['attr']] - return self.query.order_by(sort['func'](attr)) - else: - table = entity.__table__ - if sort['entity'] and table.name != sort['entity']: - continue - return self.assign_entity_attr_order_by(entity, sort) + return self.assign_entity_attr_order_by(entity, sort) return self.query def assign_entity_attr_order_by(self, entity, sort): - if sort['attr'] in entity.__mapper__.class_manager.keys(): - try: - attr = getattr(entity, sort['attr']) - except AttributeError: - pass - else: - property_ = attr.property - if isinstance(property_, ColumnProperty): - if isinstance(attr.property.columns[0], Label): - attr = attr.property.columns[0].name - - return self.query.order_by(sort['func'](attr)) + properties = attrs(entity) + if sort['attr'] in properties: + property_ = properties[sort['attr']] + if isinstance(property_, ColumnProperty): + if isinstance(property_.columns[0], Label): + expr = property_.columns[0].name + else: + expr = sort_expression(entity, property_.key) + return self.query.order_by(sort['func']( + expr + )) # Check hybrid properties. if hasattr(entity, sort['attr']): return self.query.order_by( sort['func'](getattr(entity, sort['attr'])) ) - return self.query def parse_sort_arg(self, arg): diff --git a/tests/test_sort_query.py b/tests/test_sort_query.py index b9520df..7637a3f 100644 --- a/tests/test_sort_query.py +++ b/tests/test_sort_query.py @@ -9,11 +9,11 @@ class TestSortQuery(TestCase): sorted_query = sort_query(query, '') assert query == sorted_query - def test_sort_by_column_ascending(self): + def test_column_ascending(self): query = sort_query(self.session.query(self.Article), 'name') assert 'ORDER BY article.name ASC' in str(query) - def test_sort_by_column_descending(self): + def test_column_descending(self): query = sort_query(self.session.query(self.Article), '-name') assert 'ORDER BY article.name DESC' in str(query) @@ -22,21 +22,21 @@ class TestSortQuery(TestCase): sorted_query = sort_query(query, '-unknown') assert query == sorted_query - def test_sort_by_calculated_value_ascending(self): + def test_calculated_value_ascending(self): query = self.session.query( self.Category, sa.func.count(self.Article.id).label('articles') ) query = sort_query(query, 'articles') assert 'ORDER BY articles ASC' in str(query) - def test_sort_by_calculated_value_descending(self): + def test_calculated_value_descending(self): query = self.session.query( self.Category, sa.func.count(self.Article.id).label('articles') ) query = sort_query(query, '-articles') assert 'ORDER BY articles DESC' in str(query) - def test_sort_by_subqueried_scalar(self): + def test_subqueried_scalar(self): article_count = ( sa.sql.select( [sa.func.count(self.Article.id)], @@ -52,7 +52,7 @@ class TestSortQuery(TestCase): query = sort_query(query, '-articles') assert 'ORDER BY articles DESC' in str(query) - def test_sort_by_aliased_joined_entity(self): + def test_aliased_joined_entity(self): alias = sa.orm.aliased(self.Category, name='categories') query = self.session.query( self.Article @@ -62,17 +62,17 @@ class TestSortQuery(TestCase): query = sort_query(query, '-categories-name') assert 'ORDER BY categories.name DESC' in str(query) - def test_sort_by_joined_table_column(self): + def test_joined_table_column(self): query = self.session.query(self.Article).join(self.Article.category) sorted_query = sort_query(query, 'category-name') assert 'category.name ASC' in str(sorted_query) - def test_sort_by_multiple_columns(self): + def test_multiple_columns(self): query = self.session.query(self.Article) sorted_query = sort_query(query, 'name', 'id') assert 'article.name ASC, article.id ASC' in str(sorted_query) - def test_sort_by_column_property(self): + def test_column_property(self): self.Category.article_count = sa.orm.column_property( sa.select([sa.func.count(self.Article.id)]) .where(self.Article.category_id == self.Category.id) @@ -83,7 +83,7 @@ class TestSortQuery(TestCase): sorted_query = sort_query(query, 'article_count') assert 'article_count ASC' in str(sorted_query) - def test_sort_by_column_property_descending(self): + def test_column_property_descending(self): self.Category.article_count = sa.orm.column_property( sa.select([sa.func.count(self.Article.id)]) .where(self.Article.category_id == self.Category.id) @@ -94,12 +94,12 @@ class TestSortQuery(TestCase): sorted_query = sort_query(query, '-article_count') assert 'article_count DESC' in str(sorted_query) - def test_sort_by_hybrid_property(self): + def test_hybrid_property(self): query = self.session.query(self.Category) query = sort_query(query, 'articles_count') assert 'ORDER BY (SELECT count(article.id) AS count_1' in str(query) - def test_sort_by_hybrid_property_descending(self): + def test_hybrid_property_descending(self): query = self.session.query(self.Category) query = sort_query(query, '-articles_count') assert ( @@ -107,7 +107,7 @@ class TestSortQuery(TestCase): ) in str(query) assert ' DESC' in str(query) - def test_sort_by_related_hybrid_property(self): + def test_relation_hybrid_property(self): query = ( self.session.query(self.Article) .join(self.Article.category) @@ -115,3 +115,19 @@ class TestSortQuery(TestCase): ) query = sort_query(query, '-category-articles_count') assert 'ORDER BY (SELECT count(article.id) AS count_1' in str(query) + + def test_aliased_relation_hybrid_property(self): + alias = sa.orm.aliased( + self.Category, + name='category' + ) + query = ( + self.session.query(self.Article) + .outerjoin(alias, self.Article.category) + .options( + sa.orm.contains_eager(self.Article.category, alias=alias) + ) + ) + query = sort_query(query, '-category-articles_count') + print query + #assert 'ORDER BY (SELECT count(article.id) AS count_1' in str(query)