diff --git a/CHANGES.rst b/CHANGES.rst index a00ec84..f955c82 100644 --- a/CHANGES.rst +++ b/CHANGES.rst @@ -4,6 +4,12 @@ Changelog Here you can see the full list of changes between each SQLAlchemy-Utils release. +0.14.1 (2013-07-02) +^^^^^^^^^^^^^^^^^^^ + +- Made sort_query support column_property selects with labels + + 0.14.0 (2013-07-02) ^^^^^^^^^^^^^^^^^^^ diff --git a/setup.py b/setup.py index 77120a5..773fa8e 100644 --- a/setup.py +++ b/setup.py @@ -24,7 +24,7 @@ class PyTest(Command): setup( name='SQLAlchemy-Utils', - version='0.14.0', + version='0.14.1', url='https://github.com/kvesteri/sqlalchemy-utils', license='BSD', author='Konsta Vesterinen', diff --git a/sqlalchemy_utils/functions.py b/sqlalchemy_utils/functions.py index 6fea131..0340bc0 100644 --- a/sqlalchemy_utils/functions.py +++ b/sqlalchemy_utils/functions.py @@ -3,13 +3,14 @@ from sqlalchemy.orm.mapper import Mapper from sqlalchemy.orm.query import _ColumnEntity from sqlalchemy.orm.properties import ColumnProperty from sqlalchemy.orm.util import AliasedInsp -from sqlalchemy.sql.expression import desc, asc +from sqlalchemy.sql.expression import desc, asc, Label class QuerySorter(object): - def __init__(self): + def __init__(self, separator='-'): self.entities = [] self.labels = [] + self.separator = separator def inspect_labels_and_entities(self): for entity in self.query._entities: @@ -33,43 +34,57 @@ class QuerySorter(object): if not sort: return self.query - if sort[0] == '-': - func = desc - sort = sort[1:] - else: - func = asc - - component = None - parts = sort.split('-') - if len(parts) > 1: - component = parts[0] - sort = parts[1] - if sort in self.labels: - return self.query.order_by(func(sort)) + sort = self.parse_sort_arg(sort) + print sort + if sort['attr'] in self.labels: + return self.query.order_by(sort['func'](sort['attr'])) for entity in self.entities: if isinstance(entity, AliasedInsp): - if component and entity.name != component: + if sort['entity'] and entity.name != sort['entity']: continue selectable = entity.selectable - if sort in selectable.c: - attr = selectable.c[sort] - return self.query.order_by(func(attr)) + if sort['attr'] in selectable.c: + attr = selectable.c[sort['attr']] + return self.query.order_by(sort['func'](attr)) else: table = entity.__table__ - if component and table.name != component: + if sort['entity'] and table.name != sort['entity']: continue - if sort in table.columns: - try: - attr = getattr(entity, sort) - return self.query.order_by(func(attr)) - except AttributeError: - pass - break + return self.assign_entity_attr_order_by(entity, sort) return self.query + def assign_entity_attr_order_by(self, entity, sort): + if sort['attr'] in entity.__mapper__.class_manager.keys(): + try: + attr = getattr(entity, sort['attr']) + except AttributeError: + pass + else: + property_ = attr.property + if isinstance(property_, ColumnProperty): + if isinstance(attr.property.columns[0], Label): + attr = attr.property.columns[0].name + + return self.query.order_by(sort['func'](attr)) + return self.query + + def parse_sort_arg(self, arg): + if arg[0] == self.separator: + func = desc + arg = arg[1:] + else: + func = asc + + parts = arg.split(self.separator) + return { + 'entity': parts[0] if len(parts) > 1 else None, + 'attr': parts[1] if len(parts) > 1 else arg, + 'func': func + } + def __call__(self, query, *args): self.query = query self.inspect_labels_and_entities() diff --git a/tests/test_sort_query.py b/tests/test_sort_query.py index 844a7e0..ce5564e 100644 --- a/tests/test_sort_query.py +++ b/tests/test_sort_query.py @@ -71,3 +71,14 @@ class TestSortQuery(TestCase): query = self.session.query(self.Article) sorted_query = sort_query(query, 'name', 'id') assert 'article.name ASC, article.id ASC' in str(sorted_query) + + def test_sort_by_column_property(self): + self.Category.article_count = sa.orm.column_property( + sa.select([sa.func.count(self.Article.id)]) + .where(self.Article.category_id == self.Category.id) + .label('article_count') + ) + + query = self.session.query(self.Category) + sorted_query = sort_query(query, 'article_count') + assert 'article_count ASC' in str(sorted_query)