Fixed sort query aliased hybrid property handling
This commit is contained in:
@@ -7,7 +7,7 @@ from sqlalchemy.orm.properties import ColumnProperty
|
|||||||
from sqlalchemy.orm.query import Query
|
from sqlalchemy.orm.query import Query
|
||||||
from sqlalchemy.schema import MetaData, Table, ForeignKeyConstraint
|
from sqlalchemy.schema import MetaData, Table, ForeignKeyConstraint
|
||||||
from .batch_fetch import batch_fetch, with_backrefs, CompositePath
|
from .batch_fetch import batch_fetch, with_backrefs, CompositePath
|
||||||
from .sort_query import sort_query
|
from .sort_query import sort_query, QuerySorterException
|
||||||
|
|
||||||
|
|
||||||
__all__ = (
|
__all__ = (
|
||||||
@@ -15,6 +15,7 @@ __all__ = (
|
|||||||
sort_query,
|
sort_query,
|
||||||
with_backrefs,
|
with_backrefs,
|
||||||
CompositePath,
|
CompositePath,
|
||||||
|
QuerySorterException
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@@ -20,15 +20,6 @@ def sort_expression(expr, attr_name):
|
|||||||
return getattr(expr, attr_name)
|
return getattr(expr, attr_name)
|
||||||
|
|
||||||
|
|
||||||
def get_entity(expr):
|
|
||||||
if isinstance(expr, AliasedInsp):
|
|
||||||
return expr.mapper.class_
|
|
||||||
elif isinstance(expr, Mapper):
|
|
||||||
return expr.class_
|
|
||||||
else:
|
|
||||||
return expr
|
|
||||||
|
|
||||||
|
|
||||||
def matches_entity(alias, entity):
|
def matches_entity(alias, entity):
|
||||||
if not alias:
|
if not alias:
|
||||||
return True
|
return True
|
||||||
@@ -40,11 +31,16 @@ def matches_entity(alias, entity):
|
|||||||
return name == alias
|
return name == alias
|
||||||
|
|
||||||
|
|
||||||
|
class QuerySorterException(Exception):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
class QuerySorter(object):
|
class QuerySorter(object):
|
||||||
def __init__(self, separator='-'):
|
def __init__(self, silent=True, separator='-'):
|
||||||
self.entities = []
|
self.entities = []
|
||||||
self.labels = []
|
self.labels = []
|
||||||
self.separator = separator
|
self.separator = separator
|
||||||
|
self.silent = silent
|
||||||
|
|
||||||
def inspect_labels_and_entities(self):
|
def inspect_labels_and_entities(self):
|
||||||
for entity in self.query._entities:
|
for entity in self.query._entities:
|
||||||
@@ -69,37 +65,41 @@ class QuerySorter(object):
|
|||||||
return self.query
|
return self.query
|
||||||
|
|
||||||
sort = self.parse_sort_arg(sort)
|
sort = self.parse_sort_arg(sort)
|
||||||
|
expr = None
|
||||||
if sort['attr'] in self.labels:
|
if sort['attr'] in self.labels:
|
||||||
return self.query.order_by(sort['func'](sort['attr']))
|
expr = sort['attr']
|
||||||
|
else:
|
||||||
|
for entity in self.entities:
|
||||||
|
if not matches_entity(sort['entity'], entity):
|
||||||
|
continue
|
||||||
|
|
||||||
for entity in self.entities:
|
expr = self.order_by_attr(entity, sort['attr'])
|
||||||
if not matches_entity(sort['entity'], entity):
|
|
||||||
continue
|
|
||||||
|
|
||||||
return self.assign_entity_attr_order_by(entity, sort)
|
|
||||||
|
|
||||||
|
if expr is not None:
|
||||||
|
return self.query.order_by(
|
||||||
|
sort['func'](expr)
|
||||||
|
)
|
||||||
|
if not self.silent:
|
||||||
|
raise QuerySorterException(
|
||||||
|
"Could not sort query with expression '%s'" % sort['attr']
|
||||||
|
)
|
||||||
return self.query
|
return self.query
|
||||||
|
|
||||||
def assign_entity_attr_order_by(self, entity, sort):
|
def order_by_attr(self, entity, attr):
|
||||||
properties = attrs(entity)
|
properties = attrs(entity)
|
||||||
if sort['attr'] in properties:
|
if attr in properties:
|
||||||
property_ = properties[sort['attr']]
|
property_ = properties[attr]
|
||||||
if isinstance(property_, ColumnProperty):
|
if isinstance(property_, ColumnProperty):
|
||||||
if isinstance(property_.columns[0], Label):
|
if isinstance(property_.columns[0], Label):
|
||||||
expr = property_.columns[0].name
|
expr = property_.columns[0].name
|
||||||
else:
|
else:
|
||||||
expr = sort_expression(entity, property_.key)
|
expr = sort_expression(entity, property_.key)
|
||||||
return self.query.order_by(sort['func'](
|
return expr
|
||||||
expr
|
|
||||||
))
|
|
||||||
|
|
||||||
# Check hybrid properties.
|
if isinstance(entity, AliasedInsp):
|
||||||
entity = get_entity(entity)
|
entity = entity.entity
|
||||||
if hasattr(entity, sort['attr']):
|
if hasattr(entity, attr):
|
||||||
return self.query.order_by(
|
return getattr(entity, attr)
|
||||||
sort['func'](getattr(entity, sort['attr']))
|
|
||||||
)
|
|
||||||
return self.query
|
|
||||||
|
|
||||||
def parse_sort_arg(self, arg):
|
def parse_sort_arg(self, arg):
|
||||||
if arg[0] == self.separator:
|
if arg[0] == self.separator:
|
||||||
@@ -123,7 +123,7 @@ class QuerySorter(object):
|
|||||||
return self.query
|
return self.query
|
||||||
|
|
||||||
|
|
||||||
def sort_query(query, *args):
|
def sort_query(query, *args, **kwargs):
|
||||||
"""
|
"""
|
||||||
Applies an sql ORDER BY for given query. This function can be easily used
|
Applies an sql ORDER BY for given query. This function can be easily used
|
||||||
with user-defined sorting.
|
with user-defined sorting.
|
||||||
@@ -192,4 +192,4 @@ def sort_query(query, *args):
|
|||||||
is passed
|
is passed
|
||||||
"""
|
"""
|
||||||
|
|
||||||
return QuerySorter()(query, *args)
|
return QuerySorter(**kwargs)(query, *args)
|
||||||
|
@@ -1,5 +1,7 @@
|
|||||||
|
from pytest import raises
|
||||||
import sqlalchemy as sa
|
import sqlalchemy as sa
|
||||||
from sqlalchemy_utils import sort_query
|
from sqlalchemy_utils import sort_query
|
||||||
|
from sqlalchemy_utils.functions import QuerySorterException
|
||||||
from tests import TestCase
|
from tests import TestCase
|
||||||
|
|
||||||
|
|
||||||
@@ -22,6 +24,11 @@ class TestSortQuery(TestCase):
|
|||||||
sorted_query = sort_query(query, '-unknown')
|
sorted_query = sort_query(query, '-unknown')
|
||||||
assert query == sorted_query
|
assert query == sorted_query
|
||||||
|
|
||||||
|
def test_non_silent_mode(self):
|
||||||
|
query = self.session.query(self.Article)
|
||||||
|
with raises(QuerySorterException):
|
||||||
|
sort_query(query, '-unknown', silent=False)
|
||||||
|
|
||||||
def test_calculated_value_ascending(self):
|
def test_calculated_value_ascending(self):
|
||||||
query = self.session.query(
|
query = self.session.query(
|
||||||
self.Category, sa.func.count(self.Article.id).label('articles')
|
self.Category, sa.func.count(self.Article.id).label('articles')
|
||||||
@@ -116,10 +123,10 @@ class TestSortQuery(TestCase):
|
|||||||
query = sort_query(query, '-category-articles_count')
|
query = sort_query(query, '-category-articles_count')
|
||||||
assert 'ORDER BY (SELECT count(article.id) AS count_1' in str(query)
|
assert 'ORDER BY (SELECT count(article.id) AS count_1' in str(query)
|
||||||
|
|
||||||
def test_aliased_relation_hybrid_property(self):
|
def test_aliased_hybrid_property(self):
|
||||||
alias = sa.orm.aliased(
|
alias = sa.orm.aliased(
|
||||||
self.Category,
|
self.Category,
|
||||||
name='category'
|
name='categories'
|
||||||
)
|
)
|
||||||
query = (
|
query = (
|
||||||
self.session.query(self.Article)
|
self.session.query(self.Article)
|
||||||
@@ -128,5 +135,5 @@ class TestSortQuery(TestCase):
|
|||||||
sa.orm.contains_eager(self.Article.category, alias=alias)
|
sa.orm.contains_eager(self.Article.category, alias=alias)
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
query = sort_query(query, '-category-articles_count')
|
query = sort_query(query, '-categories-articles_count')
|
||||||
assert 'ORDER BY (SELECT count(article.id) AS count_1' in str(query)
|
assert 'ORDER BY (SELECT count(article.id) AS count_1' in str(query)
|
||||||
|
Reference in New Issue
Block a user