Fix sort_query regular property handling
This commit is contained in:
@@ -2,9 +2,10 @@ from functools import partial
|
||||
from operator import attrgetter
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy import inspect
|
||||
from sqlalchemy.ext.hybrid import hybrid_property
|
||||
from sqlalchemy.orm.attributes import InstrumentedAttribute
|
||||
from sqlalchemy.orm.query import _ColumnEntity
|
||||
from sqlalchemy.orm.mapper import Mapper
|
||||
from sqlalchemy.orm.query import _ColumnEntity
|
||||
from sqlalchemy.orm.util import AliasedInsp
|
||||
|
||||
|
||||
@@ -179,13 +180,19 @@ def get_query_entity_by_alias(query, alias):
|
||||
return entity
|
||||
|
||||
|
||||
def attrs(expr):
|
||||
def get_attrs(expr):
|
||||
if isinstance(expr, AliasedInsp):
|
||||
return expr.mapper.attrs
|
||||
else:
|
||||
return inspect(expr).attrs
|
||||
|
||||
|
||||
def get_hybrid_properties(class_):
|
||||
for prop in sa.inspect(class_).all_orm_descriptors:
|
||||
if isinstance(prop, hybrid_property):
|
||||
yield prop
|
||||
|
||||
|
||||
def get_expr_attr(expr, attr_name):
|
||||
if isinstance(expr, AliasedInsp):
|
||||
return getattr(expr.selectable.c, attr_name)
|
||||
|
@@ -1,12 +1,14 @@
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy.orm.properties import ColumnProperty
|
||||
from sqlalchemy.sql.expression import desc, asc, Label
|
||||
from sqlalchemy.orm.util import AliasedInsp
|
||||
from .orm import (
|
||||
attrs,
|
||||
query_labels,
|
||||
query_entities,
|
||||
get_attrs,
|
||||
get_expr_attr,
|
||||
get_hybrid_properties,
|
||||
get_query_entity_by_alias,
|
||||
get_expr_attr
|
||||
query_entities,
|
||||
query_labels,
|
||||
)
|
||||
|
||||
|
||||
@@ -39,7 +41,7 @@ class QuerySorter(object):
|
||||
return self.query
|
||||
|
||||
def order_by_attr(self, entity, attr):
|
||||
properties = attrs(entity)
|
||||
properties = get_attrs(entity)
|
||||
if attr in properties:
|
||||
property_ = properties[attr]
|
||||
if isinstance(property_, ColumnProperty):
|
||||
@@ -51,10 +53,15 @@ class QuerySorter(object):
|
||||
else:
|
||||
return
|
||||
|
||||
if isinstance(entity, AliasedInsp):
|
||||
entity = entity.entity
|
||||
if hasattr(entity, attr):
|
||||
return getattr(entity, attr)
|
||||
mapper = sa.inspect(entity)
|
||||
|
||||
if isinstance(mapper, AliasedInsp):
|
||||
mapper = mapper.mapper
|
||||
entity = mapper.entity
|
||||
|
||||
for prop in get_hybrid_properties(mapper):
|
||||
if attr == prop.__name__:
|
||||
return getattr(entity, attr)
|
||||
|
||||
def parse_sort_arg(self, arg):
|
||||
if arg[0] == self.separator:
|
||||
@@ -136,7 +143,7 @@ def sort_query(query, *args, **kwargs):
|
||||
3. Applying sort to custom calculated label
|
||||
|
||||
>>> query = session.query(
|
||||
... Category, db.func.count(Article.id).label('articles')
|
||||
... Category, sa.func.count(Article.id).label('articles')
|
||||
... )
|
||||
>>> query = sort_query(query, 'articles')
|
||||
|
||||
|
@@ -3,7 +3,7 @@ import sqlalchemy as sa
|
||||
|
||||
from sqlalchemy import create_engine
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
from sqlalchemy.ext.declarative import declarative_base
|
||||
from sqlalchemy.ext.declarative import declarative_base, synonym_for
|
||||
from sqlalchemy.ext.hybrid import hybrid_property
|
||||
|
||||
from sqlalchemy_utils import (
|
||||
@@ -81,6 +81,15 @@ class TestCase(object):
|
||||
.label('article_count')
|
||||
)
|
||||
|
||||
@property
|
||||
def name_alias(self):
|
||||
return self.name
|
||||
|
||||
@synonym_for('name')
|
||||
@property
|
||||
def name_synonym(self):
|
||||
return self.name
|
||||
|
||||
class Article(self.Base):
|
||||
__tablename__ = 'article'
|
||||
id = sa.Column(sa.Integer, primary_key=True)
|
||||
|
@@ -114,6 +114,16 @@ class TestSortQuery(TestCase):
|
||||
query = sort_query(query, 'articles')
|
||||
assert 'ORDER BY' not in str(query)
|
||||
|
||||
def test_regular_property(self):
|
||||
query = self.session.query(self.Category)
|
||||
query = sort_query(query, 'name_alias')
|
||||
assert 'ORDER BY' not in str(query)
|
||||
|
||||
def test_synonym_property(self):
|
||||
query = self.session.query(self.Category)
|
||||
query = sort_query(query, 'name_synonym')
|
||||
assert 'ORDER BY name DESC'
|
||||
|
||||
def test_hybrid_property(self):
|
||||
query = self.session.query(self.Category)
|
||||
query = sort_query(query, 'articles_count')
|
||||
|
Reference in New Issue
Block a user