Added support for relation hybrid property sorting for sort_query

This commit is contained in:
Konsta Vesterinen
2013-09-19 15:12:32 +03:00
parent 747ff6df99
commit af7aa08c64
2 changed files with 69 additions and 39 deletions

View File

@@ -1,3 +1,4 @@
from sqlalchemy import inspect
from sqlalchemy.orm.mapper import Mapper
from sqlalchemy.orm.properties import ColumnProperty
from sqlalchemy.orm.query import _ColumnEntity
@@ -5,6 +6,31 @@ from sqlalchemy.orm.util import AliasedInsp
from sqlalchemy.sql.expression import desc, asc, Label
def attrs(expr):
if isinstance(expr, AliasedInsp):
return expr.mapper.attrs
else:
return inspect(expr).attrs
def sort_expression(expr, attr_name):
if isinstance(expr, AliasedInsp):
return getattr(expr.selectable.c, attr_name)
else:
return getattr(expr, attr_name)
def matches_entity(alias, entity):
if not alias:
return True
if isinstance(entity, AliasedInsp):
name = entity.name
else:
name = entity.__table__.name
return name == alias
class QuerySorter(object):
def __init__(self, separator='-'):
self.entities = []
@@ -38,43 +64,31 @@ class QuerySorter(object):
return self.query.order_by(sort['func'](sort['attr']))
for entity in self.entities:
if isinstance(entity, AliasedInsp):
if sort['entity'] and entity.name != sort['entity']:
continue
if not matches_entity(sort['entity'], entity):
continue
selectable = entity.selectable
if sort['attr'] in selectable.c:
attr = selectable.c[sort['attr']]
return self.query.order_by(sort['func'](attr))
else:
table = entity.__table__
if sort['entity'] and table.name != sort['entity']:
continue
return self.assign_entity_attr_order_by(entity, sort)
return self.assign_entity_attr_order_by(entity, sort)
return self.query
def assign_entity_attr_order_by(self, entity, sort):
if sort['attr'] in entity.__mapper__.class_manager.keys():
try:
attr = getattr(entity, sort['attr'])
except AttributeError:
pass
else:
property_ = attr.property
if isinstance(property_, ColumnProperty):
if isinstance(attr.property.columns[0], Label):
attr = attr.property.columns[0].name
return self.query.order_by(sort['func'](attr))
properties = attrs(entity)
if sort['attr'] in properties:
property_ = properties[sort['attr']]
if isinstance(property_, ColumnProperty):
if isinstance(property_.columns[0], Label):
expr = property_.columns[0].name
else:
expr = sort_expression(entity, property_.key)
return self.query.order_by(sort['func'](
expr
))
# Check hybrid properties.
if hasattr(entity, sort['attr']):
return self.query.order_by(
sort['func'](getattr(entity, sort['attr']))
)
return self.query
def parse_sort_arg(self, arg):

View File

@@ -9,11 +9,11 @@ class TestSortQuery(TestCase):
sorted_query = sort_query(query, '')
assert query == sorted_query
def test_sort_by_column_ascending(self):
def test_column_ascending(self):
query = sort_query(self.session.query(self.Article), 'name')
assert 'ORDER BY article.name ASC' in str(query)
def test_sort_by_column_descending(self):
def test_column_descending(self):
query = sort_query(self.session.query(self.Article), '-name')
assert 'ORDER BY article.name DESC' in str(query)
@@ -22,21 +22,21 @@ class TestSortQuery(TestCase):
sorted_query = sort_query(query, '-unknown')
assert query == sorted_query
def test_sort_by_calculated_value_ascending(self):
def test_calculated_value_ascending(self):
query = self.session.query(
self.Category, sa.func.count(self.Article.id).label('articles')
)
query = sort_query(query, 'articles')
assert 'ORDER BY articles ASC' in str(query)
def test_sort_by_calculated_value_descending(self):
def test_calculated_value_descending(self):
query = self.session.query(
self.Category, sa.func.count(self.Article.id).label('articles')
)
query = sort_query(query, '-articles')
assert 'ORDER BY articles DESC' in str(query)
def test_sort_by_subqueried_scalar(self):
def test_subqueried_scalar(self):
article_count = (
sa.sql.select(
[sa.func.count(self.Article.id)],
@@ -52,7 +52,7 @@ class TestSortQuery(TestCase):
query = sort_query(query, '-articles')
assert 'ORDER BY articles DESC' in str(query)
def test_sort_by_aliased_joined_entity(self):
def test_aliased_joined_entity(self):
alias = sa.orm.aliased(self.Category, name='categories')
query = self.session.query(
self.Article
@@ -62,17 +62,17 @@ class TestSortQuery(TestCase):
query = sort_query(query, '-categories-name')
assert 'ORDER BY categories.name DESC' in str(query)
def test_sort_by_joined_table_column(self):
def test_joined_table_column(self):
query = self.session.query(self.Article).join(self.Article.category)
sorted_query = sort_query(query, 'category-name')
assert 'category.name ASC' in str(sorted_query)
def test_sort_by_multiple_columns(self):
def test_multiple_columns(self):
query = self.session.query(self.Article)
sorted_query = sort_query(query, 'name', 'id')
assert 'article.name ASC, article.id ASC' in str(sorted_query)
def test_sort_by_column_property(self):
def test_column_property(self):
self.Category.article_count = sa.orm.column_property(
sa.select([sa.func.count(self.Article.id)])
.where(self.Article.category_id == self.Category.id)
@@ -83,7 +83,7 @@ class TestSortQuery(TestCase):
sorted_query = sort_query(query, 'article_count')
assert 'article_count ASC' in str(sorted_query)
def test_sort_by_column_property_descending(self):
def test_column_property_descending(self):
self.Category.article_count = sa.orm.column_property(
sa.select([sa.func.count(self.Article.id)])
.where(self.Article.category_id == self.Category.id)
@@ -94,12 +94,12 @@ class TestSortQuery(TestCase):
sorted_query = sort_query(query, '-article_count')
assert 'article_count DESC' in str(sorted_query)
def test_sort_by_hybrid_property(self):
def test_hybrid_property(self):
query = self.session.query(self.Category)
query = sort_query(query, 'articles_count')
assert 'ORDER BY (SELECT count(article.id) AS count_1' in str(query)
def test_sort_by_hybrid_property_descending(self):
def test_hybrid_property_descending(self):
query = self.session.query(self.Category)
query = sort_query(query, '-articles_count')
assert (
@@ -107,7 +107,7 @@ class TestSortQuery(TestCase):
) in str(query)
assert ' DESC' in str(query)
def test_sort_by_related_hybrid_property(self):
def test_relation_hybrid_property(self):
query = (
self.session.query(self.Article)
.join(self.Article.category)
@@ -115,3 +115,19 @@ class TestSortQuery(TestCase):
)
query = sort_query(query, '-category-articles_count')
assert 'ORDER BY (SELECT count(article.id) AS count_1' in str(query)
def test_aliased_relation_hybrid_property(self):
alias = sa.orm.aliased(
self.Category,
name='category'
)
query = (
self.session.query(self.Article)
.outerjoin(alias, self.Article.category)
.options(
sa.orm.contains_eager(self.Article.category, alias=alias)
)
)
query = sort_query(query, '-category-articles_count')
print query
#assert 'ORDER BY (SELECT count(article.id) AS count_1' in str(query)