diff --git a/sqlalchemy_utils/functions/orm.py b/sqlalchemy_utils/functions/orm.py index 92705db..3c13058 100644 --- a/sqlalchemy_utils/functions/orm.py +++ b/sqlalchemy_utils/functions/orm.py @@ -12,6 +12,7 @@ from sqlalchemy.ext.hybrid import hybrid_property from sqlalchemy.orm import mapperlib from sqlalchemy.orm.attributes import InstrumentedAttribute from sqlalchemy.orm.exc import UnmappedInstanceError +from sqlalchemy.orm.properties import ColumnProperty from sqlalchemy.orm.query import _ColumnEntity from sqlalchemy.orm.session import object_session from sqlalchemy.orm.util import AliasedInsp @@ -197,13 +198,7 @@ def get_tables(mixed): SQLAlchemy Mapper / Declarative class or a SA Alias object wrapping any of these objects. """ - if isinstance(mixed, sa.orm.util.AliasedClass): - mapper = sa.inspect(mixed).mapper - else: - if not isclass(mixed): - mixed = mixed.__class__ - mapper = sa.inspect(mixed) - return mapper.tables + return get_mapper(mixed).tables def get_columns(mixed): @@ -414,17 +409,19 @@ def get_query_entities(query): :param query: SQLAlchemy Query object """ - return list( - map(get_selectable, chain(query._entities, query._join_entities)) - ) + return [ + get_query_entity(entity) for entity in + chain(query._entities, query._join_entities) + ] -def get_selectable(mixed): - if isinstance(mixed, (sa.orm.Mapper, AliasedInsp)): - return mixed - expr = mixed.expr +def get_query_entity(mixed): + if hasattr(mixed, 'expr'): + expr = mixed.expr + else: + expr = mixed if isinstance(expr, sa.orm.attributes.InstrumentedAttribute): - return expr.parent + return expr.parent.class_ elif isinstance(expr, sa.Column): return expr.table elif isinstance(expr, sa.sql.expression.Label): @@ -432,17 +429,22 @@ def get_selectable(mixed): return mixed.entity_zero else: return expr + elif isinstance(expr, sa.orm.Mapper): + return expr.class_ + elif isinstance(expr, AliasedInsp): + return expr.entity return expr def get_query_entity_by_alias(query, alias): entities = get_query_entities(query) + if not alias: return entities[0] for entity in entities: - if isinstance(entity, AliasedInsp): - name = entity.name + if isinstance(entity, sa.orm.util.AliasedClass): + name = sa.inspect(entity).name else: name = get_mapper(entity).tables[0].name @@ -457,17 +459,61 @@ def get_polymorphic_mappers(mixed): return mixed.polymorphic_map.values() -def get_attrs(expr): - insp = sa.inspect(expr) - mapper = get_mapper(expr) - polymorphic_mappers = get_polymorphic_mappers(insp) +def get_query_descriptor(query, entity, attr): + if attr in query_labels(query): + return attr + else: + entity = get_query_entity_by_alias(query, entity) + if entity: + descriptor = get_descriptor(entity, attr) + if ( + hasattr(descriptor, 'property') and + isinstance(descriptor.property, sa.orm.RelationshipProperty) + ): + return + return descriptor + +def get_descriptor(entity, attr): + mapper = sa.inspect(entity) + + for key, descriptor in get_all_descriptors(mapper).items(): + if attr == key: + prop = ( + descriptor.property + if hasattr(descriptor, 'property') + else None + ) + if isinstance(prop, ColumnProperty): + if isinstance(entity, sa.orm.util.AliasedClass): + for c in mapper.selectable.c: + if c.key == attr: + return c + else: + # If the property belongs to a class that uses + # polymorphic inheritance we have to take into account + # situations where the attribute exists in child class + # but not in parent class. + return getattr(prop.parent.class_, attr) + else: + # Handle synonyms, relationship proeprties and hybrid + # properties + try: + return getattr(entity, attr) + except AttributeError: + pass + + +def get_all_descriptors(expr): + insp = sa.inspect(expr) + polymorphic_mappers = get_polymorphic_mappers(insp) if polymorphic_mappers: + attrs = {} for submapper in polymorphic_mappers: - attrs.update(submapper.attrs) + attrs.update(submapper.all_orm_descriptors) return attrs - return mapper.attrs + return get_mapper(expr).all_orm_descriptors def get_hybrid_properties(model): @@ -521,13 +567,6 @@ def get_hybrid_properties(model): ) -def get_expr_attr(expr, prop): - if isinstance(expr, AliasedInsp): - return getattr(expr.selectable.c, prop.key) - else: - return getattr(prop.parent.class_, prop.key) - - def get_declarative_base(model): """ Returns the declarative base for given model class. diff --git a/sqlalchemy_utils/functions/sort_query.py b/sqlalchemy_utils/functions/sort_query.py index bd2656a..1fe25e6 100644 --- a/sqlalchemy_utils/functions/sort_query.py +++ b/sqlalchemy_utils/functions/sort_query.py @@ -1,15 +1,5 @@ -import sqlalchemy as sa -from sqlalchemy.orm.properties import ColumnProperty, SynonymProperty -from sqlalchemy.sql.expression import desc, asc, Label -from sqlalchemy.orm.util import AliasedInsp -from .orm import ( - get_attrs, - get_expr_attr, - get_hybrid_properties, - get_query_entity_by_alias, - get_query_entities, - query_labels, -) +from sqlalchemy.sql.expression import desc, asc +from .orm import get_query_descriptor class QuerySorterException(Exception): @@ -18,18 +8,11 @@ class QuerySorterException(Exception): class QuerySorter(object): def __init__(self, silent=True, separator='-'): - self.labels = [] self.separator = separator self.silent = silent def assign_order_by(self, entity, attr, func): - expr = None - if attr in self.labels: - expr = attr - else: - entity = get_query_entity_by_alias(self.query, entity) - if entity: - expr = self.order_by_attr(entity, attr) + expr = get_query_descriptor(self.query, entity, attr) if expr is not None: return self.query.order_by(func(expr)) @@ -39,30 +22,6 @@ class QuerySorter(object): ) return self.query - def order_by_attr(self, entity, attr): - properties = get_attrs(entity) - if attr in properties: - property_ = properties[attr] - - if isinstance(property_, ColumnProperty): - if isinstance(property_.columns[0], Label): - return getattr(entity, property_.key) - else: - return get_expr_attr(entity, property_) - elif isinstance(property_, SynonymProperty): - return get_expr_attr(entity, property_) - return - - mapper = sa.inspect(entity) - entity = mapper.entity - - if isinstance(mapper, AliasedInsp): - mapper = mapper.mapper - - for key in get_hybrid_properties(mapper).keys(): - if attr == key: - return getattr(entity, attr) - def parse_sort_arg(self, arg): if arg[0] == self.separator: func = desc @@ -79,7 +38,6 @@ class QuerySorter(object): def __call__(self, query, *args): self.query = query - self.labels = query_labels(query) for sort in args: if not sort: diff --git a/tests/functions/test_get_query_entities.py b/tests/functions/test_get_query_entities.py index 08dd5c4..398ee00 100644 --- a/tests/functions/test_get_query_entities.py +++ b/tests/functions/test_get_query_entities.py @@ -41,7 +41,7 @@ class TestGetQueryEntities(TestCase): def test_mapper(self): query = self.session.query(sa.inspect(self.TextItem)) - assert list(get_query_entities(query)) == [sa.inspect(self.TextItem)] + assert list(get_query_entities(query)) == [self.TextItem] def test_entity(self): query = self.session.query(self.TextItem) @@ -49,7 +49,7 @@ class TestGetQueryEntities(TestCase): def test_instrumented_attribute(self): query = self.session.query(self.TextItem.id) - assert list(get_query_entities(query)) == [sa.inspect(self.TextItem)] + assert list(get_query_entities(query)) == [self.TextItem] def test_column(self): query = self.session.query(self.TextItem.__table__.c.id) @@ -65,7 +65,7 @@ class TestGetQueryEntities(TestCase): self.BlogPost, self.BlogPost.id == self.TextItem.id ) assert list(get_query_entities(query)) == [ - self.TextItem, sa.inspect(self.BlogPost) + self.TextItem, self.BlogPost ] def test_joined_aliased_entity(self): @@ -74,9 +74,7 @@ class TestGetQueryEntities(TestCase): query = self.session.query(self.TextItem).join( alias, alias.id == self.TextItem.id ) - assert list(get_query_entities(query)) == [ - self.TextItem, sa.inspect(alias) - ] + assert list(get_query_entities(query)) == [self.TextItem, alias] def test_column_entity_with_label(self): query = self.session.query(self.Article.id.label('id')) diff --git a/tests/test_sort_query.py b/tests/test_sort_query.py index 0ca2c8a..1b8154a 100644 --- a/tests/test_sort_query.py +++ b/tests/test_sort_query.py @@ -156,6 +156,37 @@ class TestSortQuery(TestCase): query = sort_query(query, 'some_hybrid') assert_contains('ORDER BY article.name ASC', query) + def test_with_mapper_and_column_property(self): + class Apple(self.Base): + __tablename__ = 'apple' + id = sa.Column(sa.Integer, primary_key=True) + article_id = sa.Column(sa.Integer, sa.ForeignKey(self.Article.id)) + + self.Article.apples = sa.orm.relationship(Apple) + + self.Article.apple_count = sa.orm.column_property( + sa.select([sa.func.count(Apple.id)]) + .where(Apple.article_id == self.Article.id) + .correlate(self.Article.__table__) + .label('apple_count'), + deferred=True + ) + query = ( + self.session.query(sa.inspect(self.Article)) + .outerjoin(self.Article.apples) + .options( + sa.orm.undefer(self.Article.apple_count) + ) + .options(sa.orm.contains_eager(self.Article.apples)) + ) + query = sort_query(query, 'apple_count') + assert 'ORDER BY apple_count' in str(query) + + def test_table(self): + query = self.session.query(self.Article.__table__) + query = sort_query(query, 'name') + assert_contains('ORDER BY name', query) + class TestSortQueryRelationshipCounts(TestCase): """