Fixed an issue with sort_query + joins, made defer_except not accept relationships
This commit is contained in:
@@ -2,6 +2,8 @@ import six
|
||||
from sqlalchemy import inspect
|
||||
from sqlalchemy.orm import defer
|
||||
from sqlalchemy.orm.properties import ColumnProperty
|
||||
from sqlalchemy.orm.query import _ColumnEntity
|
||||
from sqlalchemy.orm.mapper import Mapper
|
||||
|
||||
|
||||
def property_names(properties):
|
||||
@@ -10,30 +12,46 @@ def property_names(properties):
|
||||
if isinstance(property_, six.string_types):
|
||||
names.append(property_)
|
||||
else:
|
||||
names.append(property_.key)
|
||||
names.append(
|
||||
'%s.%s' % (
|
||||
property_.class_.__name__,
|
||||
property_.key
|
||||
)
|
||||
)
|
||||
return names
|
||||
|
||||
|
||||
def defer_except(query, properties):
|
||||
def query_entities(query):
|
||||
entities = []
|
||||
for entity in query._entities:
|
||||
if not isinstance(entity, _ColumnEntity):
|
||||
entities.append(entity.entity_zero.class_)
|
||||
|
||||
for entity in query._join_entities:
|
||||
if isinstance(entity, Mapper):
|
||||
entities.append(entity.class_)
|
||||
else:
|
||||
entities.append(entity)
|
||||
return entities
|
||||
|
||||
|
||||
def defer_except(query, columns):
|
||||
"""
|
||||
Deferred loads all properties in given query, except the ones given.
|
||||
Deferred loads all columns in given query, except the ones given.
|
||||
|
||||
This function is very useful when working with models with myriad of
|
||||
properties and you want to deferred load many properties.
|
||||
columns and you want to deferred load many columns.
|
||||
|
||||
>>> from sqlalchemy_utils import defer_except
|
||||
>>> query = session.query(Article)
|
||||
>>> query = defer_except(Article, [Article.id, Article.name])
|
||||
|
||||
:param query: SQLAlchemy Query object to apply the deferred loading to
|
||||
:param properties: properties not to deferred load
|
||||
:param columns: columns not to deferred load
|
||||
"""
|
||||
|
||||
allowed_names = property_names(properties)
|
||||
|
||||
model = query._entities[0].entity_zero.class_
|
||||
for property_ in inspect(model).attrs:
|
||||
if isinstance(property_, ColumnProperty):
|
||||
if property_.key not in allowed_names:
|
||||
column = property_.columns[0]
|
||||
if column.name not in columns:
|
||||
query = query.options(defer(property_.key))
|
||||
return query
|
||||
|
@@ -20,17 +20,6 @@ def sort_expression(expr, attr_name):
|
||||
return getattr(expr, attr_name)
|
||||
|
||||
|
||||
def matches_entity(alias, entity):
|
||||
if not alias:
|
||||
return True
|
||||
if isinstance(entity, AliasedInsp):
|
||||
name = entity.name
|
||||
else:
|
||||
name = entity.__table__.name
|
||||
|
||||
return name == alias
|
||||
|
||||
|
||||
class QuerySorterException(Exception):
|
||||
pass
|
||||
|
||||
@@ -60,6 +49,19 @@ class QuerySorter(object):
|
||||
else:
|
||||
self.entities.append(mapper)
|
||||
|
||||
def get_entity_by_alias(self, alias):
|
||||
if not alias:
|
||||
return self.entities[0]
|
||||
|
||||
for entity in self.entities:
|
||||
if isinstance(entity, AliasedInsp):
|
||||
name = entity.name
|
||||
else:
|
||||
name = entity.__table__.name
|
||||
|
||||
if name == alias:
|
||||
return entity
|
||||
|
||||
def assign_order_by(self, sort):
|
||||
if not sort:
|
||||
return self.query
|
||||
@@ -69,10 +71,8 @@ class QuerySorter(object):
|
||||
if sort['attr'] in self.labels:
|
||||
expr = sort['attr']
|
||||
else:
|
||||
for entity in self.entities:
|
||||
if not matches_entity(sort['entity'], entity):
|
||||
continue
|
||||
|
||||
entity = self.get_entity_by_alias(sort['entity'])
|
||||
if entity:
|
||||
expr = self.order_by_attr(entity, sort['attr'])
|
||||
|
||||
if expr is not None:
|
||||
|
@@ -7,8 +7,3 @@ class TestDeferExcept(TestCase):
|
||||
query = self.session.query(self.Article)
|
||||
query = defer_except(query, ['id'])
|
||||
assert str(query) == 'SELECT article.id AS article_id \nFROM article'
|
||||
|
||||
def test_supports_properties_as_class_attributes(self):
|
||||
query = self.session.query(self.Article)
|
||||
query = defer_except(query, [self.Article.id])
|
||||
assert str(query) == 'SELECT article.id AS article_id \nFROM article'
|
||||
|
@@ -29,6 +29,14 @@ class TestSortQuery(TestCase):
|
||||
with raises(QuerySorterException):
|
||||
sort_query(query, '-unknown', silent=False)
|
||||
|
||||
def test_join(self):
|
||||
query = (
|
||||
self.session.query(self.Article)
|
||||
.join(self.Article.category)
|
||||
)
|
||||
query = sort_query(query, 'name', silent=False)
|
||||
assert 'ORDER BY article.name ASC' in str(query)
|
||||
|
||||
def test_calculated_value_ascending(self):
|
||||
query = self.session.query(
|
||||
self.Category, sa.func.count(self.Article.id).label('articles')
|
||||
|
Reference in New Issue
Block a user