Fix query entities SA 1.0 compatibility
This commit is contained in:
@@ -471,28 +471,34 @@ def get_query_entities(query):
|
||||
|
||||
:param query: SQLAlchemy Query object
|
||||
"""
|
||||
exprs = [
|
||||
d['expr']
|
||||
if is_labeled_query(d['expr']) or isinstance(d['expr'], sa.Column)
|
||||
else d['entity']
|
||||
for d in query.column_descriptions
|
||||
]
|
||||
return [
|
||||
get_query_entity(entity) for entity in
|
||||
chain(query._entities, query._join_entities)
|
||||
get_query_entity(expr) for expr in exprs
|
||||
] + [
|
||||
get_query_entity(entity) for entity in query._join_entities
|
||||
]
|
||||
|
||||
|
||||
def get_query_entity(mixed):
|
||||
if hasattr(mixed, 'expr'):
|
||||
expr = mixed.expr
|
||||
else:
|
||||
expr = mixed
|
||||
def is_labeled_query(expr):
|
||||
return (
|
||||
isinstance(expr, sa.sql.elements.Label) and
|
||||
isinstance(
|
||||
list(expr.base_columns)[0],
|
||||
(sa.sql.selectable.Select, sa.sql.selectable.ScalarSelect)
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def get_query_entity(expr):
|
||||
if isinstance(expr, sa.orm.attributes.InstrumentedAttribute):
|
||||
return expr.parent.class_
|
||||
elif isinstance(expr, sa.Column):
|
||||
return expr.table
|
||||
elif isinstance(expr, sa.sql.expression.Label):
|
||||
if mixed.entity_zero:
|
||||
return mixed.entity_zero
|
||||
else:
|
||||
return expr
|
||||
elif isinstance(expr, sa.orm.Mapper):
|
||||
return expr.class_
|
||||
elif isinstance(expr, AliasedInsp):
|
||||
return expr.entity
|
||||
return expr
|
||||
@@ -561,7 +567,7 @@ def get_descriptor(entity, attr):
|
||||
# Handle synonyms, relationship properties and hybrid
|
||||
# properties
|
||||
try:
|
||||
return getattr(entity, attr)
|
||||
return getattr(mapper.class_, attr)
|
||||
except AttributeError:
|
||||
pass
|
||||
|
||||
|
@@ -41,31 +41,31 @@ class TestGetQueryEntities(TestCase):
|
||||
|
||||
def test_mapper(self):
|
||||
query = self.session.query(sa.inspect(self.TextItem))
|
||||
assert list(get_query_entities(query)) == [self.TextItem]
|
||||
assert get_query_entities(query) == [self.TextItem]
|
||||
|
||||
def test_entity(self):
|
||||
query = self.session.query(self.TextItem)
|
||||
assert list(get_query_entities(query)) == [self.TextItem]
|
||||
assert get_query_entities(query) == [self.TextItem]
|
||||
|
||||
def test_instrumented_attribute(self):
|
||||
query = self.session.query(self.TextItem.id)
|
||||
assert list(get_query_entities(query)) == [self.TextItem]
|
||||
assert get_query_entities(query) == [self.TextItem]
|
||||
|
||||
def test_column(self):
|
||||
query = self.session.query(self.TextItem.__table__.c.id)
|
||||
assert list(get_query_entities(query)) == [self.TextItem.__table__]
|
||||
assert get_query_entities(query) == [self.TextItem.__table__]
|
||||
|
||||
def test_aliased_selectable(self):
|
||||
selectable = sa.orm.with_polymorphic(self.TextItem, [self.BlogPost])
|
||||
query = self.session.query(selectable)
|
||||
assert list(get_query_entities(query)) == [selectable]
|
||||
assert get_query_entities(query) == [selectable]
|
||||
|
||||
def test_joined_entity(self):
|
||||
query = self.session.query(self.TextItem).join(
|
||||
self.BlogPost, self.BlogPost.id == self.TextItem.id
|
||||
)
|
||||
assert list(get_query_entities(query)) == [
|
||||
self.TextItem, self.BlogPost
|
||||
assert get_query_entities(query) == [
|
||||
self.TextItem, sa.inspect(self.BlogPost)
|
||||
]
|
||||
|
||||
def test_joined_aliased_entity(self):
|
||||
@@ -74,11 +74,11 @@ class TestGetQueryEntities(TestCase):
|
||||
query = self.session.query(self.TextItem).join(
|
||||
alias, alias.id == self.TextItem.id
|
||||
)
|
||||
assert list(get_query_entities(query)) == [self.TextItem, alias]
|
||||
assert get_query_entities(query) == [self.TextItem, alias]
|
||||
|
||||
def test_column_entity_with_label(self):
|
||||
query = self.session.query(self.Article.id.label('id'))
|
||||
assert list(get_query_entities(query)) == [sa.inspect(self.Article)]
|
||||
assert get_query_entities(query) == [self.Article]
|
||||
|
||||
def test_with_subquery(self):
|
||||
number_of_articles = (
|
||||
@@ -91,7 +91,7 @@ class TestGetQueryEntities(TestCase):
|
||||
).label('number_of_articles')
|
||||
|
||||
query = self.session.query(self.Article, number_of_articles)
|
||||
assert list(get_query_entities(query)) == [
|
||||
assert get_query_entities(query) == [
|
||||
self.Article,
|
||||
number_of_articles
|
||||
]
|
||||
@@ -99,4 +99,4 @@ class TestGetQueryEntities(TestCase):
|
||||
def test_aliased_entity(self):
|
||||
alias = sa.orm.aliased(self.Article)
|
||||
query = self.session.query(alias)
|
||||
assert list(get_query_entities(query)) == [alias]
|
||||
assert get_query_entities(query) == [alias]
|
||||
|
Reference in New Issue
Block a user