Add support for various polymorphic scenarios in sort_query
- Rename query_entities to get_query_entities - Add tests for get_query_entities
This commit is contained in:
		@@ -4,10 +4,11 @@ Changelog
 | 
			
		||||
Here you can see the full list of changes between each SQLAlchemy-Utils release.
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
0.26.9 (2014-08-xx)
 | 
			
		||||
0.26.9 (2014-08-06)
 | 
			
		||||
^^^^^^^^^^^^^^^^^^^
 | 
			
		||||
 | 
			
		||||
- Fixed PasswordType with Oracle dialect
 | 
			
		||||
- Added support for sort_query and attributes on mappers using with_polymorphic
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
0.26.8 (2014-07-30)
 | 
			
		||||
 
 | 
			
		||||
@@ -46,6 +46,12 @@ get_mapper
 | 
			
		||||
.. autofunction:: get_mapper
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
get_query_entities
 | 
			
		||||
^^^^^^^^^^^^^^^^^^
 | 
			
		||||
 | 
			
		||||
.. autofunction:: get_query_entities
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
get_primary_keys
 | 
			
		||||
^^^^^^^^^^^^^^^^
 | 
			
		||||
 | 
			
		||||
@@ -58,12 +64,6 @@ get_tables
 | 
			
		||||
.. autofunction:: get_tables
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
query_entities
 | 
			
		||||
^^^^^^^^^^^^^^
 | 
			
		||||
 | 
			
		||||
.. autofunction:: query_entities
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
has_changes
 | 
			
		||||
^^^^^^^^^^^
 | 
			
		||||
 | 
			
		||||
 
 | 
			
		||||
@@ -17,6 +17,7 @@ from .functions import (
 | 
			
		||||
    get_declarative_base,
 | 
			
		||||
    get_hybrid_properties,
 | 
			
		||||
    get_mapper,
 | 
			
		||||
    get_query_entities,
 | 
			
		||||
    get_primary_keys,
 | 
			
		||||
    get_referencing_foreign_keys,
 | 
			
		||||
    get_tables,
 | 
			
		||||
@@ -99,6 +100,7 @@ __all__ = (
 | 
			
		||||
    get_declarative_base,
 | 
			
		||||
    get_hybrid_properties,
 | 
			
		||||
    get_mapper,
 | 
			
		||||
    get_query_entities,
 | 
			
		||||
    get_primary_keys,
 | 
			
		||||
    get_referencing_foreign_keys,
 | 
			
		||||
    get_tables,
 | 
			
		||||
 
 | 
			
		||||
@@ -26,13 +26,13 @@ from .orm import (
 | 
			
		||||
    get_hybrid_properties,
 | 
			
		||||
    get_mapper,
 | 
			
		||||
    get_primary_keys,
 | 
			
		||||
    get_query_entities,
 | 
			
		||||
    get_tables,
 | 
			
		||||
    getdotattr,
 | 
			
		||||
    has_any_changes,
 | 
			
		||||
    has_changes,
 | 
			
		||||
    identity,
 | 
			
		||||
    naturally_equivalent,
 | 
			
		||||
    query_entities,
 | 
			
		||||
    table_name,
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
@@ -49,6 +49,7 @@ __all__ = (
 | 
			
		||||
    'get_declarative_base',
 | 
			
		||||
    'get_hybrid_properties',
 | 
			
		||||
    'get_mapper',
 | 
			
		||||
    'get_query_entities',
 | 
			
		||||
    'get_primary_keys',
 | 
			
		||||
    'get_referencing_foreign_keys',
 | 
			
		||||
    'get_tables',
 | 
			
		||||
 
 | 
			
		||||
@@ -4,15 +4,14 @@ except ImportError:
 | 
			
		||||
    from ordereddict import OrderedDict
 | 
			
		||||
from functools import partial
 | 
			
		||||
from inspect import isclass
 | 
			
		||||
from itertools import chain
 | 
			
		||||
from operator import attrgetter
 | 
			
		||||
import six
 | 
			
		||||
import sqlalchemy as sa
 | 
			
		||||
from sqlalchemy import inspect
 | 
			
		||||
from sqlalchemy.ext.hybrid import hybrid_property
 | 
			
		||||
from sqlalchemy.orm import mapperlib
 | 
			
		||||
from sqlalchemy.orm.attributes import InstrumentedAttribute
 | 
			
		||||
from sqlalchemy.orm.exc import UnmappedInstanceError
 | 
			
		||||
from sqlalchemy.orm.mapper import Mapper
 | 
			
		||||
from sqlalchemy.orm.query import _ColumnEntity
 | 
			
		||||
from sqlalchemy.orm.session import object_session
 | 
			
		||||
from sqlalchemy.orm.util import AliasedInsp
 | 
			
		||||
@@ -81,6 +80,8 @@ def get_mapper(mixed):
 | 
			
		||||
        return sa.inspect(mixed).mapper
 | 
			
		||||
    if isinstance(mixed, sa.sql.selectable.Alias):
 | 
			
		||||
        mixed = mixed.element
 | 
			
		||||
    if isinstance(mixed, AliasedInsp):
 | 
			
		||||
        return mixed.mapper
 | 
			
		||||
    if isinstance(mixed, sa.Table):
 | 
			
		||||
        mappers = [
 | 
			
		||||
            mapper for mapper in mapperlib._mapper_registry
 | 
			
		||||
@@ -344,31 +345,31 @@ def query_labels(query):
 | 
			
		||||
            db.func.count(Article.id).label('articles')
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
        query_labels(query)  # ('articles', )
 | 
			
		||||
        query_labels(query)  # ['articles']
 | 
			
		||||
 | 
			
		||||
    :param query: SQLAlchemy Query object
 | 
			
		||||
    """
 | 
			
		||||
    for entity in query._entities:
 | 
			
		||||
        if isinstance(entity, _ColumnEntity) and entity._label_name:
 | 
			
		||||
            yield entity._label_name
 | 
			
		||||
    return [
 | 
			
		||||
        entity._label_name for entity in query._entities
 | 
			
		||||
        if isinstance(entity, _ColumnEntity) and entity._label_name
 | 
			
		||||
    ]
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def query_entities(query):
 | 
			
		||||
def get_query_entities(query):
 | 
			
		||||
    """
 | 
			
		||||
    Return a generator that iterates through all entities for given SQLAlchemy
 | 
			
		||||
    query object.
 | 
			
		||||
    Return a list of all entities present in given SQLAlchemy query object.
 | 
			
		||||
 | 
			
		||||
    Examples::
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
        query = session.query(Category)
 | 
			
		||||
 | 
			
		||||
        query_entities(query)  # <Category>
 | 
			
		||||
        query_entities(query)  # [<Category>]
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
        query = session.query(Category.id)
 | 
			
		||||
 | 
			
		||||
        query_entities(query)  # <Category>
 | 
			
		||||
        query_entities(query)  # [<Category>]
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
    This function also supports queries with joins.
 | 
			
		||||
@@ -378,42 +379,66 @@ def query_entities(query):
 | 
			
		||||
 | 
			
		||||
        query = session.query(Category).join(Article)
 | 
			
		||||
 | 
			
		||||
        query_entities(query)  # (<Category>, <Article>)
 | 
			
		||||
        query_entities(query)  # [<Category>, <Article>]
 | 
			
		||||
 | 
			
		||||
    .. versionchanged: 0.26.7
 | 
			
		||||
        This function now returns a list instead of generator
 | 
			
		||||
 | 
			
		||||
    :param query: SQLAlchemy Query object
 | 
			
		||||
    """
 | 
			
		||||
    for entity in query._entities:
 | 
			
		||||
        if entity.entity_zero:
 | 
			
		||||
            yield entity.entity_zero.class_
 | 
			
		||||
 | 
			
		||||
    for entity in query._join_entities:
 | 
			
		||||
        if isinstance(entity, Mapper):
 | 
			
		||||
            yield entity.class_
 | 
			
		||||
        else:
 | 
			
		||||
            yield entity
 | 
			
		||||
    def get_expr(mixed):
 | 
			
		||||
        if isinstance(mixed, (sa.orm.Mapper, AliasedInsp)):
 | 
			
		||||
            return mixed
 | 
			
		||||
        expr = mixed.expr
 | 
			
		||||
        if isinstance(expr, sa.orm.attributes.InstrumentedAttribute):
 | 
			
		||||
            expr = expr.parent
 | 
			
		||||
        elif isinstance(expr, sa.Column):
 | 
			
		||||
            expr = expr.table
 | 
			
		||||
        elif isinstance(expr, sa.sql.expression.Label):
 | 
			
		||||
            if mixed.entity_zero:
 | 
			
		||||
                return mixed.entity_zero
 | 
			
		||||
            else:
 | 
			
		||||
                return expr
 | 
			
		||||
        return expr
 | 
			
		||||
    return [
 | 
			
		||||
        get_expr(entity) for entity in
 | 
			
		||||
        chain(query._entities, query._join_entities)
 | 
			
		||||
    ]
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def get_query_entity_by_alias(query, alias):
 | 
			
		||||
    entities = query_entities(query)
 | 
			
		||||
    entities = get_query_entities(query)
 | 
			
		||||
    if not alias:
 | 
			
		||||
        return list(entities)[0]
 | 
			
		||||
        return entities[0]
 | 
			
		||||
 | 
			
		||||
    for entity in entities:
 | 
			
		||||
        if isinstance(entity, AliasedInsp):
 | 
			
		||||
            name = entity.name
 | 
			
		||||
        else:
 | 
			
		||||
            name = entity.__table__.name
 | 
			
		||||
            name = get_mapper(entity).tables[0].name
 | 
			
		||||
 | 
			
		||||
        if name == alias:
 | 
			
		||||
            return entity
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def get_attrs(expr):
 | 
			
		||||
    if isinstance(expr, AliasedInsp):
 | 
			
		||||
        return expr.mapper.attrs
 | 
			
		||||
def get_polymorphic_mappers(mixed):
 | 
			
		||||
    if isinstance(mixed, AliasedInsp):
 | 
			
		||||
        return mixed.with_polymorphic_mappers
 | 
			
		||||
    else:
 | 
			
		||||
        return inspect(expr).attrs
 | 
			
		||||
        return mixed.polymorphic_map.values()
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def get_attrs(expr):
 | 
			
		||||
    insp = sa.inspect(expr)
 | 
			
		||||
    mapper = get_mapper(expr)
 | 
			
		||||
    polymorphic_mappers = get_polymorphic_mappers(insp)
 | 
			
		||||
 | 
			
		||||
    if polymorphic_mappers:
 | 
			
		||||
        attrs = {}
 | 
			
		||||
        for submapper in polymorphic_mappers:
 | 
			
		||||
            attrs.update(submapper.attrs)
 | 
			
		||||
        return attrs
 | 
			
		||||
    return mapper.attrs
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def get_hybrid_properties(model):
 | 
			
		||||
@@ -467,11 +492,11 @@ def get_hybrid_properties(model):
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def get_expr_attr(expr, attr_name):
 | 
			
		||||
def get_expr_attr(expr, prop):
 | 
			
		||||
    if isinstance(expr, AliasedInsp):
 | 
			
		||||
        return getattr(expr.selectable.c, attr_name)
 | 
			
		||||
        return getattr(expr.selectable.c, prop.key)
 | 
			
		||||
    else:
 | 
			
		||||
        return getattr(expr, attr_name)
 | 
			
		||||
        return getattr(prop.parent.class_, prop.key)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def get_declarative_base(model):
 | 
			
		||||
 
 | 
			
		||||
@@ -7,7 +7,7 @@ from .orm import (
 | 
			
		||||
    get_expr_attr,
 | 
			
		||||
    get_hybrid_properties,
 | 
			
		||||
    get_query_entity_by_alias,
 | 
			
		||||
    query_entities,
 | 
			
		||||
    get_query_entities,
 | 
			
		||||
    query_labels,
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
@@ -18,7 +18,6 @@ class QuerySorterException(Exception):
 | 
			
		||||
 | 
			
		||||
class QuerySorter(object):
 | 
			
		||||
    def __init__(self, silent=True, separator='-'):
 | 
			
		||||
        self.entities = []
 | 
			
		||||
        self.labels = []
 | 
			
		||||
        self.separator = separator
 | 
			
		||||
        self.silent = silent
 | 
			
		||||
@@ -44,20 +43,21 @@ class QuerySorter(object):
 | 
			
		||||
        properties = get_attrs(entity)
 | 
			
		||||
        if attr in properties:
 | 
			
		||||
            property_ = properties[attr]
 | 
			
		||||
 | 
			
		||||
            if isinstance(property_, ColumnProperty):
 | 
			
		||||
                if isinstance(property_.columns[0], Label):
 | 
			
		||||
                    return getattr(entity, property_.key)
 | 
			
		||||
                else:
 | 
			
		||||
                    return get_expr_attr(entity, property_.key)
 | 
			
		||||
                    return get_expr_attr(entity, property_)
 | 
			
		||||
            elif isinstance(property_, SynonymProperty):
 | 
			
		||||
                return get_expr_attr(entity, property_.key)
 | 
			
		||||
                return get_expr_attr(entity, property_)
 | 
			
		||||
            return
 | 
			
		||||
 | 
			
		||||
        mapper = sa.inspect(entity)
 | 
			
		||||
        entity = mapper.entity
 | 
			
		||||
 | 
			
		||||
        if isinstance(mapper, AliasedInsp):
 | 
			
		||||
            mapper = mapper.mapper
 | 
			
		||||
            entity = mapper.entity
 | 
			
		||||
 | 
			
		||||
        for key in get_hybrid_properties(mapper).keys():
 | 
			
		||||
            if attr == key:
 | 
			
		||||
@@ -80,7 +80,6 @@ class QuerySorter(object):
 | 
			
		||||
    def __call__(self, query, *args):
 | 
			
		||||
        self.query = query
 | 
			
		||||
        self.labels = query_labels(query)
 | 
			
		||||
        self.entities = query_entities(query)
 | 
			
		||||
 | 
			
		||||
        for sort in args:
 | 
			
		||||
            if not sort:
 | 
			
		||||
 
 | 
			
		||||
							
								
								
									
										101
									
								
								tests/functions/test_get_query_entities.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										101
									
								
								tests/functions/test_get_query_entities.py
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,101 @@
 | 
			
		||||
import sqlalchemy as sa
 | 
			
		||||
from sqlalchemy_utils import get_query_entities
 | 
			
		||||
 | 
			
		||||
from tests import TestCase
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class TestGetQueryEntities(TestCase):
 | 
			
		||||
    def create_models(self):
 | 
			
		||||
        class TextItem(self.Base):
 | 
			
		||||
            __tablename__ = 'text_item'
 | 
			
		||||
            id = sa.Column(sa.Integer, primary_key=True)
 | 
			
		||||
 | 
			
		||||
            type = sa.Column(sa.Unicode(255))
 | 
			
		||||
 | 
			
		||||
            __mapper_args__ = {
 | 
			
		||||
                'polymorphic_on': type,
 | 
			
		||||
            }
 | 
			
		||||
 | 
			
		||||
        class Article(TextItem):
 | 
			
		||||
            __tablename__ = 'article'
 | 
			
		||||
            id = sa.Column(
 | 
			
		||||
                sa.Integer, sa.ForeignKey(TextItem.id), primary_key=True
 | 
			
		||||
            )
 | 
			
		||||
            category = sa.Column(sa.Unicode(255))
 | 
			
		||||
            __mapper_args__ = {
 | 
			
		||||
                'polymorphic_identity': u'article'
 | 
			
		||||
            }
 | 
			
		||||
 | 
			
		||||
        class BlogPost(TextItem):
 | 
			
		||||
            __tablename__ = 'blog_post'
 | 
			
		||||
            id = sa.Column(
 | 
			
		||||
                sa.Integer, sa.ForeignKey(TextItem.id), primary_key=True
 | 
			
		||||
            )
 | 
			
		||||
            __mapper_args__ = {
 | 
			
		||||
                'polymorphic_identity': u'blog_post'
 | 
			
		||||
            }
 | 
			
		||||
 | 
			
		||||
        self.TextItem = TextItem
 | 
			
		||||
        self.Article = Article
 | 
			
		||||
        self.BlogPost = BlogPost
 | 
			
		||||
 | 
			
		||||
    def test_mapper(self):
 | 
			
		||||
        query = self.session.query(sa.inspect(self.TextItem))
 | 
			
		||||
        assert get_query_entities(query) == [sa.inspect(self.TextItem)]
 | 
			
		||||
 | 
			
		||||
    def test_entity(self):
 | 
			
		||||
        query = self.session.query(self.TextItem)
 | 
			
		||||
        assert get_query_entities(query) == [self.TextItem]
 | 
			
		||||
 | 
			
		||||
    def test_instrumented_attribute(self):
 | 
			
		||||
        query = self.session.query(self.TextItem.id)
 | 
			
		||||
        assert get_query_entities(query) == [sa.inspect(self.TextItem)]
 | 
			
		||||
 | 
			
		||||
    def test_column(self):
 | 
			
		||||
        query = self.session.query(self.TextItem.__table__.c.id)
 | 
			
		||||
        assert get_query_entities(query) == [self.TextItem.__table__]
 | 
			
		||||
 | 
			
		||||
    def test_aliased_selectable(self):
 | 
			
		||||
        selectable = sa.orm.with_polymorphic(self.TextItem, [self.BlogPost])
 | 
			
		||||
        query = self.session.query(selectable)
 | 
			
		||||
        assert get_query_entities(query) == [selectable]
 | 
			
		||||
 | 
			
		||||
    def test_joined_entity(self):
 | 
			
		||||
        query = self.session.query(self.TextItem).join(
 | 
			
		||||
            self.BlogPost, self.BlogPost.id == self.TextItem.id
 | 
			
		||||
        )
 | 
			
		||||
        assert get_query_entities(query) == [
 | 
			
		||||
            self.TextItem, sa.inspect(self.BlogPost)
 | 
			
		||||
        ]
 | 
			
		||||
 | 
			
		||||
    def test_joined_aliased_entity(self):
 | 
			
		||||
        alias = sa.orm.aliased(self.BlogPost)
 | 
			
		||||
 | 
			
		||||
        query = self.session.query(self.TextItem).join(
 | 
			
		||||
            alias, alias.id == self.TextItem.id
 | 
			
		||||
        )
 | 
			
		||||
        assert get_query_entities(query) == [
 | 
			
		||||
            self.TextItem, sa.inspect(alias)
 | 
			
		||||
        ]
 | 
			
		||||
 | 
			
		||||
    def test_column_entity_with_label(self):
 | 
			
		||||
        query = self.session.query(self.Article.id.label('id'))
 | 
			
		||||
        assert get_query_entities(query) == [sa.inspect(self.Article)]
 | 
			
		||||
 | 
			
		||||
    def test_with_subquery(self):
 | 
			
		||||
        number_of_articles = (
 | 
			
		||||
            sa.select(
 | 
			
		||||
                [sa.func.count(self.Article.id)],
 | 
			
		||||
            )
 | 
			
		||||
            .select_from(
 | 
			
		||||
                self.Article.__table__
 | 
			
		||||
            )
 | 
			
		||||
        ).label('number_of_articles')
 | 
			
		||||
 | 
			
		||||
        query = self.session.query(self.Article, number_of_articles)
 | 
			
		||||
        assert get_query_entities(query) == [self.Article, number_of_articles]
 | 
			
		||||
 | 
			
		||||
    def test_aliased_entity(self):
 | 
			
		||||
        alias = sa.orm.aliased(self.Article)
 | 
			
		||||
        query = self.session.query(alias)
 | 
			
		||||
        assert get_query_entities(query) == [alias]
 | 
			
		||||
@@ -1,74 +0,0 @@
 | 
			
		||||
import sqlalchemy as sa
 | 
			
		||||
from tests import TestCase
 | 
			
		||||
from sqlalchemy_utils.functions import query_entities
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class TestQueryEntities(TestCase):
 | 
			
		||||
    def create_models(self):
 | 
			
		||||
        class User(self.Base):
 | 
			
		||||
            __tablename__ = 'user'
 | 
			
		||||
            id = sa.Column(sa.Integer, autoincrement=True, primary_key=True)
 | 
			
		||||
            name = sa.Column(sa.Unicode(255))
 | 
			
		||||
 | 
			
		||||
        class Category(self.Base):
 | 
			
		||||
            __tablename__ = 'category'
 | 
			
		||||
            id = sa.Column(sa.Integer, primary_key=True)
 | 
			
		||||
            name = sa.Column(sa.Unicode(255))
 | 
			
		||||
 | 
			
		||||
        class Article(self.Base):
 | 
			
		||||
            __tablename__ = 'article'
 | 
			
		||||
            id = sa.Column(sa.Integer, primary_key=True)
 | 
			
		||||
            name = sa.Column(sa.Unicode(255))
 | 
			
		||||
            author_id = sa.Column(
 | 
			
		||||
                sa.Integer, sa.ForeignKey(User.id), index=True
 | 
			
		||||
            )
 | 
			
		||||
            category_id = sa.Column(sa.Integer, sa.ForeignKey(Category.id))
 | 
			
		||||
 | 
			
		||||
            category = sa.orm.relationship(
 | 
			
		||||
                Category,
 | 
			
		||||
                primaryjoin=category_id == Category.id,
 | 
			
		||||
                backref=sa.orm.backref(
 | 
			
		||||
                    'articles',
 | 
			
		||||
                )
 | 
			
		||||
            )
 | 
			
		||||
 | 
			
		||||
        self.User = User
 | 
			
		||||
        self.Category = Category
 | 
			
		||||
        self.Article = Article
 | 
			
		||||
 | 
			
		||||
    def test_simple_query(self):
 | 
			
		||||
        query = self.session.query(self.User)
 | 
			
		||||
        assert list(query_entities(query)) == [self.User]
 | 
			
		||||
 | 
			
		||||
    def test_column_entity(self):
 | 
			
		||||
        query = self.session.query(self.User.id)
 | 
			
		||||
        assert list(query_entities(query)) == [self.User]
 | 
			
		||||
 | 
			
		||||
    def test_column_entity_with_label(self):
 | 
			
		||||
        query = self.session.query(self.User.id.label('id'))
 | 
			
		||||
        assert list(query_entities(query)) == [self.User]
 | 
			
		||||
 | 
			
		||||
    def test_with_subquery(self):
 | 
			
		||||
        number_of_sales = (
 | 
			
		||||
            sa.select(
 | 
			
		||||
                [sa.func.count(self.Article.id)],
 | 
			
		||||
            )
 | 
			
		||||
            .select_from(
 | 
			
		||||
                self.Article.__table__
 | 
			
		||||
            )
 | 
			
		||||
        ).label('number_of_articles')
 | 
			
		||||
 | 
			
		||||
        query = self.session.query(self.User, number_of_sales)
 | 
			
		||||
        assert list(query_entities(query)) == [self.User]
 | 
			
		||||
 | 
			
		||||
    def test_mapper(self):
 | 
			
		||||
        query = self.session.query(self.User.__mapper__)
 | 
			
		||||
        assert list(query_entities(query)) == [self.User]
 | 
			
		||||
 | 
			
		||||
    def test_joins(self):
 | 
			
		||||
        query = self.session.query(self.User.__mapper__).join(self.Article)
 | 
			
		||||
        assert list(query_entities(query)) == [self.User, self.Article]
 | 
			
		||||
 | 
			
		||||
    def test_aliased_entity(self):
 | 
			
		||||
        query = self.session.query(sa.orm.aliased(self.User))
 | 
			
		||||
        assert list(query_entities(query)) == [self.User]
 | 
			
		||||
@@ -210,6 +210,7 @@ class TestSortQueryWithPolymorphicInheritance(TestCase):
 | 
			
		||||
            id = sa.Column(
 | 
			
		||||
                sa.Integer, sa.ForeignKey(TextItem.id), primary_key=True
 | 
			
		||||
            )
 | 
			
		||||
            category = sa.Column(sa.Unicode(255))
 | 
			
		||||
            __mapper_args__ = {
 | 
			
		||||
                'polymorphic_identity': u'article'
 | 
			
		||||
            }
 | 
			
		||||
@@ -233,3 +234,69 @@ class TestSortQueryWithPolymorphicInheritance(TestCase):
 | 
			
		||||
            'item_count'
 | 
			
		||||
        )
 | 
			
		||||
        assert_contains('ORDER BY (SELECT count(:param_2) AS count_2', query)
 | 
			
		||||
 | 
			
		||||
    def test_child_class_attribute(self):
 | 
			
		||||
        query = sort_query(
 | 
			
		||||
            self.session.query(self.TextItem),
 | 
			
		||||
            'category'
 | 
			
		||||
        )
 | 
			
		||||
        assert_contains('ORDER BY article.category ASC', query)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class TestSortQueryWithCustomPolymorphic(TestCase):
 | 
			
		||||
    """
 | 
			
		||||
    Currently this doesn't work with SQLite
 | 
			
		||||
    """
 | 
			
		||||
    dns = 'postgres://postgres@localhost/sqlalchemy_utils_test'
 | 
			
		||||
 | 
			
		||||
    def create_models(self):
 | 
			
		||||
        class TextItem(self.Base):
 | 
			
		||||
            __tablename__ = 'text_item'
 | 
			
		||||
            id = sa.Column(sa.Integer, primary_key=True)
 | 
			
		||||
 | 
			
		||||
            type = sa.Column(sa.Unicode(255))
 | 
			
		||||
 | 
			
		||||
            __mapper_args__ = {
 | 
			
		||||
                'polymorphic_on': type,
 | 
			
		||||
            }
 | 
			
		||||
 | 
			
		||||
        class Article(TextItem):
 | 
			
		||||
            __tablename__ = 'article'
 | 
			
		||||
            id = sa.Column(
 | 
			
		||||
                sa.Integer, sa.ForeignKey(TextItem.id), primary_key=True
 | 
			
		||||
            )
 | 
			
		||||
            category = sa.Column(sa.Unicode(255))
 | 
			
		||||
            __mapper_args__ = {
 | 
			
		||||
                'polymorphic_identity': u'article'
 | 
			
		||||
            }
 | 
			
		||||
 | 
			
		||||
        class BlogPost(TextItem):
 | 
			
		||||
            __tablename__ = 'blog_post'
 | 
			
		||||
            id = sa.Column(
 | 
			
		||||
                sa.Integer, sa.ForeignKey(TextItem.id), primary_key=True
 | 
			
		||||
            )
 | 
			
		||||
            __mapper_args__ = {
 | 
			
		||||
                'polymorphic_identity': u'blog_post'
 | 
			
		||||
            }
 | 
			
		||||
 | 
			
		||||
        self.TextItem = TextItem
 | 
			
		||||
        self.Article = Article
 | 
			
		||||
        self.BlogPost = BlogPost
 | 
			
		||||
 | 
			
		||||
    def test_with_unknown_column(self):
 | 
			
		||||
        query = sort_query(
 | 
			
		||||
            self.session.query(
 | 
			
		||||
                sa.orm.with_polymorphic(self.TextItem, [self.BlogPost])
 | 
			
		||||
            ),
 | 
			
		||||
            'category'
 | 
			
		||||
        )
 | 
			
		||||
        assert 'ORDER BY' not in str(query)
 | 
			
		||||
 | 
			
		||||
    def test_with_existing_column(self):
 | 
			
		||||
        query = sort_query(
 | 
			
		||||
            self.session.query(
 | 
			
		||||
                sa.orm.with_polymorphic(self.TextItem, [self.Article])
 | 
			
		||||
            ),
 | 
			
		||||
            'category'
 | 
			
		||||
        )
 | 
			
		||||
        assert 'ORDER BY' in str(query)
 | 
			
		||||
 
 | 
			
		||||
		Reference in New Issue
	
	Block a user