Refactor get_query_entities
This commit is contained in:
@@ -362,14 +362,17 @@ def get_query_entities(query):
|
|||||||
Examples::
|
Examples::
|
||||||
|
|
||||||
|
|
||||||
|
from sqlalchemy_utils import get_query_entities
|
||||||
|
|
||||||
|
|
||||||
query = session.query(Category)
|
query = session.query(Category)
|
||||||
|
|
||||||
query_entities(query) # [<Category>]
|
get_query_entities(query) # [<Category>]
|
||||||
|
|
||||||
|
|
||||||
query = session.query(Category.id)
|
query = session.query(Category.id)
|
||||||
|
|
||||||
query_entities(query) # [<Category>]
|
get_query_entities(query) # [<Category>]
|
||||||
|
|
||||||
|
|
||||||
This function also supports queries with joins.
|
This function also supports queries with joins.
|
||||||
@@ -379,31 +382,32 @@ def get_query_entities(query):
|
|||||||
|
|
||||||
query = session.query(Category).join(Article)
|
query = session.query(Category).join(Article)
|
||||||
|
|
||||||
query_entities(query) # [<Category>, <Article>]
|
get_query_entities(query) # [<Category>, <Article>]
|
||||||
|
|
||||||
.. versionchanged: 0.26.7
|
.. versionchanged: 0.26.7
|
||||||
This function now returns a list instead of generator
|
This function now returns a list instead of generator
|
||||||
|
|
||||||
:param query: SQLAlchemy Query object
|
:param query: SQLAlchemy Query object
|
||||||
"""
|
"""
|
||||||
def get_expr(mixed):
|
return list(
|
||||||
if isinstance(mixed, (sa.orm.Mapper, AliasedInsp)):
|
map(get_selectable, chain(query._entities, query._join_entities))
|
||||||
return mixed
|
)
|
||||||
expr = mixed.expr
|
|
||||||
if isinstance(expr, sa.orm.attributes.InstrumentedAttribute):
|
|
||||||
expr = expr.parent
|
def get_selectable(mixed):
|
||||||
elif isinstance(expr, sa.Column):
|
if isinstance(mixed, (sa.orm.Mapper, AliasedInsp)):
|
||||||
expr = expr.table
|
return mixed
|
||||||
elif isinstance(expr, sa.sql.expression.Label):
|
expr = mixed.expr
|
||||||
if mixed.entity_zero:
|
if isinstance(expr, sa.orm.attributes.InstrumentedAttribute):
|
||||||
return mixed.entity_zero
|
return expr.parent
|
||||||
else:
|
elif isinstance(expr, sa.Column):
|
||||||
return expr
|
return expr.table
|
||||||
return expr
|
elif isinstance(expr, sa.sql.expression.Label):
|
||||||
return [
|
if mixed.entity_zero:
|
||||||
get_expr(entity) for entity in
|
return mixed.entity_zero
|
||||||
chain(query._entities, query._join_entities)
|
else:
|
||||||
]
|
return expr
|
||||||
|
return expr
|
||||||
|
|
||||||
|
|
||||||
def get_query_entity_by_alias(query, alias):
|
def get_query_entity_by_alias(query, alias):
|
||||||
|
@@ -41,30 +41,30 @@ class TestGetQueryEntities(TestCase):
|
|||||||
|
|
||||||
def test_mapper(self):
|
def test_mapper(self):
|
||||||
query = self.session.query(sa.inspect(self.TextItem))
|
query = self.session.query(sa.inspect(self.TextItem))
|
||||||
assert get_query_entities(query) == [sa.inspect(self.TextItem)]
|
assert list(get_query_entities(query)) == [sa.inspect(self.TextItem)]
|
||||||
|
|
||||||
def test_entity(self):
|
def test_entity(self):
|
||||||
query = self.session.query(self.TextItem)
|
query = self.session.query(self.TextItem)
|
||||||
assert get_query_entities(query) == [self.TextItem]
|
assert list(get_query_entities(query)) == [self.TextItem]
|
||||||
|
|
||||||
def test_instrumented_attribute(self):
|
def test_instrumented_attribute(self):
|
||||||
query = self.session.query(self.TextItem.id)
|
query = self.session.query(self.TextItem.id)
|
||||||
assert get_query_entities(query) == [sa.inspect(self.TextItem)]
|
assert list(get_query_entities(query)) == [sa.inspect(self.TextItem)]
|
||||||
|
|
||||||
def test_column(self):
|
def test_column(self):
|
||||||
query = self.session.query(self.TextItem.__table__.c.id)
|
query = self.session.query(self.TextItem.__table__.c.id)
|
||||||
assert get_query_entities(query) == [self.TextItem.__table__]
|
assert list(get_query_entities(query)) == [self.TextItem.__table__]
|
||||||
|
|
||||||
def test_aliased_selectable(self):
|
def test_aliased_selectable(self):
|
||||||
selectable = sa.orm.with_polymorphic(self.TextItem, [self.BlogPost])
|
selectable = sa.orm.with_polymorphic(self.TextItem, [self.BlogPost])
|
||||||
query = self.session.query(selectable)
|
query = self.session.query(selectable)
|
||||||
assert get_query_entities(query) == [selectable]
|
assert list(get_query_entities(query)) == [selectable]
|
||||||
|
|
||||||
def test_joined_entity(self):
|
def test_joined_entity(self):
|
||||||
query = self.session.query(self.TextItem).join(
|
query = self.session.query(self.TextItem).join(
|
||||||
self.BlogPost, self.BlogPost.id == self.TextItem.id
|
self.BlogPost, self.BlogPost.id == self.TextItem.id
|
||||||
)
|
)
|
||||||
assert get_query_entities(query) == [
|
assert list(get_query_entities(query)) == [
|
||||||
self.TextItem, sa.inspect(self.BlogPost)
|
self.TextItem, sa.inspect(self.BlogPost)
|
||||||
]
|
]
|
||||||
|
|
||||||
@@ -74,13 +74,13 @@ class TestGetQueryEntities(TestCase):
|
|||||||
query = self.session.query(self.TextItem).join(
|
query = self.session.query(self.TextItem).join(
|
||||||
alias, alias.id == self.TextItem.id
|
alias, alias.id == self.TextItem.id
|
||||||
)
|
)
|
||||||
assert get_query_entities(query) == [
|
assert list(get_query_entities(query)) == [
|
||||||
self.TextItem, sa.inspect(alias)
|
self.TextItem, sa.inspect(alias)
|
||||||
]
|
]
|
||||||
|
|
||||||
def test_column_entity_with_label(self):
|
def test_column_entity_with_label(self):
|
||||||
query = self.session.query(self.Article.id.label('id'))
|
query = self.session.query(self.Article.id.label('id'))
|
||||||
assert get_query_entities(query) == [sa.inspect(self.Article)]
|
assert list(get_query_entities(query)) == [sa.inspect(self.Article)]
|
||||||
|
|
||||||
def test_with_subquery(self):
|
def test_with_subquery(self):
|
||||||
number_of_articles = (
|
number_of_articles = (
|
||||||
@@ -93,9 +93,9 @@ class TestGetQueryEntities(TestCase):
|
|||||||
).label('number_of_articles')
|
).label('number_of_articles')
|
||||||
|
|
||||||
query = self.session.query(self.Article, number_of_articles)
|
query = self.session.query(self.Article, number_of_articles)
|
||||||
assert get_query_entities(query) == [self.Article, number_of_articles]
|
assert list(get_query_entities(query)) == [self.Article, number_of_articles]
|
||||||
|
|
||||||
def test_aliased_entity(self):
|
def test_aliased_entity(self):
|
||||||
alias = sa.orm.aliased(self.Article)
|
alias = sa.orm.aliased(self.Article)
|
||||||
query = self.session.query(alias)
|
query = self.session.query(alias)
|
||||||
assert get_query_entities(query) == [alias]
|
assert list(get_query_entities(query)) == [alias]
|
||||||
|
Reference in New Issue
Block a user