Made sort_query support multiple columns
This commit is contained in:
@@ -4,6 +4,12 @@ Changelog
|
||||
Here you can see the full list of changes between each SQLAlchemy-Utils release.
|
||||
|
||||
|
||||
0.13.1 (2013-06-11)
|
||||
^^^^^^^^^^^^^^^^^^^
|
||||
|
||||
- Made sort_query function support multicolumn sorting.
|
||||
|
||||
|
||||
0.13.0 (2013-06-05)
|
||||
^^^^^^^^^^^^^^^^^^^
|
||||
|
||||
|
2
setup.py
2
setup.py
@@ -24,7 +24,7 @@ class PyTest(Command):
|
||||
|
||||
setup(
|
||||
name='SQLAlchemy-Utils',
|
||||
version='0.13.0',
|
||||
version='0.13.1',
|
||||
url='https://github.com/kvesteri/sqlalchemy-utils',
|
||||
license='BSD',
|
||||
author='Konsta Vesterinen',
|
||||
|
@@ -6,7 +6,78 @@ from sqlalchemy.orm.util import AliasedInsp
|
||||
from sqlalchemy.sql.expression import desc, asc
|
||||
|
||||
|
||||
def sort_query(query, sort):
|
||||
class QuerySorter(object):
|
||||
entities = []
|
||||
labels = []
|
||||
|
||||
def inspect_labels_and_entities(self):
|
||||
for entity in self.query._entities:
|
||||
# get all label names for queries such as:
|
||||
# db.session.query(
|
||||
# Category,
|
||||
# db.func.count(Article.id).label('articles')
|
||||
# )
|
||||
if isinstance(entity, _ColumnEntity) and entity._label_name:
|
||||
self.labels.append(entity._label_name)
|
||||
else:
|
||||
self.entities.append(entity.entity_zero.class_)
|
||||
|
||||
for mapper in self.query._join_entities:
|
||||
if isinstance(mapper, Mapper):
|
||||
self.entities.append(mapper.class_)
|
||||
else:
|
||||
self.entities.append(mapper)
|
||||
|
||||
def assign_order_by(self, sort):
|
||||
if not sort:
|
||||
return self.query
|
||||
|
||||
if sort[0] == '-':
|
||||
func = desc
|
||||
sort = sort[1:]
|
||||
else:
|
||||
func = asc
|
||||
|
||||
component = None
|
||||
parts = sort.split('-')
|
||||
if len(parts) > 1:
|
||||
component = parts[0]
|
||||
sort = parts[1]
|
||||
if sort in self.labels:
|
||||
return self.query.order_by(func(sort))
|
||||
|
||||
for entity in self.entities:
|
||||
if isinstance(entity, AliasedInsp):
|
||||
if component and entity.name != component:
|
||||
continue
|
||||
|
||||
selectable = entity.selectable
|
||||
|
||||
if sort in selectable.c:
|
||||
attr = selectable.c[sort]
|
||||
return self.query.order_by(func(attr))
|
||||
else:
|
||||
table = entity.__table__
|
||||
if component and table.name != component:
|
||||
continue
|
||||
if sort in table.columns:
|
||||
try:
|
||||
attr = getattr(entity, sort)
|
||||
return self.query.order_by(func(attr))
|
||||
except AttributeError:
|
||||
pass
|
||||
break
|
||||
return self.query
|
||||
|
||||
def __call__(self, query, *args):
|
||||
self.query = query
|
||||
self.inspect_labels_and_entities()
|
||||
for sort in args:
|
||||
self.query = self.assign_order_by(sort)
|
||||
return self.query
|
||||
|
||||
|
||||
def sort_query(query, *args):
|
||||
"""
|
||||
Applies an sql ORDER BY for given query. This function can be easily used
|
||||
with user-defined sorting.
|
||||
@@ -71,64 +142,8 @@ def sort_query(query, sort):
|
||||
:param errors: whether or not to raise exceptions if unknown sort column
|
||||
is passed
|
||||
"""
|
||||
entities = []
|
||||
labels = []
|
||||
for entity in query._entities:
|
||||
# get all label names for queries such as:
|
||||
# db.session.query(
|
||||
# Category,
|
||||
# db.func.count(Article.id).label('articles')
|
||||
# )
|
||||
if isinstance(entity, _ColumnEntity) and entity._label_name:
|
||||
labels.append(entity._label_name)
|
||||
else:
|
||||
entities.append(entity.entity_zero.class_)
|
||||
|
||||
for mapper in query._join_entities:
|
||||
if isinstance(mapper, Mapper):
|
||||
entities.append(mapper.class_)
|
||||
else:
|
||||
entities.append(mapper)
|
||||
|
||||
if not sort:
|
||||
return query
|
||||
|
||||
if sort[0] == '-':
|
||||
func = desc
|
||||
sort = sort[1:]
|
||||
else:
|
||||
func = asc
|
||||
|
||||
component = None
|
||||
parts = sort.split('-')
|
||||
if len(parts) > 1:
|
||||
component = parts[0]
|
||||
sort = parts[1]
|
||||
if sort in labels:
|
||||
return query.order_by(func(sort))
|
||||
|
||||
for entity in entities:
|
||||
if isinstance(entity, AliasedInsp):
|
||||
if component and entity.name != component:
|
||||
continue
|
||||
|
||||
selectable = entity.selectable
|
||||
|
||||
if sort in selectable.c:
|
||||
attr = selectable.c[sort]
|
||||
query = query.order_by(func(attr))
|
||||
else:
|
||||
table = entity.__table__
|
||||
if component and table.name != component:
|
||||
continue
|
||||
if sort in table.columns:
|
||||
try:
|
||||
attr = getattr(entity, sort)
|
||||
query = query.order_by(func(attr))
|
||||
except AttributeError:
|
||||
pass
|
||||
break
|
||||
return query
|
||||
return QuerySorter()(query, *args)
|
||||
|
||||
|
||||
def defer_except(query, columns):
|
||||
|
73
tests/test_sort_query.py
Normal file
73
tests/test_sort_query.py
Normal file
@@ -0,0 +1,73 @@
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy_utils import sort_query
|
||||
from tests import TestCase
|
||||
|
||||
|
||||
class TestSortQuery(TestCase):
|
||||
def test_without_sort_param_returns_the_query_object_untouched(self):
|
||||
query = self.session.query(self.Article)
|
||||
sorted_query = sort_query(query, '')
|
||||
assert query == sorted_query
|
||||
|
||||
def test_sort_by_column_ascending(self):
|
||||
query = sort_query(self.session.query(self.Article), 'name')
|
||||
assert 'ORDER BY article.name ASC' in str(query)
|
||||
|
||||
def test_sort_by_column_descending(self):
|
||||
query = sort_query(self.session.query(self.Article), '-name')
|
||||
assert 'ORDER BY article.name DESC' in str(query)
|
||||
|
||||
def test_skips_unknown_columns(self):
|
||||
query = self.session.query(self.Article)
|
||||
sorted_query = sort_query(query, '-unknown')
|
||||
assert query == sorted_query
|
||||
|
||||
def test_sort_by_calculated_value_ascending(self):
|
||||
query = self.session.query(
|
||||
self.Category, sa.func.count(self.Article.id).label('articles')
|
||||
)
|
||||
query = sort_query(query, 'articles')
|
||||
assert 'ORDER BY articles ASC' in str(query)
|
||||
|
||||
def test_sort_by_calculated_value_descending(self):
|
||||
query = self.session.query(
|
||||
self.Category, sa.func.count(self.Article.id).label('articles')
|
||||
)
|
||||
query = sort_query(query, '-articles')
|
||||
assert 'ORDER BY articles DESC' in str(query)
|
||||
|
||||
def test_sort_by_subqueried_scalar(self):
|
||||
article_count = (
|
||||
sa.sql.select(
|
||||
[sa.func.count(self.Article.id)],
|
||||
from_obj=[self.Article.__table__]
|
||||
)
|
||||
.where(self.Article.category_id == self.Category.id)
|
||||
.correlate(self.Category.__table__)
|
||||
)
|
||||
|
||||
query = self.session.query(
|
||||
self.Category, article_count.label('articles')
|
||||
)
|
||||
query = sort_query(query, '-articles')
|
||||
assert 'ORDER BY articles DESC' in str(query)
|
||||
|
||||
def test_sort_by_aliased_joined_entity(self):
|
||||
alias = sa.orm.aliased(self.Category, name='categories')
|
||||
query = self.session.query(
|
||||
self.Article
|
||||
).join(
|
||||
alias, self.Article.category
|
||||
)
|
||||
query = sort_query(query, '-categories-name')
|
||||
assert 'ORDER BY categories.name DESC' in str(query)
|
||||
|
||||
def test_sort_by_joined_table_column(self):
|
||||
query = self.session.query(self.Article).join(self.Article.category)
|
||||
sorted_query = sort_query(query, 'category-name')
|
||||
assert 'category.name ASC' in str(sorted_query)
|
||||
|
||||
def test_sort_by_multiple_columns(self):
|
||||
query = self.session.query(self.Article)
|
||||
sorted_query = sort_query(query, 'name', 'id')
|
||||
assert 'article.name ASC, article.id ASC' in str(sorted_query)
|
@@ -1,5 +1,4 @@
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy_utils import escape_like, sort_query, defer_except
|
||||
from sqlalchemy_utils import escape_like, defer_except
|
||||
from tests import TestCase
|
||||
|
||||
|
||||
@@ -13,68 +12,3 @@ class TestDeferExcept(TestCase):
|
||||
query = self.session.query(self.Article)
|
||||
query = defer_except(query, ['id'])
|
||||
assert str(query) == 'SELECT article.id AS article_id \nFROM article'
|
||||
|
||||
|
||||
class TestSortQuery(TestCase):
|
||||
def test_without_sort_param_returns_the_query_object_untouched(self):
|
||||
query = self.session.query(self.Article)
|
||||
sorted_query = sort_query(query, '')
|
||||
assert query == sorted_query
|
||||
|
||||
def test_sort_by_column_ascending(self):
|
||||
query = sort_query(self.session.query(self.Article), 'name')
|
||||
assert 'ORDER BY article.name ASC' in str(query)
|
||||
|
||||
def test_sort_by_column_descending(self):
|
||||
query = sort_query(self.session.query(self.Article), '-name')
|
||||
assert 'ORDER BY article.name DESC' in str(query)
|
||||
|
||||
def test_skips_unknown_columns(self):
|
||||
query = self.session.query(self.Article)
|
||||
sorted_query = sort_query(query, '-unknown')
|
||||
assert query == sorted_query
|
||||
|
||||
def test_sort_by_calculated_value_ascending(self):
|
||||
query = self.session.query(
|
||||
self.Category, sa.func.count(self.Article.id).label('articles')
|
||||
)
|
||||
query = sort_query(query, 'articles')
|
||||
assert 'ORDER BY articles ASC' in str(query)
|
||||
|
||||
def test_sort_by_calculated_value_descending(self):
|
||||
query = self.session.query(
|
||||
self.Category, sa.func.count(self.Article.id).label('articles')
|
||||
)
|
||||
query = sort_query(query, '-articles')
|
||||
assert 'ORDER BY articles DESC' in str(query)
|
||||
|
||||
def test_sort_by_subqueried_scalar(self):
|
||||
article_count = (
|
||||
sa.sql.select(
|
||||
[sa.func.count(self.Article.id)],
|
||||
from_obj=[self.Article.__table__]
|
||||
)
|
||||
.where(self.Article.category_id == self.Category.id)
|
||||
.correlate(self.Category.__table__)
|
||||
)
|
||||
|
||||
query = self.session.query(
|
||||
self.Category, article_count.label('articles')
|
||||
)
|
||||
query = sort_query(query, '-articles')
|
||||
assert 'ORDER BY articles DESC' in str(query)
|
||||
|
||||
def test_sort_by_aliased_joined_entity(self):
|
||||
alias = sa.orm.aliased(self.Category, name='categories')
|
||||
query = self.session.query(
|
||||
self.Article
|
||||
).join(
|
||||
alias, self.Article.category
|
||||
)
|
||||
query = sort_query(query, '-categories-name')
|
||||
assert 'ORDER BY categories.name DESC' in str(query)
|
||||
|
||||
def test_sort_by_joined_table_column(self):
|
||||
query = self.session.query(self.Article).join(self.Article.category)
|
||||
sorted_query = sort_query(query, 'category-name')
|
||||
assert 'category.name ASC' in str(sorted_query)
|
||||
|
Reference in New Issue
Block a user