sort_query now supports column property selects with labels

This commit is contained in:
Konsta Vesterinen
2013-07-02 15:11:45 +03:00
parent 5e7a79dfba
commit 2a226d1ed8
4 changed files with 60 additions and 28 deletions

View File

@@ -4,6 +4,12 @@ Changelog
Here you can see the full list of changes between each SQLAlchemy-Utils release.
0.14.1 (2013-07-02)
^^^^^^^^^^^^^^^^^^^
- Made sort_query support column_property selects with labels
0.14.0 (2013-07-02)
^^^^^^^^^^^^^^^^^^^

View File

@@ -24,7 +24,7 @@ class PyTest(Command):
setup(
name='SQLAlchemy-Utils',
version='0.14.0',
version='0.14.1',
url='https://github.com/kvesteri/sqlalchemy-utils',
license='BSD',
author='Konsta Vesterinen',

View File

@@ -3,13 +3,14 @@ from sqlalchemy.orm.mapper import Mapper
from sqlalchemy.orm.query import _ColumnEntity
from sqlalchemy.orm.properties import ColumnProperty
from sqlalchemy.orm.util import AliasedInsp
from sqlalchemy.sql.expression import desc, asc
from sqlalchemy.sql.expression import desc, asc, Label
class QuerySorter(object):
def __init__(self):
def __init__(self, separator='-'):
self.entities = []
self.labels = []
self.separator = separator
def inspect_labels_and_entities(self):
for entity in self.query._entities:
@@ -33,43 +34,57 @@ class QuerySorter(object):
if not sort:
return self.query
if sort[0] == '-':
func = desc
sort = sort[1:]
else:
func = asc
component = None
parts = sort.split('-')
if len(parts) > 1:
component = parts[0]
sort = parts[1]
if sort in self.labels:
return self.query.order_by(func(sort))
sort = self.parse_sort_arg(sort)
print sort
if sort['attr'] in self.labels:
return self.query.order_by(sort['func'](sort['attr']))
for entity in self.entities:
if isinstance(entity, AliasedInsp):
if component and entity.name != component:
if sort['entity'] and entity.name != sort['entity']:
continue
selectable = entity.selectable
if sort in selectable.c:
attr = selectable.c[sort]
return self.query.order_by(func(attr))
if sort['attr'] in selectable.c:
attr = selectable.c[sort['attr']]
return self.query.order_by(sort['func'](attr))
else:
table = entity.__table__
if component and table.name != component:
if sort['entity'] and table.name != sort['entity']:
continue
if sort in table.columns:
try:
attr = getattr(entity, sort)
return self.query.order_by(func(attr))
except AttributeError:
pass
break
return self.assign_entity_attr_order_by(entity, sort)
return self.query
def assign_entity_attr_order_by(self, entity, sort):
if sort['attr'] in entity.__mapper__.class_manager.keys():
try:
attr = getattr(entity, sort['attr'])
except AttributeError:
pass
else:
property_ = attr.property
if isinstance(property_, ColumnProperty):
if isinstance(attr.property.columns[0], Label):
attr = attr.property.columns[0].name
return self.query.order_by(sort['func'](attr))
return self.query
def parse_sort_arg(self, arg):
if arg[0] == self.separator:
func = desc
arg = arg[1:]
else:
func = asc
parts = arg.split(self.separator)
return {
'entity': parts[0] if len(parts) > 1 else None,
'attr': parts[1] if len(parts) > 1 else arg,
'func': func
}
def __call__(self, query, *args):
self.query = query
self.inspect_labels_and_entities()

View File

@@ -71,3 +71,14 @@ class TestSortQuery(TestCase):
query = self.session.query(self.Article)
sorted_query = sort_query(query, 'name', 'id')
assert 'article.name ASC, article.id ASC' in str(sorted_query)
def test_sort_by_column_property(self):
self.Category.article_count = sa.orm.column_property(
sa.select([sa.func.count(self.Article.id)])
.where(self.Article.category_id == self.Category.id)
.label('article_count')
)
query = self.session.query(self.Category)
sorted_query = sort_query(query, 'article_count')
assert 'article_count ASC' in str(sorted_query)