Fixed an issue with sort_query + joins, made defer_except not accept relationships
This commit is contained in:
		@@ -2,6 +2,8 @@ import six
 | 
			
		||||
from sqlalchemy import inspect
 | 
			
		||||
from sqlalchemy.orm import defer
 | 
			
		||||
from sqlalchemy.orm.properties import ColumnProperty
 | 
			
		||||
from sqlalchemy.orm.query import _ColumnEntity
 | 
			
		||||
from sqlalchemy.orm.mapper import Mapper
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def property_names(properties):
 | 
			
		||||
@@ -10,30 +12,46 @@ def property_names(properties):
 | 
			
		||||
        if isinstance(property_, six.string_types):
 | 
			
		||||
            names.append(property_)
 | 
			
		||||
        else:
 | 
			
		||||
            names.append(property_.key)
 | 
			
		||||
            names.append(
 | 
			
		||||
                '%s.%s' % (
 | 
			
		||||
                    property_.class_.__name__,
 | 
			
		||||
                    property_.key
 | 
			
		||||
                )
 | 
			
		||||
            )
 | 
			
		||||
    return names
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def defer_except(query, properties):
 | 
			
		||||
def query_entities(query):
 | 
			
		||||
    entities = []
 | 
			
		||||
    for entity in query._entities:
 | 
			
		||||
        if not isinstance(entity, _ColumnEntity):
 | 
			
		||||
            entities.append(entity.entity_zero.class_)
 | 
			
		||||
 | 
			
		||||
    for entity in query._join_entities:
 | 
			
		||||
        if isinstance(entity, Mapper):
 | 
			
		||||
            entities.append(entity.class_)
 | 
			
		||||
        else:
 | 
			
		||||
            entities.append(entity)
 | 
			
		||||
    return entities
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def defer_except(query, columns):
 | 
			
		||||
    """
 | 
			
		||||
    Deferred loads all properties in given query, except the ones given.
 | 
			
		||||
    Deferred loads all columns in given query, except the ones given.
 | 
			
		||||
 | 
			
		||||
    This function is very useful when working with models with myriad of
 | 
			
		||||
    properties and you want to deferred load many properties.
 | 
			
		||||
    columns and you want to deferred load many columns.
 | 
			
		||||
 | 
			
		||||
        >>> from sqlalchemy_utils import defer_except
 | 
			
		||||
        >>> query = session.query(Article)
 | 
			
		||||
        >>> query = defer_except(Article, [Article.id, Article.name])
 | 
			
		||||
 | 
			
		||||
    :param query: SQLAlchemy Query object to apply the deferred loading to
 | 
			
		||||
    :param properties: properties not to deferred load
 | 
			
		||||
    :param columns: columns not to deferred load
 | 
			
		||||
    """
 | 
			
		||||
 | 
			
		||||
    allowed_names = property_names(properties)
 | 
			
		||||
 | 
			
		||||
    model = query._entities[0].entity_zero.class_
 | 
			
		||||
    for property_ in inspect(model).attrs:
 | 
			
		||||
        if isinstance(property_, ColumnProperty):
 | 
			
		||||
            if property_.key not in allowed_names:
 | 
			
		||||
            column = property_.columns[0]
 | 
			
		||||
            if column.name not in columns:
 | 
			
		||||
                query = query.options(defer(property_.key))
 | 
			
		||||
    return query
 | 
			
		||||
 
 | 
			
		||||
@@ -20,17 +20,6 @@ def sort_expression(expr, attr_name):
 | 
			
		||||
        return getattr(expr, attr_name)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def matches_entity(alias, entity):
 | 
			
		||||
    if not alias:
 | 
			
		||||
        return True
 | 
			
		||||
    if isinstance(entity, AliasedInsp):
 | 
			
		||||
        name = entity.name
 | 
			
		||||
    else:
 | 
			
		||||
        name = entity.__table__.name
 | 
			
		||||
 | 
			
		||||
    return name == alias
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class QuerySorterException(Exception):
 | 
			
		||||
    pass
 | 
			
		||||
 | 
			
		||||
@@ -60,6 +49,19 @@ class QuerySorter(object):
 | 
			
		||||
            else:
 | 
			
		||||
                self.entities.append(mapper)
 | 
			
		||||
 | 
			
		||||
    def get_entity_by_alias(self, alias):
 | 
			
		||||
        if not alias:
 | 
			
		||||
            return self.entities[0]
 | 
			
		||||
 | 
			
		||||
        for entity in self.entities:
 | 
			
		||||
            if isinstance(entity, AliasedInsp):
 | 
			
		||||
                name = entity.name
 | 
			
		||||
            else:
 | 
			
		||||
                name = entity.__table__.name
 | 
			
		||||
 | 
			
		||||
            if name == alias:
 | 
			
		||||
                return entity
 | 
			
		||||
 | 
			
		||||
    def assign_order_by(self, sort):
 | 
			
		||||
        if not sort:
 | 
			
		||||
            return self.query
 | 
			
		||||
@@ -69,10 +71,8 @@ class QuerySorter(object):
 | 
			
		||||
        if sort['attr'] in self.labels:
 | 
			
		||||
            expr = sort['attr']
 | 
			
		||||
        else:
 | 
			
		||||
            for entity in self.entities:
 | 
			
		||||
                if not matches_entity(sort['entity'], entity):
 | 
			
		||||
                    continue
 | 
			
		||||
 | 
			
		||||
            entity = self.get_entity_by_alias(sort['entity'])
 | 
			
		||||
            if entity:
 | 
			
		||||
                expr = self.order_by_attr(entity, sort['attr'])
 | 
			
		||||
 | 
			
		||||
        if expr is not None:
 | 
			
		||||
 
 | 
			
		||||
@@ -7,8 +7,3 @@ class TestDeferExcept(TestCase):
 | 
			
		||||
        query = self.session.query(self.Article)
 | 
			
		||||
        query = defer_except(query, ['id'])
 | 
			
		||||
        assert str(query) == 'SELECT article.id AS article_id \nFROM article'
 | 
			
		||||
 | 
			
		||||
    def test_supports_properties_as_class_attributes(self):
 | 
			
		||||
        query = self.session.query(self.Article)
 | 
			
		||||
        query = defer_except(query, [self.Article.id])
 | 
			
		||||
        assert str(query) == 'SELECT article.id AS article_id \nFROM article'
 | 
			
		||||
 
 | 
			
		||||
@@ -29,6 +29,14 @@ class TestSortQuery(TestCase):
 | 
			
		||||
        with raises(QuerySorterException):
 | 
			
		||||
            sort_query(query, '-unknown', silent=False)
 | 
			
		||||
 | 
			
		||||
    def test_join(self):
 | 
			
		||||
        query = (
 | 
			
		||||
            self.session.query(self.Article)
 | 
			
		||||
            .join(self.Article.category)
 | 
			
		||||
        )
 | 
			
		||||
        query = sort_query(query, 'name', silent=False)
 | 
			
		||||
        assert 'ORDER BY article.name ASC' in str(query)
 | 
			
		||||
 | 
			
		||||
    def test_calculated_value_ascending(self):
 | 
			
		||||
        query = self.session.query(
 | 
			
		||||
            self.Category, sa.func.count(self.Article.id).label('articles')
 | 
			
		||||
 
 | 
			
		||||
		Reference in New Issue
	
	Block a user