diff --git a/sqlalchemy_utils/functions/orm.py b/sqlalchemy_utils/functions/orm.py index cb8a1bb..ef8ac7d 100644 --- a/sqlalchemy_utils/functions/orm.py +++ b/sqlalchemy_utils/functions/orm.py @@ -471,28 +471,34 @@ def get_query_entities(query): :param query: SQLAlchemy Query object """ + exprs = [ + d['expr'] + if is_labeled_query(d['expr']) or isinstance(d['expr'], sa.Column) + else d['entity'] + for d in query.column_descriptions + ] return [ - get_query_entity(entity) for entity in - chain(query._entities, query._join_entities) + get_query_entity(expr) for expr in exprs + ] + [ + get_query_entity(entity) for entity in query._join_entities ] -def get_query_entity(mixed): - if hasattr(mixed, 'expr'): - expr = mixed.expr - else: - expr = mixed +def is_labeled_query(expr): + return ( + isinstance(expr, sa.sql.elements.Label) and + isinstance( + list(expr.base_columns)[0], + (sa.sql.selectable.Select, sa.sql.selectable.ScalarSelect) + ) + ) + + +def get_query_entity(expr): if isinstance(expr, sa.orm.attributes.InstrumentedAttribute): return expr.parent.class_ elif isinstance(expr, sa.Column): return expr.table - elif isinstance(expr, sa.sql.expression.Label): - if mixed.entity_zero: - return mixed.entity_zero - else: - return expr - elif isinstance(expr, sa.orm.Mapper): - return expr.class_ elif isinstance(expr, AliasedInsp): return expr.entity return expr @@ -561,7 +567,7 @@ def get_descriptor(entity, attr): # Handle synonyms, relationship properties and hybrid # properties try: - return getattr(entity, attr) + return getattr(mapper.class_, attr) except AttributeError: pass diff --git a/tests/functions/test_get_query_entities.py b/tests/functions/test_get_query_entities.py index e9a37e1..f1ae7c9 100644 --- a/tests/functions/test_get_query_entities.py +++ b/tests/functions/test_get_query_entities.py @@ -41,31 +41,31 @@ class TestGetQueryEntities(TestCase): def test_mapper(self): query = self.session.query(sa.inspect(self.TextItem)) - assert list(get_query_entities(query)) == [self.TextItem] + assert get_query_entities(query) == [self.TextItem] def test_entity(self): query = self.session.query(self.TextItem) - assert list(get_query_entities(query)) == [self.TextItem] + assert get_query_entities(query) == [self.TextItem] def test_instrumented_attribute(self): query = self.session.query(self.TextItem.id) - assert list(get_query_entities(query)) == [self.TextItem] + assert get_query_entities(query) == [self.TextItem] def test_column(self): query = self.session.query(self.TextItem.__table__.c.id) - assert list(get_query_entities(query)) == [self.TextItem.__table__] + assert get_query_entities(query) == [self.TextItem.__table__] def test_aliased_selectable(self): selectable = sa.orm.with_polymorphic(self.TextItem, [self.BlogPost]) query = self.session.query(selectable) - assert list(get_query_entities(query)) == [selectable] + assert get_query_entities(query) == [selectable] def test_joined_entity(self): query = self.session.query(self.TextItem).join( self.BlogPost, self.BlogPost.id == self.TextItem.id ) - assert list(get_query_entities(query)) == [ - self.TextItem, self.BlogPost + assert get_query_entities(query) == [ + self.TextItem, sa.inspect(self.BlogPost) ] def test_joined_aliased_entity(self): @@ -74,11 +74,11 @@ class TestGetQueryEntities(TestCase): query = self.session.query(self.TextItem).join( alias, alias.id == self.TextItem.id ) - assert list(get_query_entities(query)) == [self.TextItem, alias] + assert get_query_entities(query) == [self.TextItem, alias] def test_column_entity_with_label(self): query = self.session.query(self.Article.id.label('id')) - assert list(get_query_entities(query)) == [sa.inspect(self.Article)] + assert get_query_entities(query) == [self.Article] def test_with_subquery(self): number_of_articles = ( @@ -91,7 +91,7 @@ class TestGetQueryEntities(TestCase): ).label('number_of_articles') query = self.session.query(self.Article, number_of_articles) - assert list(get_query_entities(query)) == [ + assert get_query_entities(query) == [ self.Article, number_of_articles ] @@ -99,4 +99,4 @@ class TestGetQueryEntities(TestCase): def test_aliased_entity(self): alias = sa.orm.aliased(self.Article) query = self.session.query(alias) - assert list(get_query_entities(query)) == [alias] + assert get_query_entities(query) == [alias]