From 5fb0c97f0da732639d0e84141416dde2ae1ce450 Mon Sep 17 00:00:00 2001 From: Konsta Vesterinen Date: Thu, 26 Sep 2013 11:19:47 +0300 Subject: [PATCH] Fixed an issue with sort_query + joins, made defer_except not accept relationships --- sqlalchemy_utils/functions/defer_except.py | 38 ++++++++++++++++------ sqlalchemy_utils/functions/sort_query.py | 30 ++++++++--------- tests/test_defer_except.py | 5 --- tests/test_sort_query.py | 8 +++++ 4 files changed, 51 insertions(+), 30 deletions(-) diff --git a/sqlalchemy_utils/functions/defer_except.py b/sqlalchemy_utils/functions/defer_except.py index 3d89dd2..4de18e4 100644 --- a/sqlalchemy_utils/functions/defer_except.py +++ b/sqlalchemy_utils/functions/defer_except.py @@ -2,6 +2,8 @@ import six from sqlalchemy import inspect from sqlalchemy.orm import defer from sqlalchemy.orm.properties import ColumnProperty +from sqlalchemy.orm.query import _ColumnEntity +from sqlalchemy.orm.mapper import Mapper def property_names(properties): @@ -10,30 +12,46 @@ def property_names(properties): if isinstance(property_, six.string_types): names.append(property_) else: - names.append(property_.key) + names.append( + '%s.%s' % ( + property_.class_.__name__, + property_.key + ) + ) return names -def defer_except(query, properties): +def query_entities(query): + entities = [] + for entity in query._entities: + if not isinstance(entity, _ColumnEntity): + entities.append(entity.entity_zero.class_) + + for entity in query._join_entities: + if isinstance(entity, Mapper): + entities.append(entity.class_) + else: + entities.append(entity) + return entities + + +def defer_except(query, columns): """ - Deferred loads all properties in given query, except the ones given. + Deferred loads all columns in given query, except the ones given. This function is very useful when working with models with myriad of - properties and you want to deferred load many properties. + columns and you want to deferred load many columns. >>> from sqlalchemy_utils import defer_except >>> query = session.query(Article) >>> query = defer_except(Article, [Article.id, Article.name]) - :param query: SQLAlchemy Query object to apply the deferred loading to - :param properties: properties not to deferred load + :param columns: columns not to deferred load """ - - allowed_names = property_names(properties) - model = query._entities[0].entity_zero.class_ for property_ in inspect(model).attrs: if isinstance(property_, ColumnProperty): - if property_.key not in allowed_names: + column = property_.columns[0] + if column.name not in columns: query = query.options(defer(property_.key)) return query diff --git a/sqlalchemy_utils/functions/sort_query.py b/sqlalchemy_utils/functions/sort_query.py index 6685b75..2eaa2ba 100644 --- a/sqlalchemy_utils/functions/sort_query.py +++ b/sqlalchemy_utils/functions/sort_query.py @@ -20,17 +20,6 @@ def sort_expression(expr, attr_name): return getattr(expr, attr_name) -def matches_entity(alias, entity): - if not alias: - return True - if isinstance(entity, AliasedInsp): - name = entity.name - else: - name = entity.__table__.name - - return name == alias - - class QuerySorterException(Exception): pass @@ -60,6 +49,19 @@ class QuerySorter(object): else: self.entities.append(mapper) + def get_entity_by_alias(self, alias): + if not alias: + return self.entities[0] + + for entity in self.entities: + if isinstance(entity, AliasedInsp): + name = entity.name + else: + name = entity.__table__.name + + if name == alias: + return entity + def assign_order_by(self, sort): if not sort: return self.query @@ -69,10 +71,8 @@ class QuerySorter(object): if sort['attr'] in self.labels: expr = sort['attr'] else: - for entity in self.entities: - if not matches_entity(sort['entity'], entity): - continue - + entity = self.get_entity_by_alias(sort['entity']) + if entity: expr = self.order_by_attr(entity, sort['attr']) if expr is not None: diff --git a/tests/test_defer_except.py b/tests/test_defer_except.py index f04543b..e63e4c0 100644 --- a/tests/test_defer_except.py +++ b/tests/test_defer_except.py @@ -7,8 +7,3 @@ class TestDeferExcept(TestCase): query = self.session.query(self.Article) query = defer_except(query, ['id']) assert str(query) == 'SELECT article.id AS article_id \nFROM article' - - def test_supports_properties_as_class_attributes(self): - query = self.session.query(self.Article) - query = defer_except(query, [self.Article.id]) - assert str(query) == 'SELECT article.id AS article_id \nFROM article' diff --git a/tests/test_sort_query.py b/tests/test_sort_query.py index ec92d04..154e11e 100644 --- a/tests/test_sort_query.py +++ b/tests/test_sort_query.py @@ -29,6 +29,14 @@ class TestSortQuery(TestCase): with raises(QuerySorterException): sort_query(query, '-unknown', silent=False) + def test_join(self): + query = ( + self.session.query(self.Article) + .join(self.Article.category) + ) + query = sort_query(query, 'name', silent=False) + assert 'ORDER BY article.name ASC' in str(query) + def test_calculated_value_ascending(self): query = self.session.query( self.Category, sa.func.count(self.Article.id).label('articles')