diff --git a/sqlalchemy_utils/functions/orm.py b/sqlalchemy_utils/functions/orm.py index df83b92..7157815 100644 --- a/sqlalchemy_utils/functions/orm.py +++ b/sqlalchemy_utils/functions/orm.py @@ -362,14 +362,17 @@ def get_query_entities(query): Examples:: + from sqlalchemy_utils import get_query_entities + + query = session.query(Category) - query_entities(query) # [] + get_query_entities(query) # [] query = session.query(Category.id) - query_entities(query) # [] + get_query_entities(query) # [] This function also supports queries with joins. @@ -379,31 +382,32 @@ def get_query_entities(query): query = session.query(Category).join(Article) - query_entities(query) # [,
] + get_query_entities(query) # [,
] .. versionchanged: 0.26.7 This function now returns a list instead of generator :param query: SQLAlchemy Query object """ - def get_expr(mixed): - if isinstance(mixed, (sa.orm.Mapper, AliasedInsp)): - return mixed - expr = mixed.expr - if isinstance(expr, sa.orm.attributes.InstrumentedAttribute): - expr = expr.parent - elif isinstance(expr, sa.Column): - expr = expr.table - elif isinstance(expr, sa.sql.expression.Label): - if mixed.entity_zero: - return mixed.entity_zero - else: - return expr - return expr - return [ - get_expr(entity) for entity in - chain(query._entities, query._join_entities) - ] + return list( + map(get_selectable, chain(query._entities, query._join_entities)) + ) + + +def get_selectable(mixed): + if isinstance(mixed, (sa.orm.Mapper, AliasedInsp)): + return mixed + expr = mixed.expr + if isinstance(expr, sa.orm.attributes.InstrumentedAttribute): + return expr.parent + elif isinstance(expr, sa.Column): + return expr.table + elif isinstance(expr, sa.sql.expression.Label): + if mixed.entity_zero: + return mixed.entity_zero + else: + return expr + return expr def get_query_entity_by_alias(query, alias): diff --git a/tests/functions/test_get_query_entities.py b/tests/functions/test_get_query_entities.py index 45b9b4a..08dd5c4 100644 --- a/tests/functions/test_get_query_entities.py +++ b/tests/functions/test_get_query_entities.py @@ -41,30 +41,30 @@ class TestGetQueryEntities(TestCase): def test_mapper(self): query = self.session.query(sa.inspect(self.TextItem)) - assert get_query_entities(query) == [sa.inspect(self.TextItem)] + assert list(get_query_entities(query)) == [sa.inspect(self.TextItem)] def test_entity(self): query = self.session.query(self.TextItem) - assert get_query_entities(query) == [self.TextItem] + assert list(get_query_entities(query)) == [self.TextItem] def test_instrumented_attribute(self): query = self.session.query(self.TextItem.id) - assert get_query_entities(query) == [sa.inspect(self.TextItem)] + assert list(get_query_entities(query)) == [sa.inspect(self.TextItem)] def test_column(self): query = self.session.query(self.TextItem.__table__.c.id) - assert get_query_entities(query) == [self.TextItem.__table__] + assert list(get_query_entities(query)) == [self.TextItem.__table__] def test_aliased_selectable(self): selectable = sa.orm.with_polymorphic(self.TextItem, [self.BlogPost]) query = self.session.query(selectable) - assert get_query_entities(query) == [selectable] + assert list(get_query_entities(query)) == [selectable] def test_joined_entity(self): query = self.session.query(self.TextItem).join( self.BlogPost, self.BlogPost.id == self.TextItem.id ) - assert get_query_entities(query) == [ + assert list(get_query_entities(query)) == [ self.TextItem, sa.inspect(self.BlogPost) ] @@ -74,13 +74,13 @@ class TestGetQueryEntities(TestCase): query = self.session.query(self.TextItem).join( alias, alias.id == self.TextItem.id ) - assert get_query_entities(query) == [ + assert list(get_query_entities(query)) == [ self.TextItem, sa.inspect(alias) ] def test_column_entity_with_label(self): query = self.session.query(self.Article.id.label('id')) - assert get_query_entities(query) == [sa.inspect(self.Article)] + assert list(get_query_entities(query)) == [sa.inspect(self.Article)] def test_with_subquery(self): number_of_articles = ( @@ -93,9 +93,9 @@ class TestGetQueryEntities(TestCase): ).label('number_of_articles') query = self.session.query(self.Article, number_of_articles) - assert get_query_entities(query) == [self.Article, number_of_articles] + assert list(get_query_entities(query)) == [self.Article, number_of_articles] def test_aliased_entity(self): alias = sa.orm.aliased(self.Article) query = self.session.query(alias) - assert get_query_entities(query) == [alias] + assert list(get_query_entities(query)) == [alias]