Refactor get_query_entities

This commit is contained in:
Konsta Vesterinen
2014-08-07 11:31:03 +03:00
parent a975591c2e
commit e0851f961f
2 changed files with 35 additions and 31 deletions

View File

@@ -362,14 +362,17 @@ def get_query_entities(query):
Examples:: Examples::
from sqlalchemy_utils import get_query_entities
query = session.query(Category) query = session.query(Category)
query_entities(query) # [<Category>] get_query_entities(query) # [<Category>]
query = session.query(Category.id) query = session.query(Category.id)
query_entities(query) # [<Category>] get_query_entities(query) # [<Category>]
This function also supports queries with joins. This function also supports queries with joins.
@@ -379,31 +382,32 @@ def get_query_entities(query):
query = session.query(Category).join(Article) query = session.query(Category).join(Article)
query_entities(query) # [<Category>, <Article>] get_query_entities(query) # [<Category>, <Article>]
.. versionchanged: 0.26.7 .. versionchanged: 0.26.7
This function now returns a list instead of generator This function now returns a list instead of generator
:param query: SQLAlchemy Query object :param query: SQLAlchemy Query object
""" """
def get_expr(mixed): return list(
if isinstance(mixed, (sa.orm.Mapper, AliasedInsp)): map(get_selectable, chain(query._entities, query._join_entities))
return mixed )
expr = mixed.expr
if isinstance(expr, sa.orm.attributes.InstrumentedAttribute):
expr = expr.parent def get_selectable(mixed):
elif isinstance(expr, sa.Column): if isinstance(mixed, (sa.orm.Mapper, AliasedInsp)):
expr = expr.table return mixed
elif isinstance(expr, sa.sql.expression.Label): expr = mixed.expr
if mixed.entity_zero: if isinstance(expr, sa.orm.attributes.InstrumentedAttribute):
return mixed.entity_zero return expr.parent
else: elif isinstance(expr, sa.Column):
return expr return expr.table
return expr elif isinstance(expr, sa.sql.expression.Label):
return [ if mixed.entity_zero:
get_expr(entity) for entity in return mixed.entity_zero
chain(query._entities, query._join_entities) else:
] return expr
return expr
def get_query_entity_by_alias(query, alias): def get_query_entity_by_alias(query, alias):

View File

@@ -41,30 +41,30 @@ class TestGetQueryEntities(TestCase):
def test_mapper(self): def test_mapper(self):
query = self.session.query(sa.inspect(self.TextItem)) query = self.session.query(sa.inspect(self.TextItem))
assert get_query_entities(query) == [sa.inspect(self.TextItem)] assert list(get_query_entities(query)) == [sa.inspect(self.TextItem)]
def test_entity(self): def test_entity(self):
query = self.session.query(self.TextItem) query = self.session.query(self.TextItem)
assert get_query_entities(query) == [self.TextItem] assert list(get_query_entities(query)) == [self.TextItem]
def test_instrumented_attribute(self): def test_instrumented_attribute(self):
query = self.session.query(self.TextItem.id) query = self.session.query(self.TextItem.id)
assert get_query_entities(query) == [sa.inspect(self.TextItem)] assert list(get_query_entities(query)) == [sa.inspect(self.TextItem)]
def test_column(self): def test_column(self):
query = self.session.query(self.TextItem.__table__.c.id) query = self.session.query(self.TextItem.__table__.c.id)
assert get_query_entities(query) == [self.TextItem.__table__] assert list(get_query_entities(query)) == [self.TextItem.__table__]
def test_aliased_selectable(self): def test_aliased_selectable(self):
selectable = sa.orm.with_polymorphic(self.TextItem, [self.BlogPost]) selectable = sa.orm.with_polymorphic(self.TextItem, [self.BlogPost])
query = self.session.query(selectable) query = self.session.query(selectable)
assert get_query_entities(query) == [selectable] assert list(get_query_entities(query)) == [selectable]
def test_joined_entity(self): def test_joined_entity(self):
query = self.session.query(self.TextItem).join( query = self.session.query(self.TextItem).join(
self.BlogPost, self.BlogPost.id == self.TextItem.id self.BlogPost, self.BlogPost.id == self.TextItem.id
) )
assert get_query_entities(query) == [ assert list(get_query_entities(query)) == [
self.TextItem, sa.inspect(self.BlogPost) self.TextItem, sa.inspect(self.BlogPost)
] ]
@@ -74,13 +74,13 @@ class TestGetQueryEntities(TestCase):
query = self.session.query(self.TextItem).join( query = self.session.query(self.TextItem).join(
alias, alias.id == self.TextItem.id alias, alias.id == self.TextItem.id
) )
assert get_query_entities(query) == [ assert list(get_query_entities(query)) == [
self.TextItem, sa.inspect(alias) self.TextItem, sa.inspect(alias)
] ]
def test_column_entity_with_label(self): def test_column_entity_with_label(self):
query = self.session.query(self.Article.id.label('id')) query = self.session.query(self.Article.id.label('id'))
assert get_query_entities(query) == [sa.inspect(self.Article)] assert list(get_query_entities(query)) == [sa.inspect(self.Article)]
def test_with_subquery(self): def test_with_subquery(self):
number_of_articles = ( number_of_articles = (
@@ -93,9 +93,9 @@ class TestGetQueryEntities(TestCase):
).label('number_of_articles') ).label('number_of_articles')
query = self.session.query(self.Article, number_of_articles) query = self.session.query(self.Article, number_of_articles)
assert get_query_entities(query) == [self.Article, number_of_articles] assert list(get_query_entities(query)) == [self.Article, number_of_articles]
def test_aliased_entity(self): def test_aliased_entity(self):
alias = sa.orm.aliased(self.Article) alias = sa.orm.aliased(self.Article)
query = self.session.query(alias) query = self.session.query(alias)
assert get_query_entities(query) == [alias] assert list(get_query_entities(query)) == [alias]