Fixed sort query aliased hybrid property handling
This commit is contained in:
@@ -7,7 +7,7 @@ from sqlalchemy.orm.properties import ColumnProperty
|
||||
from sqlalchemy.orm.query import Query
|
||||
from sqlalchemy.schema import MetaData, Table, ForeignKeyConstraint
|
||||
from .batch_fetch import batch_fetch, with_backrefs, CompositePath
|
||||
from .sort_query import sort_query
|
||||
from .sort_query import sort_query, QuerySorterException
|
||||
|
||||
|
||||
__all__ = (
|
||||
@@ -15,6 +15,7 @@ __all__ = (
|
||||
sort_query,
|
||||
with_backrefs,
|
||||
CompositePath,
|
||||
QuerySorterException
|
||||
)
|
||||
|
||||
|
||||
|
@@ -20,15 +20,6 @@ def sort_expression(expr, attr_name):
|
||||
return getattr(expr, attr_name)
|
||||
|
||||
|
||||
def get_entity(expr):
|
||||
if isinstance(expr, AliasedInsp):
|
||||
return expr.mapper.class_
|
||||
elif isinstance(expr, Mapper):
|
||||
return expr.class_
|
||||
else:
|
||||
return expr
|
||||
|
||||
|
||||
def matches_entity(alias, entity):
|
||||
if not alias:
|
||||
return True
|
||||
@@ -40,11 +31,16 @@ def matches_entity(alias, entity):
|
||||
return name == alias
|
||||
|
||||
|
||||
class QuerySorterException(Exception):
|
||||
pass
|
||||
|
||||
|
||||
class QuerySorter(object):
|
||||
def __init__(self, separator='-'):
|
||||
def __init__(self, silent=True, separator='-'):
|
||||
self.entities = []
|
||||
self.labels = []
|
||||
self.separator = separator
|
||||
self.silent = silent
|
||||
|
||||
def inspect_labels_and_entities(self):
|
||||
for entity in self.query._entities:
|
||||
@@ -69,37 +65,41 @@ class QuerySorter(object):
|
||||
return self.query
|
||||
|
||||
sort = self.parse_sort_arg(sort)
|
||||
expr = None
|
||||
if sort['attr'] in self.labels:
|
||||
return self.query.order_by(sort['func'](sort['attr']))
|
||||
expr = sort['attr']
|
||||
else:
|
||||
for entity in self.entities:
|
||||
if not matches_entity(sort['entity'], entity):
|
||||
continue
|
||||
|
||||
for entity in self.entities:
|
||||
if not matches_entity(sort['entity'], entity):
|
||||
continue
|
||||
|
||||
return self.assign_entity_attr_order_by(entity, sort)
|
||||
expr = self.order_by_attr(entity, sort['attr'])
|
||||
|
||||
if expr is not None:
|
||||
return self.query.order_by(
|
||||
sort['func'](expr)
|
||||
)
|
||||
if not self.silent:
|
||||
raise QuerySorterException(
|
||||
"Could not sort query with expression '%s'" % sort['attr']
|
||||
)
|
||||
return self.query
|
||||
|
||||
def assign_entity_attr_order_by(self, entity, sort):
|
||||
def order_by_attr(self, entity, attr):
|
||||
properties = attrs(entity)
|
||||
if sort['attr'] in properties:
|
||||
property_ = properties[sort['attr']]
|
||||
if attr in properties:
|
||||
property_ = properties[attr]
|
||||
if isinstance(property_, ColumnProperty):
|
||||
if isinstance(property_.columns[0], Label):
|
||||
expr = property_.columns[0].name
|
||||
else:
|
||||
expr = sort_expression(entity, property_.key)
|
||||
return self.query.order_by(sort['func'](
|
||||
expr
|
||||
))
|
||||
return expr
|
||||
|
||||
# Check hybrid properties.
|
||||
entity = get_entity(entity)
|
||||
if hasattr(entity, sort['attr']):
|
||||
return self.query.order_by(
|
||||
sort['func'](getattr(entity, sort['attr']))
|
||||
)
|
||||
return self.query
|
||||
if isinstance(entity, AliasedInsp):
|
||||
entity = entity.entity
|
||||
if hasattr(entity, attr):
|
||||
return getattr(entity, attr)
|
||||
|
||||
def parse_sort_arg(self, arg):
|
||||
if arg[0] == self.separator:
|
||||
@@ -123,7 +123,7 @@ class QuerySorter(object):
|
||||
return self.query
|
||||
|
||||
|
||||
def sort_query(query, *args):
|
||||
def sort_query(query, *args, **kwargs):
|
||||
"""
|
||||
Applies an sql ORDER BY for given query. This function can be easily used
|
||||
with user-defined sorting.
|
||||
@@ -192,4 +192,4 @@ def sort_query(query, *args):
|
||||
is passed
|
||||
"""
|
||||
|
||||
return QuerySorter()(query, *args)
|
||||
return QuerySorter(**kwargs)(query, *args)
|
||||
|
@@ -1,5 +1,7 @@
|
||||
from pytest import raises
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy_utils import sort_query
|
||||
from sqlalchemy_utils.functions import QuerySorterException
|
||||
from tests import TestCase
|
||||
|
||||
|
||||
@@ -22,6 +24,11 @@ class TestSortQuery(TestCase):
|
||||
sorted_query = sort_query(query, '-unknown')
|
||||
assert query == sorted_query
|
||||
|
||||
def test_non_silent_mode(self):
|
||||
query = self.session.query(self.Article)
|
||||
with raises(QuerySorterException):
|
||||
sort_query(query, '-unknown', silent=False)
|
||||
|
||||
def test_calculated_value_ascending(self):
|
||||
query = self.session.query(
|
||||
self.Category, sa.func.count(self.Article.id).label('articles')
|
||||
@@ -116,10 +123,10 @@ class TestSortQuery(TestCase):
|
||||
query = sort_query(query, '-category-articles_count')
|
||||
assert 'ORDER BY (SELECT count(article.id) AS count_1' in str(query)
|
||||
|
||||
def test_aliased_relation_hybrid_property(self):
|
||||
def test_aliased_hybrid_property(self):
|
||||
alias = sa.orm.aliased(
|
||||
self.Category,
|
||||
name='category'
|
||||
name='categories'
|
||||
)
|
||||
query = (
|
||||
self.session.query(self.Article)
|
||||
@@ -128,5 +135,5 @@ class TestSortQuery(TestCase):
|
||||
sa.orm.contains_eager(self.Article.category, alias=alias)
|
||||
)
|
||||
)
|
||||
query = sort_query(query, '-category-articles_count')
|
||||
query = sort_query(query, '-categories-articles_count')
|
||||
assert 'ORDER BY (SELECT count(article.id) AS count_1' in str(query)
|
||||
|
Reference in New Issue
Block a user