Refactor get_query_entities
This commit is contained in:
@@ -362,14 +362,17 @@ def get_query_entities(query):
|
||||
Examples::
|
||||
|
||||
|
||||
from sqlalchemy_utils import get_query_entities
|
||||
|
||||
|
||||
query = session.query(Category)
|
||||
|
||||
query_entities(query) # [<Category>]
|
||||
get_query_entities(query) # [<Category>]
|
||||
|
||||
|
||||
query = session.query(Category.id)
|
||||
|
||||
query_entities(query) # [<Category>]
|
||||
get_query_entities(query) # [<Category>]
|
||||
|
||||
|
||||
This function also supports queries with joins.
|
||||
@@ -379,31 +382,32 @@ def get_query_entities(query):
|
||||
|
||||
query = session.query(Category).join(Article)
|
||||
|
||||
query_entities(query) # [<Category>, <Article>]
|
||||
get_query_entities(query) # [<Category>, <Article>]
|
||||
|
||||
.. versionchanged: 0.26.7
|
||||
This function now returns a list instead of generator
|
||||
|
||||
:param query: SQLAlchemy Query object
|
||||
"""
|
||||
def get_expr(mixed):
|
||||
if isinstance(mixed, (sa.orm.Mapper, AliasedInsp)):
|
||||
return mixed
|
||||
expr = mixed.expr
|
||||
if isinstance(expr, sa.orm.attributes.InstrumentedAttribute):
|
||||
expr = expr.parent
|
||||
elif isinstance(expr, sa.Column):
|
||||
expr = expr.table
|
||||
elif isinstance(expr, sa.sql.expression.Label):
|
||||
if mixed.entity_zero:
|
||||
return mixed.entity_zero
|
||||
else:
|
||||
return expr
|
||||
return expr
|
||||
return [
|
||||
get_expr(entity) for entity in
|
||||
chain(query._entities, query._join_entities)
|
||||
]
|
||||
return list(
|
||||
map(get_selectable, chain(query._entities, query._join_entities))
|
||||
)
|
||||
|
||||
|
||||
def get_selectable(mixed):
|
||||
if isinstance(mixed, (sa.orm.Mapper, AliasedInsp)):
|
||||
return mixed
|
||||
expr = mixed.expr
|
||||
if isinstance(expr, sa.orm.attributes.InstrumentedAttribute):
|
||||
return expr.parent
|
||||
elif isinstance(expr, sa.Column):
|
||||
return expr.table
|
||||
elif isinstance(expr, sa.sql.expression.Label):
|
||||
if mixed.entity_zero:
|
||||
return mixed.entity_zero
|
||||
else:
|
||||
return expr
|
||||
return expr
|
||||
|
||||
|
||||
def get_query_entity_by_alias(query, alias):
|
||||
|
@@ -41,30 +41,30 @@ class TestGetQueryEntities(TestCase):
|
||||
|
||||
def test_mapper(self):
|
||||
query = self.session.query(sa.inspect(self.TextItem))
|
||||
assert get_query_entities(query) == [sa.inspect(self.TextItem)]
|
||||
assert list(get_query_entities(query)) == [sa.inspect(self.TextItem)]
|
||||
|
||||
def test_entity(self):
|
||||
query = self.session.query(self.TextItem)
|
||||
assert get_query_entities(query) == [self.TextItem]
|
||||
assert list(get_query_entities(query)) == [self.TextItem]
|
||||
|
||||
def test_instrumented_attribute(self):
|
||||
query = self.session.query(self.TextItem.id)
|
||||
assert get_query_entities(query) == [sa.inspect(self.TextItem)]
|
||||
assert list(get_query_entities(query)) == [sa.inspect(self.TextItem)]
|
||||
|
||||
def test_column(self):
|
||||
query = self.session.query(self.TextItem.__table__.c.id)
|
||||
assert get_query_entities(query) == [self.TextItem.__table__]
|
||||
assert list(get_query_entities(query)) == [self.TextItem.__table__]
|
||||
|
||||
def test_aliased_selectable(self):
|
||||
selectable = sa.orm.with_polymorphic(self.TextItem, [self.BlogPost])
|
||||
query = self.session.query(selectable)
|
||||
assert get_query_entities(query) == [selectable]
|
||||
assert list(get_query_entities(query)) == [selectable]
|
||||
|
||||
def test_joined_entity(self):
|
||||
query = self.session.query(self.TextItem).join(
|
||||
self.BlogPost, self.BlogPost.id == self.TextItem.id
|
||||
)
|
||||
assert get_query_entities(query) == [
|
||||
assert list(get_query_entities(query)) == [
|
||||
self.TextItem, sa.inspect(self.BlogPost)
|
||||
]
|
||||
|
||||
@@ -74,13 +74,13 @@ class TestGetQueryEntities(TestCase):
|
||||
query = self.session.query(self.TextItem).join(
|
||||
alias, alias.id == self.TextItem.id
|
||||
)
|
||||
assert get_query_entities(query) == [
|
||||
assert list(get_query_entities(query)) == [
|
||||
self.TextItem, sa.inspect(alias)
|
||||
]
|
||||
|
||||
def test_column_entity_with_label(self):
|
||||
query = self.session.query(self.Article.id.label('id'))
|
||||
assert get_query_entities(query) == [sa.inspect(self.Article)]
|
||||
assert list(get_query_entities(query)) == [sa.inspect(self.Article)]
|
||||
|
||||
def test_with_subquery(self):
|
||||
number_of_articles = (
|
||||
@@ -93,9 +93,9 @@ class TestGetQueryEntities(TestCase):
|
||||
).label('number_of_articles')
|
||||
|
||||
query = self.session.query(self.Article, number_of_articles)
|
||||
assert get_query_entities(query) == [self.Article, number_of_articles]
|
||||
assert list(get_query_entities(query)) == [self.Article, number_of_articles]
|
||||
|
||||
def test_aliased_entity(self):
|
||||
alias = sa.orm.aliased(self.Article)
|
||||
query = self.session.query(alias)
|
||||
assert get_query_entities(query) == [alias]
|
||||
assert list(get_query_entities(query)) == [alias]
|
||||
|
Reference in New Issue
Block a user