diff --git a/CHANGES.rst b/CHANGES.rst index 79adb2f..6704fca 100644 --- a/CHANGES.rst +++ b/CHANGES.rst @@ -4,6 +4,12 @@ Changelog Here you can see the full list of changes between each SQLAlchemy-Utils release. +0.13.1 (2013-06-11) +^^^^^^^^^^^^^^^^^^^ + +- Made sort_query function support multicolumn sorting. + + 0.13.0 (2013-06-05) ^^^^^^^^^^^^^^^^^^^ diff --git a/setup.py b/setup.py index 39e9fde..850ad22 100644 --- a/setup.py +++ b/setup.py @@ -24,7 +24,7 @@ class PyTest(Command): setup( name='SQLAlchemy-Utils', - version='0.13.0', + version='0.13.1', url='https://github.com/kvesteri/sqlalchemy-utils', license='BSD', author='Konsta Vesterinen', diff --git a/sqlalchemy_utils/functions.py b/sqlalchemy_utils/functions.py index 2425b80..ea8c4ca 100644 --- a/sqlalchemy_utils/functions.py +++ b/sqlalchemy_utils/functions.py @@ -6,7 +6,78 @@ from sqlalchemy.orm.util import AliasedInsp from sqlalchemy.sql.expression import desc, asc -def sort_query(query, sort): +class QuerySorter(object): + entities = [] + labels = [] + + def inspect_labels_and_entities(self): + for entity in self.query._entities: + # get all label names for queries such as: + # db.session.query( + # Category, + # db.func.count(Article.id).label('articles') + # ) + if isinstance(entity, _ColumnEntity) and entity._label_name: + self.labels.append(entity._label_name) + else: + self.entities.append(entity.entity_zero.class_) + + for mapper in self.query._join_entities: + if isinstance(mapper, Mapper): + self.entities.append(mapper.class_) + else: + self.entities.append(mapper) + + def assign_order_by(self, sort): + if not sort: + return self.query + + if sort[0] == '-': + func = desc + sort = sort[1:] + else: + func = asc + + component = None + parts = sort.split('-') + if len(parts) > 1: + component = parts[0] + sort = parts[1] + if sort in self.labels: + return self.query.order_by(func(sort)) + + for entity in self.entities: + if isinstance(entity, AliasedInsp): + if component and entity.name != component: + continue + + selectable = entity.selectable + + if sort in selectable.c: + attr = selectable.c[sort] + return self.query.order_by(func(attr)) + else: + table = entity.__table__ + if component and table.name != component: + continue + if sort in table.columns: + try: + attr = getattr(entity, sort) + return self.query.order_by(func(attr)) + except AttributeError: + pass + break + return self.query + + def __call__(self, query, *args): + self.query = query + self.inspect_labels_and_entities() + for sort in args: + self.query = self.assign_order_by(sort) + return self.query + + +def sort_query(query, *args): """ Applies an sql ORDER BY for given query. This function can be easily used with user-defined sorting. @@ -71,64 +142,8 @@ def sort_query(query, sort): :param errors: whether or not to raise exceptions if unknown sort column is passed """ - entities = [] - labels = [] - for entity in query._entities: - # get all label names for queries such as: - # db.session.query( - # Category, - # db.func.count(Article.id).label('articles') - # ) - if isinstance(entity, _ColumnEntity) and entity._label_name: - labels.append(entity._label_name) - else: - entities.append(entity.entity_zero.class_) - for mapper in query._join_entities: - if isinstance(mapper, Mapper): - entities.append(mapper.class_) - else: - entities.append(mapper) - - if not sort: - return query - - if sort[0] == '-': - func = desc - sort = sort[1:] - else: - func = asc - - component = None - parts = sort.split('-') - if len(parts) > 1: - component = parts[0] - sort = parts[1] - if sort in labels: - return query.order_by(func(sort)) - - for entity in entities: - if isinstance(entity, AliasedInsp): - if component and entity.name != component: - continue - - selectable = entity.selectable - - if sort in selectable.c: - attr = selectable.c[sort] - query = query.order_by(func(attr)) - else: - table = entity.__table__ - if component and table.name != component: - continue - if sort in table.columns: - try: - attr = getattr(entity, sort) - query = query.order_by(func(attr)) - except AttributeError: - pass - break - return query + return QuerySorter()(query, *args) def defer_except(query, columns): diff --git a/tests/test_sort_query.py b/tests/test_sort_query.py new file mode 100644 index 0000000..844a7e0 --- /dev/null +++ b/tests/test_sort_query.py @@ -0,0 +1,73 @@ +import sqlalchemy as sa +from sqlalchemy_utils import sort_query +from tests import TestCase + + +class TestSortQuery(TestCase): + def test_without_sort_param_returns_the_query_object_untouched(self): + query = self.session.query(self.Article) + sorted_query = sort_query(query, '') + assert query == sorted_query + + def test_sort_by_column_ascending(self): + query = sort_query(self.session.query(self.Article), 'name') + assert 'ORDER BY article.name ASC' in str(query) + + def test_sort_by_column_descending(self): + query = sort_query(self.session.query(self.Article), '-name') + assert 'ORDER BY article.name DESC' in str(query) + + def test_skips_unknown_columns(self): + query = self.session.query(self.Article) + sorted_query = sort_query(query, '-unknown') + assert query == sorted_query + + def test_sort_by_calculated_value_ascending(self): + query = self.session.query( + self.Category, sa.func.count(self.Article.id).label('articles') + ) + query = sort_query(query, 'articles') + assert 'ORDER BY articles ASC' in str(query) + + def test_sort_by_calculated_value_descending(self): + query = self.session.query( + self.Category, sa.func.count(self.Article.id).label('articles') + ) + query = sort_query(query, '-articles') + assert 'ORDER BY articles DESC' in str(query) + + def test_sort_by_subqueried_scalar(self): + article_count = ( + sa.sql.select( + [sa.func.count(self.Article.id)], + from_obj=[self.Article.__table__] + ) + .where(self.Article.category_id == self.Category.id) + .correlate(self.Category.__table__) + ) + + query = self.session.query( + self.Category, article_count.label('articles') + ) + query = sort_query(query, '-articles') + assert 'ORDER BY articles DESC' in str(query) + + def test_sort_by_aliased_joined_entity(self): + alias = sa.orm.aliased(self.Category, name='categories') + query = self.session.query( + self.Article + ).join( + alias, self.Article.category + ) + query = sort_query(query, '-categories-name') + assert 'ORDER BY categories.name DESC' in str(query) + + def test_sort_by_joined_table_column(self): + query = self.session.query(self.Article).join(self.Article.category) + sorted_query = sort_query(query, 'category-name') + assert 'category.name ASC' in str(sorted_query) + + def test_sort_by_multiple_columns(self): + query = self.session.query(self.Article) + sorted_query = sort_query(query, 'name', 'id') + assert 'article.name ASC, article.id ASC' in str(sorted_query) diff --git a/tests/test_utility_functions.py b/tests/test_utility_functions.py index 1d2ac52..025aa97 100644 --- a/tests/test_utility_functions.py +++ b/tests/test_utility_functions.py @@ -1,5 +1,4 @@ -import sqlalchemy as sa -from sqlalchemy_utils import escape_like, sort_query, defer_except +from sqlalchemy_utils import escape_like, defer_except from tests import TestCase @@ -13,68 +12,3 @@ class TestDeferExcept(TestCase): query = self.session.query(self.Article) query = defer_except(query, ['id']) assert str(query) == 'SELECT article.id AS article_id \nFROM article' - - -class TestSortQuery(TestCase): - def test_without_sort_param_returns_the_query_object_untouched(self): - query = self.session.query(self.Article) - sorted_query = sort_query(query, '') - assert query == sorted_query - - def test_sort_by_column_ascending(self): - query = sort_query(self.session.query(self.Article), 'name') - assert 'ORDER BY article.name ASC' in str(query) - - def test_sort_by_column_descending(self): - query = sort_query(self.session.query(self.Article), '-name') - assert 'ORDER BY article.name DESC' in str(query) - - def test_skips_unknown_columns(self): - query = self.session.query(self.Article) - sorted_query = sort_query(query, '-unknown') - assert query == sorted_query - - def test_sort_by_calculated_value_ascending(self): - query = self.session.query( - self.Category, sa.func.count(self.Article.id).label('articles') - ) - query = sort_query(query, 'articles') - assert 'ORDER BY articles ASC' in str(query) - - def test_sort_by_calculated_value_descending(self): - query = self.session.query( - self.Category, sa.func.count(self.Article.id).label('articles') - ) - query = sort_query(query, '-articles') - assert 'ORDER BY articles DESC' in str(query) - - def test_sort_by_subqueried_scalar(self): - article_count = ( - sa.sql.select( - [sa.func.count(self.Article.id)], - from_obj=[self.Article.__table__] - ) - .where(self.Article.category_id == self.Category.id) - .correlate(self.Category.__table__) - ) - - query = self.session.query( - self.Category, article_count.label('articles') - ) - query = sort_query(query, '-articles') - assert 'ORDER BY articles DESC' in str(query) - - def test_sort_by_aliased_joined_entity(self): - alias = sa.orm.aliased(self.Category, name='categories') - query = self.session.query( - self.Article - ).join( - alias, self.Article.category - ) - query = sort_query(query, '-categories-name') - assert 'ORDER BY categories.name DESC' in str(query) - - def test_sort_by_joined_table_column(self): - query = self.session.query(self.Article).join(self.Article.category) - sorted_query = sort_query(query, 'category-name') - assert 'category.name ASC' in str(sorted_query)