Fix sort_query regular property handling

This commit is contained in:
Konsta Vesterinen
2014-03-25 15:03:02 +02:00
parent d5ff56e988
commit 1b6955e0e2
4 changed files with 46 additions and 13 deletions

View File

@@ -2,9 +2,10 @@ from functools import partial
from operator import attrgetter from operator import attrgetter
import sqlalchemy as sa import sqlalchemy as sa
from sqlalchemy import inspect from sqlalchemy import inspect
from sqlalchemy.ext.hybrid import hybrid_property
from sqlalchemy.orm.attributes import InstrumentedAttribute from sqlalchemy.orm.attributes import InstrumentedAttribute
from sqlalchemy.orm.query import _ColumnEntity
from sqlalchemy.orm.mapper import Mapper from sqlalchemy.orm.mapper import Mapper
from sqlalchemy.orm.query import _ColumnEntity
from sqlalchemy.orm.util import AliasedInsp from sqlalchemy.orm.util import AliasedInsp
@@ -179,13 +180,19 @@ def get_query_entity_by_alias(query, alias):
return entity return entity
def attrs(expr): def get_attrs(expr):
if isinstance(expr, AliasedInsp): if isinstance(expr, AliasedInsp):
return expr.mapper.attrs return expr.mapper.attrs
else: else:
return inspect(expr).attrs return inspect(expr).attrs
def get_hybrid_properties(class_):
for prop in sa.inspect(class_).all_orm_descriptors:
if isinstance(prop, hybrid_property):
yield prop
def get_expr_attr(expr, attr_name): def get_expr_attr(expr, attr_name):
if isinstance(expr, AliasedInsp): if isinstance(expr, AliasedInsp):
return getattr(expr.selectable.c, attr_name) return getattr(expr.selectable.c, attr_name)

View File

@@ -1,12 +1,14 @@
import sqlalchemy as sa
from sqlalchemy.orm.properties import ColumnProperty from sqlalchemy.orm.properties import ColumnProperty
from sqlalchemy.sql.expression import desc, asc, Label from sqlalchemy.sql.expression import desc, asc, Label
from sqlalchemy.orm.util import AliasedInsp from sqlalchemy.orm.util import AliasedInsp
from .orm import ( from .orm import (
attrs, get_attrs,
query_labels, get_expr_attr,
query_entities, get_hybrid_properties,
get_query_entity_by_alias, get_query_entity_by_alias,
get_expr_attr query_entities,
query_labels,
) )
@@ -39,7 +41,7 @@ class QuerySorter(object):
return self.query return self.query
def order_by_attr(self, entity, attr): def order_by_attr(self, entity, attr):
properties = attrs(entity) properties = get_attrs(entity)
if attr in properties: if attr in properties:
property_ = properties[attr] property_ = properties[attr]
if isinstance(property_, ColumnProperty): if isinstance(property_, ColumnProperty):
@@ -51,9 +53,14 @@ class QuerySorter(object):
else: else:
return return
if isinstance(entity, AliasedInsp): mapper = sa.inspect(entity)
entity = entity.entity
if hasattr(entity, attr): if isinstance(mapper, AliasedInsp):
mapper = mapper.mapper
entity = mapper.entity
for prop in get_hybrid_properties(mapper):
if attr == prop.__name__:
return getattr(entity, attr) return getattr(entity, attr)
def parse_sort_arg(self, arg): def parse_sort_arg(self, arg):
@@ -136,7 +143,7 @@ def sort_query(query, *args, **kwargs):
3. Applying sort to custom calculated label 3. Applying sort to custom calculated label
>>> query = session.query( >>> query = session.query(
... Category, db.func.count(Article.id).label('articles') ... Category, sa.func.count(Article.id).label('articles')
... ) ... )
>>> query = sort_query(query, 'articles') >>> query = sort_query(query, 'articles')

View File

@@ -3,7 +3,7 @@ import sqlalchemy as sa
from sqlalchemy import create_engine from sqlalchemy import create_engine
from sqlalchemy.orm import sessionmaker from sqlalchemy.orm import sessionmaker
from sqlalchemy.ext.declarative import declarative_base from sqlalchemy.ext.declarative import declarative_base, synonym_for
from sqlalchemy.ext.hybrid import hybrid_property from sqlalchemy.ext.hybrid import hybrid_property
from sqlalchemy_utils import ( from sqlalchemy_utils import (
@@ -81,6 +81,15 @@ class TestCase(object):
.label('article_count') .label('article_count')
) )
@property
def name_alias(self):
return self.name
@synonym_for('name')
@property
def name_synonym(self):
return self.name
class Article(self.Base): class Article(self.Base):
__tablename__ = 'article' __tablename__ = 'article'
id = sa.Column(sa.Integer, primary_key=True) id = sa.Column(sa.Integer, primary_key=True)

View File

@@ -114,6 +114,16 @@ class TestSortQuery(TestCase):
query = sort_query(query, 'articles') query = sort_query(query, 'articles')
assert 'ORDER BY' not in str(query) assert 'ORDER BY' not in str(query)
def test_regular_property(self):
query = self.session.query(self.Category)
query = sort_query(query, 'name_alias')
assert 'ORDER BY' not in str(query)
def test_synonym_property(self):
query = self.session.query(self.Category)
query = sort_query(query, 'name_synonym')
assert 'ORDER BY name DESC'
def test_hybrid_property(self): def test_hybrid_property(self):
query = self.session.query(self.Category) query = self.session.query(self.Category)
query = sort_query(query, 'articles_count') query = sort_query(query, 'articles_count')