Refactor get_query_entities

This commit is contained in:
Konsta Vesterinen
2014-08-07 11:31:03 +03:00
parent a975591c2e
commit e0851f961f
2 changed files with 35 additions and 31 deletions

View File

@@ -362,14 +362,17 @@ def get_query_entities(query):
Examples::
from sqlalchemy_utils import get_query_entities
query = session.query(Category)
query_entities(query) # [<Category>]
get_query_entities(query) # [<Category>]
query = session.query(Category.id)
query_entities(query) # [<Category>]
get_query_entities(query) # [<Category>]
This function also supports queries with joins.
@@ -379,31 +382,32 @@ def get_query_entities(query):
query = session.query(Category).join(Article)
query_entities(query) # [<Category>, <Article>]
get_query_entities(query) # [<Category>, <Article>]
.. versionchanged: 0.26.7
This function now returns a list instead of generator
:param query: SQLAlchemy Query object
"""
def get_expr(mixed):
if isinstance(mixed, (sa.orm.Mapper, AliasedInsp)):
return mixed
expr = mixed.expr
if isinstance(expr, sa.orm.attributes.InstrumentedAttribute):
expr = expr.parent
elif isinstance(expr, sa.Column):
expr = expr.table
elif isinstance(expr, sa.sql.expression.Label):
if mixed.entity_zero:
return mixed.entity_zero
else:
return expr
return expr
return [
get_expr(entity) for entity in
chain(query._entities, query._join_entities)
]
return list(
map(get_selectable, chain(query._entities, query._join_entities))
)
def get_selectable(mixed):
if isinstance(mixed, (sa.orm.Mapper, AliasedInsp)):
return mixed
expr = mixed.expr
if isinstance(expr, sa.orm.attributes.InstrumentedAttribute):
return expr.parent
elif isinstance(expr, sa.Column):
return expr.table
elif isinstance(expr, sa.sql.expression.Label):
if mixed.entity_zero:
return mixed.entity_zero
else:
return expr
return expr
def get_query_entity_by_alias(query, alias):

View File

@@ -41,30 +41,30 @@ class TestGetQueryEntities(TestCase):
def test_mapper(self):
query = self.session.query(sa.inspect(self.TextItem))
assert get_query_entities(query) == [sa.inspect(self.TextItem)]
assert list(get_query_entities(query)) == [sa.inspect(self.TextItem)]
def test_entity(self):
query = self.session.query(self.TextItem)
assert get_query_entities(query) == [self.TextItem]
assert list(get_query_entities(query)) == [self.TextItem]
def test_instrumented_attribute(self):
query = self.session.query(self.TextItem.id)
assert get_query_entities(query) == [sa.inspect(self.TextItem)]
assert list(get_query_entities(query)) == [sa.inspect(self.TextItem)]
def test_column(self):
query = self.session.query(self.TextItem.__table__.c.id)
assert get_query_entities(query) == [self.TextItem.__table__]
assert list(get_query_entities(query)) == [self.TextItem.__table__]
def test_aliased_selectable(self):
selectable = sa.orm.with_polymorphic(self.TextItem, [self.BlogPost])
query = self.session.query(selectable)
assert get_query_entities(query) == [selectable]
assert list(get_query_entities(query)) == [selectable]
def test_joined_entity(self):
query = self.session.query(self.TextItem).join(
self.BlogPost, self.BlogPost.id == self.TextItem.id
)
assert get_query_entities(query) == [
assert list(get_query_entities(query)) == [
self.TextItem, sa.inspect(self.BlogPost)
]
@@ -74,13 +74,13 @@ class TestGetQueryEntities(TestCase):
query = self.session.query(self.TextItem).join(
alias, alias.id == self.TextItem.id
)
assert get_query_entities(query) == [
assert list(get_query_entities(query)) == [
self.TextItem, sa.inspect(alias)
]
def test_column_entity_with_label(self):
query = self.session.query(self.Article.id.label('id'))
assert get_query_entities(query) == [sa.inspect(self.Article)]
assert list(get_query_entities(query)) == [sa.inspect(self.Article)]
def test_with_subquery(self):
number_of_articles = (
@@ -93,9 +93,9 @@ class TestGetQueryEntities(TestCase):
).label('number_of_articles')
query = self.session.query(self.Article, number_of_articles)
assert get_query_entities(query) == [self.Article, number_of_articles]
assert list(get_query_entities(query)) == [self.Article, number_of_articles]
def test_aliased_entity(self):
alias = sa.orm.aliased(self.Article)
query = self.session.query(alias)
assert get_query_entities(query) == [alias]
assert list(get_query_entities(query)) == [alias]