Fixed an issue with sort_query + joins, made defer_except not accept relationships

This commit is contained in:
Konsta Vesterinen
2013-09-26 11:19:47 +03:00
parent cd92d2b1ff
commit 5fb0c97f0d
4 changed files with 51 additions and 30 deletions

View File

@@ -2,6 +2,8 @@ import six
from sqlalchemy import inspect from sqlalchemy import inspect
from sqlalchemy.orm import defer from sqlalchemy.orm import defer
from sqlalchemy.orm.properties import ColumnProperty from sqlalchemy.orm.properties import ColumnProperty
from sqlalchemy.orm.query import _ColumnEntity
from sqlalchemy.orm.mapper import Mapper
def property_names(properties): def property_names(properties):
@@ -10,30 +12,46 @@ def property_names(properties):
if isinstance(property_, six.string_types): if isinstance(property_, six.string_types):
names.append(property_) names.append(property_)
else: else:
names.append(property_.key) names.append(
'%s.%s' % (
property_.class_.__name__,
property_.key
)
)
return names return names
def defer_except(query, properties): def query_entities(query):
entities = []
for entity in query._entities:
if not isinstance(entity, _ColumnEntity):
entities.append(entity.entity_zero.class_)
for entity in query._join_entities:
if isinstance(entity, Mapper):
entities.append(entity.class_)
else:
entities.append(entity)
return entities
def defer_except(query, columns):
""" """
Deferred loads all properties in given query, except the ones given. Deferred loads all columns in given query, except the ones given.
This function is very useful when working with models with myriad of This function is very useful when working with models with myriad of
properties and you want to deferred load many properties. columns and you want to deferred load many columns.
>>> from sqlalchemy_utils import defer_except >>> from sqlalchemy_utils import defer_except
>>> query = session.query(Article) >>> query = session.query(Article)
>>> query = defer_except(Article, [Article.id, Article.name]) >>> query = defer_except(Article, [Article.id, Article.name])
:param query: SQLAlchemy Query object to apply the deferred loading to :param columns: columns not to deferred load
:param properties: properties not to deferred load
""" """
allowed_names = property_names(properties)
model = query._entities[0].entity_zero.class_ model = query._entities[0].entity_zero.class_
for property_ in inspect(model).attrs: for property_ in inspect(model).attrs:
if isinstance(property_, ColumnProperty): if isinstance(property_, ColumnProperty):
if property_.key not in allowed_names: column = property_.columns[0]
if column.name not in columns:
query = query.options(defer(property_.key)) query = query.options(defer(property_.key))
return query return query

View File

@@ -20,17 +20,6 @@ def sort_expression(expr, attr_name):
return getattr(expr, attr_name) return getattr(expr, attr_name)
def matches_entity(alias, entity):
if not alias:
return True
if isinstance(entity, AliasedInsp):
name = entity.name
else:
name = entity.__table__.name
return name == alias
class QuerySorterException(Exception): class QuerySorterException(Exception):
pass pass
@@ -60,6 +49,19 @@ class QuerySorter(object):
else: else:
self.entities.append(mapper) self.entities.append(mapper)
def get_entity_by_alias(self, alias):
if not alias:
return self.entities[0]
for entity in self.entities:
if isinstance(entity, AliasedInsp):
name = entity.name
else:
name = entity.__table__.name
if name == alias:
return entity
def assign_order_by(self, sort): def assign_order_by(self, sort):
if not sort: if not sort:
return self.query return self.query
@@ -69,10 +71,8 @@ class QuerySorter(object):
if sort['attr'] in self.labels: if sort['attr'] in self.labels:
expr = sort['attr'] expr = sort['attr']
else: else:
for entity in self.entities: entity = self.get_entity_by_alias(sort['entity'])
if not matches_entity(sort['entity'], entity): if entity:
continue
expr = self.order_by_attr(entity, sort['attr']) expr = self.order_by_attr(entity, sort['attr'])
if expr is not None: if expr is not None:

View File

@@ -7,8 +7,3 @@ class TestDeferExcept(TestCase):
query = self.session.query(self.Article) query = self.session.query(self.Article)
query = defer_except(query, ['id']) query = defer_except(query, ['id'])
assert str(query) == 'SELECT article.id AS article_id \nFROM article' assert str(query) == 'SELECT article.id AS article_id \nFROM article'
def test_supports_properties_as_class_attributes(self):
query = self.session.query(self.Article)
query = defer_except(query, [self.Article.id])
assert str(query) == 'SELECT article.id AS article_id \nFROM article'

View File

@@ -29,6 +29,14 @@ class TestSortQuery(TestCase):
with raises(QuerySorterException): with raises(QuerySorterException):
sort_query(query, '-unknown', silent=False) sort_query(query, '-unknown', silent=False)
def test_join(self):
query = (
self.session.query(self.Article)
.join(self.Article.category)
)
query = sort_query(query, 'name', silent=False)
assert 'ORDER BY article.name ASC' in str(query)
def test_calculated_value_ascending(self): def test_calculated_value_ascending(self):
query = self.session.query( query = self.session.query(
self.Category, sa.func.count(self.Article.id).label('articles') self.Category, sa.func.count(self.Article.id).label('articles')