Fix query entities SA 1.0 compatibility
This commit is contained in:
@@ -471,28 +471,34 @@ def get_query_entities(query):
|
|||||||
|
|
||||||
:param query: SQLAlchemy Query object
|
:param query: SQLAlchemy Query object
|
||||||
"""
|
"""
|
||||||
|
exprs = [
|
||||||
|
d['expr']
|
||||||
|
if is_labeled_query(d['expr']) or isinstance(d['expr'], sa.Column)
|
||||||
|
else d['entity']
|
||||||
|
for d in query.column_descriptions
|
||||||
|
]
|
||||||
return [
|
return [
|
||||||
get_query_entity(entity) for entity in
|
get_query_entity(expr) for expr in exprs
|
||||||
chain(query._entities, query._join_entities)
|
] + [
|
||||||
|
get_query_entity(entity) for entity in query._join_entities
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
def get_query_entity(mixed):
|
def is_labeled_query(expr):
|
||||||
if hasattr(mixed, 'expr'):
|
return (
|
||||||
expr = mixed.expr
|
isinstance(expr, sa.sql.elements.Label) and
|
||||||
else:
|
isinstance(
|
||||||
expr = mixed
|
list(expr.base_columns)[0],
|
||||||
|
(sa.sql.selectable.Select, sa.sql.selectable.ScalarSelect)
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def get_query_entity(expr):
|
||||||
if isinstance(expr, sa.orm.attributes.InstrumentedAttribute):
|
if isinstance(expr, sa.orm.attributes.InstrumentedAttribute):
|
||||||
return expr.parent.class_
|
return expr.parent.class_
|
||||||
elif isinstance(expr, sa.Column):
|
elif isinstance(expr, sa.Column):
|
||||||
return expr.table
|
return expr.table
|
||||||
elif isinstance(expr, sa.sql.expression.Label):
|
|
||||||
if mixed.entity_zero:
|
|
||||||
return mixed.entity_zero
|
|
||||||
else:
|
|
||||||
return expr
|
|
||||||
elif isinstance(expr, sa.orm.Mapper):
|
|
||||||
return expr.class_
|
|
||||||
elif isinstance(expr, AliasedInsp):
|
elif isinstance(expr, AliasedInsp):
|
||||||
return expr.entity
|
return expr.entity
|
||||||
return expr
|
return expr
|
||||||
@@ -561,7 +567,7 @@ def get_descriptor(entity, attr):
|
|||||||
# Handle synonyms, relationship properties and hybrid
|
# Handle synonyms, relationship properties and hybrid
|
||||||
# properties
|
# properties
|
||||||
try:
|
try:
|
||||||
return getattr(entity, attr)
|
return getattr(mapper.class_, attr)
|
||||||
except AttributeError:
|
except AttributeError:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
@@ -41,31 +41,31 @@ class TestGetQueryEntities(TestCase):
|
|||||||
|
|
||||||
def test_mapper(self):
|
def test_mapper(self):
|
||||||
query = self.session.query(sa.inspect(self.TextItem))
|
query = self.session.query(sa.inspect(self.TextItem))
|
||||||
assert list(get_query_entities(query)) == [self.TextItem]
|
assert get_query_entities(query) == [self.TextItem]
|
||||||
|
|
||||||
def test_entity(self):
|
def test_entity(self):
|
||||||
query = self.session.query(self.TextItem)
|
query = self.session.query(self.TextItem)
|
||||||
assert list(get_query_entities(query)) == [self.TextItem]
|
assert get_query_entities(query) == [self.TextItem]
|
||||||
|
|
||||||
def test_instrumented_attribute(self):
|
def test_instrumented_attribute(self):
|
||||||
query = self.session.query(self.TextItem.id)
|
query = self.session.query(self.TextItem.id)
|
||||||
assert list(get_query_entities(query)) == [self.TextItem]
|
assert get_query_entities(query) == [self.TextItem]
|
||||||
|
|
||||||
def test_column(self):
|
def test_column(self):
|
||||||
query = self.session.query(self.TextItem.__table__.c.id)
|
query = self.session.query(self.TextItem.__table__.c.id)
|
||||||
assert list(get_query_entities(query)) == [self.TextItem.__table__]
|
assert get_query_entities(query) == [self.TextItem.__table__]
|
||||||
|
|
||||||
def test_aliased_selectable(self):
|
def test_aliased_selectable(self):
|
||||||
selectable = sa.orm.with_polymorphic(self.TextItem, [self.BlogPost])
|
selectable = sa.orm.with_polymorphic(self.TextItem, [self.BlogPost])
|
||||||
query = self.session.query(selectable)
|
query = self.session.query(selectable)
|
||||||
assert list(get_query_entities(query)) == [selectable]
|
assert get_query_entities(query) == [selectable]
|
||||||
|
|
||||||
def test_joined_entity(self):
|
def test_joined_entity(self):
|
||||||
query = self.session.query(self.TextItem).join(
|
query = self.session.query(self.TextItem).join(
|
||||||
self.BlogPost, self.BlogPost.id == self.TextItem.id
|
self.BlogPost, self.BlogPost.id == self.TextItem.id
|
||||||
)
|
)
|
||||||
assert list(get_query_entities(query)) == [
|
assert get_query_entities(query) == [
|
||||||
self.TextItem, self.BlogPost
|
self.TextItem, sa.inspect(self.BlogPost)
|
||||||
]
|
]
|
||||||
|
|
||||||
def test_joined_aliased_entity(self):
|
def test_joined_aliased_entity(self):
|
||||||
@@ -74,11 +74,11 @@ class TestGetQueryEntities(TestCase):
|
|||||||
query = self.session.query(self.TextItem).join(
|
query = self.session.query(self.TextItem).join(
|
||||||
alias, alias.id == self.TextItem.id
|
alias, alias.id == self.TextItem.id
|
||||||
)
|
)
|
||||||
assert list(get_query_entities(query)) == [self.TextItem, alias]
|
assert get_query_entities(query) == [self.TextItem, alias]
|
||||||
|
|
||||||
def test_column_entity_with_label(self):
|
def test_column_entity_with_label(self):
|
||||||
query = self.session.query(self.Article.id.label('id'))
|
query = self.session.query(self.Article.id.label('id'))
|
||||||
assert list(get_query_entities(query)) == [sa.inspect(self.Article)]
|
assert get_query_entities(query) == [self.Article]
|
||||||
|
|
||||||
def test_with_subquery(self):
|
def test_with_subquery(self):
|
||||||
number_of_articles = (
|
number_of_articles = (
|
||||||
@@ -91,7 +91,7 @@ class TestGetQueryEntities(TestCase):
|
|||||||
).label('number_of_articles')
|
).label('number_of_articles')
|
||||||
|
|
||||||
query = self.session.query(self.Article, number_of_articles)
|
query = self.session.query(self.Article, number_of_articles)
|
||||||
assert list(get_query_entities(query)) == [
|
assert get_query_entities(query) == [
|
||||||
self.Article,
|
self.Article,
|
||||||
number_of_articles
|
number_of_articles
|
||||||
]
|
]
|
||||||
@@ -99,4 +99,4 @@ class TestGetQueryEntities(TestCase):
|
|||||||
def test_aliased_entity(self):
|
def test_aliased_entity(self):
|
||||||
alias = sa.orm.aliased(self.Article)
|
alias = sa.orm.aliased(self.Article)
|
||||||
query = self.session.query(alias)
|
query = self.session.query(alias)
|
||||||
assert list(get_query_entities(query)) == [alias]
|
assert get_query_entities(query) == [alias]
|
||||||
|
Reference in New Issue
Block a user