Fix query entities SA 1.0 compatibility

This commit is contained in:
Konsta Vesterinen
2015-05-06 10:52:25 +03:00
parent 64714c9a0c
commit 0c88c2dbfa
2 changed files with 32 additions and 26 deletions

View File

@@ -471,28 +471,34 @@ def get_query_entities(query):
:param query: SQLAlchemy Query object
"""
exprs = [
d['expr']
if is_labeled_query(d['expr']) or isinstance(d['expr'], sa.Column)
else d['entity']
for d in query.column_descriptions
]
return [
get_query_entity(entity) for entity in
chain(query._entities, query._join_entities)
get_query_entity(expr) for expr in exprs
] + [
get_query_entity(entity) for entity in query._join_entities
]
def get_query_entity(mixed):
if hasattr(mixed, 'expr'):
expr = mixed.expr
else:
expr = mixed
def is_labeled_query(expr):
return (
isinstance(expr, sa.sql.elements.Label) and
isinstance(
list(expr.base_columns)[0],
(sa.sql.selectable.Select, sa.sql.selectable.ScalarSelect)
)
)
def get_query_entity(expr):
if isinstance(expr, sa.orm.attributes.InstrumentedAttribute):
return expr.parent.class_
elif isinstance(expr, sa.Column):
return expr.table
elif isinstance(expr, sa.sql.expression.Label):
if mixed.entity_zero:
return mixed.entity_zero
else:
return expr
elif isinstance(expr, sa.orm.Mapper):
return expr.class_
elif isinstance(expr, AliasedInsp):
return expr.entity
return expr
@@ -561,7 +567,7 @@ def get_descriptor(entity, attr):
# Handle synonyms, relationship properties and hybrid
# properties
try:
return getattr(entity, attr)
return getattr(mapper.class_, attr)
except AttributeError:
pass

View File

@@ -41,31 +41,31 @@ class TestGetQueryEntities(TestCase):
def test_mapper(self):
query = self.session.query(sa.inspect(self.TextItem))
assert list(get_query_entities(query)) == [self.TextItem]
assert get_query_entities(query) == [self.TextItem]
def test_entity(self):
query = self.session.query(self.TextItem)
assert list(get_query_entities(query)) == [self.TextItem]
assert get_query_entities(query) == [self.TextItem]
def test_instrumented_attribute(self):
query = self.session.query(self.TextItem.id)
assert list(get_query_entities(query)) == [self.TextItem]
assert get_query_entities(query) == [self.TextItem]
def test_column(self):
query = self.session.query(self.TextItem.__table__.c.id)
assert list(get_query_entities(query)) == [self.TextItem.__table__]
assert get_query_entities(query) == [self.TextItem.__table__]
def test_aliased_selectable(self):
selectable = sa.orm.with_polymorphic(self.TextItem, [self.BlogPost])
query = self.session.query(selectable)
assert list(get_query_entities(query)) == [selectable]
assert get_query_entities(query) == [selectable]
def test_joined_entity(self):
query = self.session.query(self.TextItem).join(
self.BlogPost, self.BlogPost.id == self.TextItem.id
)
assert list(get_query_entities(query)) == [
self.TextItem, self.BlogPost
assert get_query_entities(query) == [
self.TextItem, sa.inspect(self.BlogPost)
]
def test_joined_aliased_entity(self):
@@ -74,11 +74,11 @@ class TestGetQueryEntities(TestCase):
query = self.session.query(self.TextItem).join(
alias, alias.id == self.TextItem.id
)
assert list(get_query_entities(query)) == [self.TextItem, alias]
assert get_query_entities(query) == [self.TextItem, alias]
def test_column_entity_with_label(self):
query = self.session.query(self.Article.id.label('id'))
assert list(get_query_entities(query)) == [sa.inspect(self.Article)]
assert get_query_entities(query) == [self.Article]
def test_with_subquery(self):
number_of_articles = (
@@ -91,7 +91,7 @@ class TestGetQueryEntities(TestCase):
).label('number_of_articles')
query = self.session.query(self.Article, number_of_articles)
assert list(get_query_entities(query)) == [
assert get_query_entities(query) == [
self.Article,
number_of_articles
]
@@ -99,4 +99,4 @@ class TestGetQueryEntities(TestCase):
def test_aliased_entity(self):
alias = sa.orm.aliased(self.Article)
query = self.session.query(alias)
assert list(get_query_entities(query)) == [alias]
assert get_query_entities(query) == [alias]