Fixed an issue with sort_query + joins, made defer_except not accept relationships
This commit is contained in:
@@ -2,6 +2,8 @@ import six
|
|||||||
from sqlalchemy import inspect
|
from sqlalchemy import inspect
|
||||||
from sqlalchemy.orm import defer
|
from sqlalchemy.orm import defer
|
||||||
from sqlalchemy.orm.properties import ColumnProperty
|
from sqlalchemy.orm.properties import ColumnProperty
|
||||||
|
from sqlalchemy.orm.query import _ColumnEntity
|
||||||
|
from sqlalchemy.orm.mapper import Mapper
|
||||||
|
|
||||||
|
|
||||||
def property_names(properties):
|
def property_names(properties):
|
||||||
@@ -10,30 +12,46 @@ def property_names(properties):
|
|||||||
if isinstance(property_, six.string_types):
|
if isinstance(property_, six.string_types):
|
||||||
names.append(property_)
|
names.append(property_)
|
||||||
else:
|
else:
|
||||||
names.append(property_.key)
|
names.append(
|
||||||
|
'%s.%s' % (
|
||||||
|
property_.class_.__name__,
|
||||||
|
property_.key
|
||||||
|
)
|
||||||
|
)
|
||||||
return names
|
return names
|
||||||
|
|
||||||
|
|
||||||
def defer_except(query, properties):
|
def query_entities(query):
|
||||||
|
entities = []
|
||||||
|
for entity in query._entities:
|
||||||
|
if not isinstance(entity, _ColumnEntity):
|
||||||
|
entities.append(entity.entity_zero.class_)
|
||||||
|
|
||||||
|
for entity in query._join_entities:
|
||||||
|
if isinstance(entity, Mapper):
|
||||||
|
entities.append(entity.class_)
|
||||||
|
else:
|
||||||
|
entities.append(entity)
|
||||||
|
return entities
|
||||||
|
|
||||||
|
|
||||||
|
def defer_except(query, columns):
|
||||||
"""
|
"""
|
||||||
Deferred loads all properties in given query, except the ones given.
|
Deferred loads all columns in given query, except the ones given.
|
||||||
|
|
||||||
This function is very useful when working with models with myriad of
|
This function is very useful when working with models with myriad of
|
||||||
properties and you want to deferred load many properties.
|
columns and you want to deferred load many columns.
|
||||||
|
|
||||||
>>> from sqlalchemy_utils import defer_except
|
>>> from sqlalchemy_utils import defer_except
|
||||||
>>> query = session.query(Article)
|
>>> query = session.query(Article)
|
||||||
>>> query = defer_except(Article, [Article.id, Article.name])
|
>>> query = defer_except(Article, [Article.id, Article.name])
|
||||||
|
|
||||||
:param query: SQLAlchemy Query object to apply the deferred loading to
|
:param columns: columns not to deferred load
|
||||||
:param properties: properties not to deferred load
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
allowed_names = property_names(properties)
|
|
||||||
|
|
||||||
model = query._entities[0].entity_zero.class_
|
model = query._entities[0].entity_zero.class_
|
||||||
for property_ in inspect(model).attrs:
|
for property_ in inspect(model).attrs:
|
||||||
if isinstance(property_, ColumnProperty):
|
if isinstance(property_, ColumnProperty):
|
||||||
if property_.key not in allowed_names:
|
column = property_.columns[0]
|
||||||
|
if column.name not in columns:
|
||||||
query = query.options(defer(property_.key))
|
query = query.options(defer(property_.key))
|
||||||
return query
|
return query
|
||||||
|
@@ -20,17 +20,6 @@ def sort_expression(expr, attr_name):
|
|||||||
return getattr(expr, attr_name)
|
return getattr(expr, attr_name)
|
||||||
|
|
||||||
|
|
||||||
def matches_entity(alias, entity):
|
|
||||||
if not alias:
|
|
||||||
return True
|
|
||||||
if isinstance(entity, AliasedInsp):
|
|
||||||
name = entity.name
|
|
||||||
else:
|
|
||||||
name = entity.__table__.name
|
|
||||||
|
|
||||||
return name == alias
|
|
||||||
|
|
||||||
|
|
||||||
class QuerySorterException(Exception):
|
class QuerySorterException(Exception):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@@ -60,6 +49,19 @@ class QuerySorter(object):
|
|||||||
else:
|
else:
|
||||||
self.entities.append(mapper)
|
self.entities.append(mapper)
|
||||||
|
|
||||||
|
def get_entity_by_alias(self, alias):
|
||||||
|
if not alias:
|
||||||
|
return self.entities[0]
|
||||||
|
|
||||||
|
for entity in self.entities:
|
||||||
|
if isinstance(entity, AliasedInsp):
|
||||||
|
name = entity.name
|
||||||
|
else:
|
||||||
|
name = entity.__table__.name
|
||||||
|
|
||||||
|
if name == alias:
|
||||||
|
return entity
|
||||||
|
|
||||||
def assign_order_by(self, sort):
|
def assign_order_by(self, sort):
|
||||||
if not sort:
|
if not sort:
|
||||||
return self.query
|
return self.query
|
||||||
@@ -69,10 +71,8 @@ class QuerySorter(object):
|
|||||||
if sort['attr'] in self.labels:
|
if sort['attr'] in self.labels:
|
||||||
expr = sort['attr']
|
expr = sort['attr']
|
||||||
else:
|
else:
|
||||||
for entity in self.entities:
|
entity = self.get_entity_by_alias(sort['entity'])
|
||||||
if not matches_entity(sort['entity'], entity):
|
if entity:
|
||||||
continue
|
|
||||||
|
|
||||||
expr = self.order_by_attr(entity, sort['attr'])
|
expr = self.order_by_attr(entity, sort['attr'])
|
||||||
|
|
||||||
if expr is not None:
|
if expr is not None:
|
||||||
|
@@ -7,8 +7,3 @@ class TestDeferExcept(TestCase):
|
|||||||
query = self.session.query(self.Article)
|
query = self.session.query(self.Article)
|
||||||
query = defer_except(query, ['id'])
|
query = defer_except(query, ['id'])
|
||||||
assert str(query) == 'SELECT article.id AS article_id \nFROM article'
|
assert str(query) == 'SELECT article.id AS article_id \nFROM article'
|
||||||
|
|
||||||
def test_supports_properties_as_class_attributes(self):
|
|
||||||
query = self.session.query(self.Article)
|
|
||||||
query = defer_except(query, [self.Article.id])
|
|
||||||
assert str(query) == 'SELECT article.id AS article_id \nFROM article'
|
|
||||||
|
@@ -29,6 +29,14 @@ class TestSortQuery(TestCase):
|
|||||||
with raises(QuerySorterException):
|
with raises(QuerySorterException):
|
||||||
sort_query(query, '-unknown', silent=False)
|
sort_query(query, '-unknown', silent=False)
|
||||||
|
|
||||||
|
def test_join(self):
|
||||||
|
query = (
|
||||||
|
self.session.query(self.Article)
|
||||||
|
.join(self.Article.category)
|
||||||
|
)
|
||||||
|
query = sort_query(query, 'name', silent=False)
|
||||||
|
assert 'ORDER BY article.name ASC' in str(query)
|
||||||
|
|
||||||
def test_calculated_value_ascending(self):
|
def test_calculated_value_ascending(self):
|
||||||
query = self.session.query(
|
query = self.session.query(
|
||||||
self.Category, sa.func.count(self.Article.id).label('articles')
|
self.Category, sa.func.count(self.Article.id).label('articles')
|
||||||
|
Reference in New Issue
Block a user