Fix query entities SA 1.0 compatibility

This commit is contained in:
Konsta Vesterinen
2015-05-06 10:52:25 +03:00
parent 64714c9a0c
commit 0c88c2dbfa
2 changed files with 32 additions and 26 deletions

View File

@@ -471,28 +471,34 @@ def get_query_entities(query):
:param query: SQLAlchemy Query object :param query: SQLAlchemy Query object
""" """
exprs = [
d['expr']
if is_labeled_query(d['expr']) or isinstance(d['expr'], sa.Column)
else d['entity']
for d in query.column_descriptions
]
return [ return [
get_query_entity(entity) for entity in get_query_entity(expr) for expr in exprs
chain(query._entities, query._join_entities) ] + [
get_query_entity(entity) for entity in query._join_entities
] ]
def get_query_entity(mixed): def is_labeled_query(expr):
if hasattr(mixed, 'expr'): return (
expr = mixed.expr isinstance(expr, sa.sql.elements.Label) and
else: isinstance(
expr = mixed list(expr.base_columns)[0],
(sa.sql.selectable.Select, sa.sql.selectable.ScalarSelect)
)
)
def get_query_entity(expr):
if isinstance(expr, sa.orm.attributes.InstrumentedAttribute): if isinstance(expr, sa.orm.attributes.InstrumentedAttribute):
return expr.parent.class_ return expr.parent.class_
elif isinstance(expr, sa.Column): elif isinstance(expr, sa.Column):
return expr.table return expr.table
elif isinstance(expr, sa.sql.expression.Label):
if mixed.entity_zero:
return mixed.entity_zero
else:
return expr
elif isinstance(expr, sa.orm.Mapper):
return expr.class_
elif isinstance(expr, AliasedInsp): elif isinstance(expr, AliasedInsp):
return expr.entity return expr.entity
return expr return expr
@@ -561,7 +567,7 @@ def get_descriptor(entity, attr):
# Handle synonyms, relationship properties and hybrid # Handle synonyms, relationship properties and hybrid
# properties # properties
try: try:
return getattr(entity, attr) return getattr(mapper.class_, attr)
except AttributeError: except AttributeError:
pass pass

View File

@@ -41,31 +41,31 @@ class TestGetQueryEntities(TestCase):
def test_mapper(self): def test_mapper(self):
query = self.session.query(sa.inspect(self.TextItem)) query = self.session.query(sa.inspect(self.TextItem))
assert list(get_query_entities(query)) == [self.TextItem] assert get_query_entities(query) == [self.TextItem]
def test_entity(self): def test_entity(self):
query = self.session.query(self.TextItem) query = self.session.query(self.TextItem)
assert list(get_query_entities(query)) == [self.TextItem] assert get_query_entities(query) == [self.TextItem]
def test_instrumented_attribute(self): def test_instrumented_attribute(self):
query = self.session.query(self.TextItem.id) query = self.session.query(self.TextItem.id)
assert list(get_query_entities(query)) == [self.TextItem] assert get_query_entities(query) == [self.TextItem]
def test_column(self): def test_column(self):
query = self.session.query(self.TextItem.__table__.c.id) query = self.session.query(self.TextItem.__table__.c.id)
assert list(get_query_entities(query)) == [self.TextItem.__table__] assert get_query_entities(query) == [self.TextItem.__table__]
def test_aliased_selectable(self): def test_aliased_selectable(self):
selectable = sa.orm.with_polymorphic(self.TextItem, [self.BlogPost]) selectable = sa.orm.with_polymorphic(self.TextItem, [self.BlogPost])
query = self.session.query(selectable) query = self.session.query(selectable)
assert list(get_query_entities(query)) == [selectable] assert get_query_entities(query) == [selectable]
def test_joined_entity(self): def test_joined_entity(self):
query = self.session.query(self.TextItem).join( query = self.session.query(self.TextItem).join(
self.BlogPost, self.BlogPost.id == self.TextItem.id self.BlogPost, self.BlogPost.id == self.TextItem.id
) )
assert list(get_query_entities(query)) == [ assert get_query_entities(query) == [
self.TextItem, self.BlogPost self.TextItem, sa.inspect(self.BlogPost)
] ]
def test_joined_aliased_entity(self): def test_joined_aliased_entity(self):
@@ -74,11 +74,11 @@ class TestGetQueryEntities(TestCase):
query = self.session.query(self.TextItem).join( query = self.session.query(self.TextItem).join(
alias, alias.id == self.TextItem.id alias, alias.id == self.TextItem.id
) )
assert list(get_query_entities(query)) == [self.TextItem, alias] assert get_query_entities(query) == [self.TextItem, alias]
def test_column_entity_with_label(self): def test_column_entity_with_label(self):
query = self.session.query(self.Article.id.label('id')) query = self.session.query(self.Article.id.label('id'))
assert list(get_query_entities(query)) == [sa.inspect(self.Article)] assert get_query_entities(query) == [self.Article]
def test_with_subquery(self): def test_with_subquery(self):
number_of_articles = ( number_of_articles = (
@@ -91,7 +91,7 @@ class TestGetQueryEntities(TestCase):
).label('number_of_articles') ).label('number_of_articles')
query = self.session.query(self.Article, number_of_articles) query = self.session.query(self.Article, number_of_articles)
assert list(get_query_entities(query)) == [ assert get_query_entities(query) == [
self.Article, self.Article,
number_of_articles number_of_articles
] ]
@@ -99,4 +99,4 @@ class TestGetQueryEntities(TestCase):
def test_aliased_entity(self): def test_aliased_entity(self):
alias = sa.orm.aliased(self.Article) alias = sa.orm.aliased(self.Article)
query = self.session.query(alias) query = self.session.query(alias)
assert list(get_query_entities(query)) == [alias] assert get_query_entities(query) == [alias]