From 1b6955e0e2e4148848563a6748d654560e4dfc34 Mon Sep 17 00:00:00 2001 From: Konsta Vesterinen Date: Tue, 25 Mar 2014 15:03:02 +0200 Subject: [PATCH] Fix sort_query regular property handling --- sqlalchemy_utils/functions/orm.py | 11 ++++++++-- sqlalchemy_utils/functions/sort_query.py | 27 +++++++++++++++--------- tests/__init__.py | 11 +++++++++- tests/test_sort_query.py | 10 +++++++++ 4 files changed, 46 insertions(+), 13 deletions(-) diff --git a/sqlalchemy_utils/functions/orm.py b/sqlalchemy_utils/functions/orm.py index 9b925ba..8811dd3 100644 --- a/sqlalchemy_utils/functions/orm.py +++ b/sqlalchemy_utils/functions/orm.py @@ -2,9 +2,10 @@ from functools import partial from operator import attrgetter import sqlalchemy as sa from sqlalchemy import inspect +from sqlalchemy.ext.hybrid import hybrid_property from sqlalchemy.orm.attributes import InstrumentedAttribute -from sqlalchemy.orm.query import _ColumnEntity from sqlalchemy.orm.mapper import Mapper +from sqlalchemy.orm.query import _ColumnEntity from sqlalchemy.orm.util import AliasedInsp @@ -179,13 +180,19 @@ def get_query_entity_by_alias(query, alias): return entity -def attrs(expr): +def get_attrs(expr): if isinstance(expr, AliasedInsp): return expr.mapper.attrs else: return inspect(expr).attrs +def get_hybrid_properties(class_): + for prop in sa.inspect(class_).all_orm_descriptors: + if isinstance(prop, hybrid_property): + yield prop + + def get_expr_attr(expr, attr_name): if isinstance(expr, AliasedInsp): return getattr(expr.selectable.c, attr_name) diff --git a/sqlalchemy_utils/functions/sort_query.py b/sqlalchemy_utils/functions/sort_query.py index 281153f..bd1e1bc 100644 --- a/sqlalchemy_utils/functions/sort_query.py +++ b/sqlalchemy_utils/functions/sort_query.py @@ -1,12 +1,14 @@ +import sqlalchemy as sa from sqlalchemy.orm.properties import ColumnProperty from sqlalchemy.sql.expression import desc, asc, Label from sqlalchemy.orm.util import AliasedInsp from .orm import ( - attrs, - query_labels, - query_entities, + get_attrs, + get_expr_attr, + get_hybrid_properties, get_query_entity_by_alias, - get_expr_attr + query_entities, + query_labels, ) @@ -39,7 +41,7 @@ class QuerySorter(object): return self.query def order_by_attr(self, entity, attr): - properties = attrs(entity) + properties = get_attrs(entity) if attr in properties: property_ = properties[attr] if isinstance(property_, ColumnProperty): @@ -51,10 +53,15 @@ class QuerySorter(object): else: return - if isinstance(entity, AliasedInsp): - entity = entity.entity - if hasattr(entity, attr): - return getattr(entity, attr) + mapper = sa.inspect(entity) + + if isinstance(mapper, AliasedInsp): + mapper = mapper.mapper + entity = mapper.entity + + for prop in get_hybrid_properties(mapper): + if attr == prop.__name__: + return getattr(entity, attr) def parse_sort_arg(self, arg): if arg[0] == self.separator: @@ -136,7 +143,7 @@ def sort_query(query, *args, **kwargs): 3. Applying sort to custom calculated label >>> query = session.query( - ... Category, db.func.count(Article.id).label('articles') + ... Category, sa.func.count(Article.id).label('articles') ... ) >>> query = sort_query(query, 'articles') diff --git a/tests/__init__.py b/tests/__init__.py index 71c47be..551cbdf 100644 --- a/tests/__init__.py +++ b/tests/__init__.py @@ -3,7 +3,7 @@ import sqlalchemy as sa from sqlalchemy import create_engine from sqlalchemy.orm import sessionmaker -from sqlalchemy.ext.declarative import declarative_base +from sqlalchemy.ext.declarative import declarative_base, synonym_for from sqlalchemy.ext.hybrid import hybrid_property from sqlalchemy_utils import ( @@ -81,6 +81,15 @@ class TestCase(object): .label('article_count') ) + @property + def name_alias(self): + return self.name + + @synonym_for('name') + @property + def name_synonym(self): + return self.name + class Article(self.Base): __tablename__ = 'article' id = sa.Column(sa.Integer, primary_key=True) diff --git a/tests/test_sort_query.py b/tests/test_sort_query.py index 3e14b9e..516e7c7 100644 --- a/tests/test_sort_query.py +++ b/tests/test_sort_query.py @@ -114,6 +114,16 @@ class TestSortQuery(TestCase): query = sort_query(query, 'articles') assert 'ORDER BY' not in str(query) + def test_regular_property(self): + query = self.session.query(self.Category) + query = sort_query(query, 'name_alias') + assert 'ORDER BY' not in str(query) + + def test_synonym_property(self): + query = self.session.query(self.Category) + query = sort_query(query, 'name_synonym') + assert 'ORDER BY name DESC' + def test_hybrid_property(self): query = self.session.query(self.Category) query = sort_query(query, 'articles_count')