From 824ad2a6bb493def630d454c6439276ccde193fd Mon Sep 17 00:00:00 2001 From: Konsta Vesterinen Date: Sat, 21 Sep 2013 15:18:53 +0300 Subject: [PATCH] Fixed sort query aliased hybrid property handling --- sqlalchemy_utils/functions/__init__.py | 3 +- sqlalchemy_utils/functions/sort_query.py | 62 ++++++++++++------------ tests/test_sort_query.py | 13 +++-- 3 files changed, 43 insertions(+), 35 deletions(-) diff --git a/sqlalchemy_utils/functions/__init__.py b/sqlalchemy_utils/functions/__init__.py index 9a436be..9b191fc 100644 --- a/sqlalchemy_utils/functions/__init__.py +++ b/sqlalchemy_utils/functions/__init__.py @@ -7,7 +7,7 @@ from sqlalchemy.orm.properties import ColumnProperty from sqlalchemy.orm.query import Query from sqlalchemy.schema import MetaData, Table, ForeignKeyConstraint from .batch_fetch import batch_fetch, with_backrefs, CompositePath -from .sort_query import sort_query +from .sort_query import sort_query, QuerySorterException __all__ = ( @@ -15,6 +15,7 @@ __all__ = ( sort_query, with_backrefs, CompositePath, + QuerySorterException ) diff --git a/sqlalchemy_utils/functions/sort_query.py b/sqlalchemy_utils/functions/sort_query.py index 8026c4b..6685b75 100644 --- a/sqlalchemy_utils/functions/sort_query.py +++ b/sqlalchemy_utils/functions/sort_query.py @@ -20,15 +20,6 @@ def sort_expression(expr, attr_name): return getattr(expr, attr_name) -def get_entity(expr): - if isinstance(expr, AliasedInsp): - return expr.mapper.class_ - elif isinstance(expr, Mapper): - return expr.class_ - else: - return expr - - def matches_entity(alias, entity): if not alias: return True @@ -40,11 +31,16 @@ def matches_entity(alias, entity): return name == alias +class QuerySorterException(Exception): + pass + + class QuerySorter(object): - def __init__(self, separator='-'): + def __init__(self, silent=True, separator='-'): self.entities = [] self.labels = [] self.separator = separator + self.silent = silent def inspect_labels_and_entities(self): for entity in self.query._entities: @@ -69,37 +65,41 @@ class QuerySorter(object): return self.query sort = self.parse_sort_arg(sort) + expr = None if sort['attr'] in self.labels: - return self.query.order_by(sort['func'](sort['attr'])) + expr = sort['attr'] + else: + for entity in self.entities: + if not matches_entity(sort['entity'], entity): + continue - for entity in self.entities: - if not matches_entity(sort['entity'], entity): - continue - - return self.assign_entity_attr_order_by(entity, sort) + expr = self.order_by_attr(entity, sort['attr']) + if expr is not None: + return self.query.order_by( + sort['func'](expr) + ) + if not self.silent: + raise QuerySorterException( + "Could not sort query with expression '%s'" % sort['attr'] + ) return self.query - def assign_entity_attr_order_by(self, entity, sort): + def order_by_attr(self, entity, attr): properties = attrs(entity) - if sort['attr'] in properties: - property_ = properties[sort['attr']] + if attr in properties: + property_ = properties[attr] if isinstance(property_, ColumnProperty): if isinstance(property_.columns[0], Label): expr = property_.columns[0].name else: expr = sort_expression(entity, property_.key) - return self.query.order_by(sort['func']( - expr - )) + return expr - # Check hybrid properties. - entity = get_entity(entity) - if hasattr(entity, sort['attr']): - return self.query.order_by( - sort['func'](getattr(entity, sort['attr'])) - ) - return self.query + if isinstance(entity, AliasedInsp): + entity = entity.entity + if hasattr(entity, attr): + return getattr(entity, attr) def parse_sort_arg(self, arg): if arg[0] == self.separator: @@ -123,7 +123,7 @@ class QuerySorter(object): return self.query -def sort_query(query, *args): +def sort_query(query, *args, **kwargs): """ Applies an sql ORDER BY for given query. This function can be easily used with user-defined sorting. @@ -192,4 +192,4 @@ def sort_query(query, *args): is passed """ - return QuerySorter()(query, *args) + return QuerySorter(**kwargs)(query, *args) diff --git a/tests/test_sort_query.py b/tests/test_sort_query.py index b12d061..ec92d04 100644 --- a/tests/test_sort_query.py +++ b/tests/test_sort_query.py @@ -1,5 +1,7 @@ +from pytest import raises import sqlalchemy as sa from sqlalchemy_utils import sort_query +from sqlalchemy_utils.functions import QuerySorterException from tests import TestCase @@ -22,6 +24,11 @@ class TestSortQuery(TestCase): sorted_query = sort_query(query, '-unknown') assert query == sorted_query + def test_non_silent_mode(self): + query = self.session.query(self.Article) + with raises(QuerySorterException): + sort_query(query, '-unknown', silent=False) + def test_calculated_value_ascending(self): query = self.session.query( self.Category, sa.func.count(self.Article.id).label('articles') @@ -116,10 +123,10 @@ class TestSortQuery(TestCase): query = sort_query(query, '-category-articles_count') assert 'ORDER BY (SELECT count(article.id) AS count_1' in str(query) - def test_aliased_relation_hybrid_property(self): + def test_aliased_hybrid_property(self): alias = sa.orm.aliased( self.Category, - name='category' + name='categories' ) query = ( self.session.query(self.Article) @@ -128,5 +135,5 @@ class TestSortQuery(TestCase): sa.orm.contains_eager(self.Article.category, alias=alias) ) ) - query = sort_query(query, '-category-articles_count') + query = sort_query(query, '-categories-articles_count') assert 'ORDER BY (SELECT count(article.id) AS count_1' in str(query)