Fix sort_query support for queries using mappers

This commit is contained in:
Konsta Vesterinen
2014-08-28 23:03:27 +03:00
parent 0340ea1533
commit 01b58b5599
4 changed files with 107 additions and 81 deletions

View File

@@ -12,6 +12,7 @@ from sqlalchemy.ext.hybrid import hybrid_property
from sqlalchemy.orm import mapperlib from sqlalchemy.orm import mapperlib
from sqlalchemy.orm.attributes import InstrumentedAttribute from sqlalchemy.orm.attributes import InstrumentedAttribute
from sqlalchemy.orm.exc import UnmappedInstanceError from sqlalchemy.orm.exc import UnmappedInstanceError
from sqlalchemy.orm.properties import ColumnProperty
from sqlalchemy.orm.query import _ColumnEntity from sqlalchemy.orm.query import _ColumnEntity
from sqlalchemy.orm.session import object_session from sqlalchemy.orm.session import object_session
from sqlalchemy.orm.util import AliasedInsp from sqlalchemy.orm.util import AliasedInsp
@@ -197,13 +198,7 @@ def get_tables(mixed):
SQLAlchemy Mapper / Declarative class or a SA Alias object wrapping SQLAlchemy Mapper / Declarative class or a SA Alias object wrapping
any of these objects. any of these objects.
""" """
if isinstance(mixed, sa.orm.util.AliasedClass): return get_mapper(mixed).tables
mapper = sa.inspect(mixed).mapper
else:
if not isclass(mixed):
mixed = mixed.__class__
mapper = sa.inspect(mixed)
return mapper.tables
def get_columns(mixed): def get_columns(mixed):
@@ -414,17 +409,19 @@ def get_query_entities(query):
:param query: SQLAlchemy Query object :param query: SQLAlchemy Query object
""" """
return list( return [
map(get_selectable, chain(query._entities, query._join_entities)) get_query_entity(entity) for entity in
) chain(query._entities, query._join_entities)
]
def get_selectable(mixed): def get_query_entity(mixed):
if isinstance(mixed, (sa.orm.Mapper, AliasedInsp)): if hasattr(mixed, 'expr'):
return mixed expr = mixed.expr
expr = mixed.expr else:
expr = mixed
if isinstance(expr, sa.orm.attributes.InstrumentedAttribute): if isinstance(expr, sa.orm.attributes.InstrumentedAttribute):
return expr.parent return expr.parent.class_
elif isinstance(expr, sa.Column): elif isinstance(expr, sa.Column):
return expr.table return expr.table
elif isinstance(expr, sa.sql.expression.Label): elif isinstance(expr, sa.sql.expression.Label):
@@ -432,17 +429,22 @@ def get_selectable(mixed):
return mixed.entity_zero return mixed.entity_zero
else: else:
return expr return expr
elif isinstance(expr, sa.orm.Mapper):
return expr.class_
elif isinstance(expr, AliasedInsp):
return expr.entity
return expr return expr
def get_query_entity_by_alias(query, alias): def get_query_entity_by_alias(query, alias):
entities = get_query_entities(query) entities = get_query_entities(query)
if not alias: if not alias:
return entities[0] return entities[0]
for entity in entities: for entity in entities:
if isinstance(entity, AliasedInsp): if isinstance(entity, sa.orm.util.AliasedClass):
name = entity.name name = sa.inspect(entity).name
else: else:
name = get_mapper(entity).tables[0].name name = get_mapper(entity).tables[0].name
@@ -457,17 +459,61 @@ def get_polymorphic_mappers(mixed):
return mixed.polymorphic_map.values() return mixed.polymorphic_map.values()
def get_attrs(expr): def get_query_descriptor(query, entity, attr):
insp = sa.inspect(expr) if attr in query_labels(query):
mapper = get_mapper(expr) return attr
polymorphic_mappers = get_polymorphic_mappers(insp) else:
entity = get_query_entity_by_alias(query, entity)
if entity:
descriptor = get_descriptor(entity, attr)
if (
hasattr(descriptor, 'property') and
isinstance(descriptor.property, sa.orm.RelationshipProperty)
):
return
return descriptor
def get_descriptor(entity, attr):
mapper = sa.inspect(entity)
for key, descriptor in get_all_descriptors(mapper).items():
if attr == key:
prop = (
descriptor.property
if hasattr(descriptor, 'property')
else None
)
if isinstance(prop, ColumnProperty):
if isinstance(entity, sa.orm.util.AliasedClass):
for c in mapper.selectable.c:
if c.key == attr:
return c
else:
# If the property belongs to a class that uses
# polymorphic inheritance we have to take into account
# situations where the attribute exists in child class
# but not in parent class.
return getattr(prop.parent.class_, attr)
else:
# Handle synonyms, relationship proeprties and hybrid
# properties
try:
return getattr(entity, attr)
except AttributeError:
pass
def get_all_descriptors(expr):
insp = sa.inspect(expr)
polymorphic_mappers = get_polymorphic_mappers(insp)
if polymorphic_mappers: if polymorphic_mappers:
attrs = {} attrs = {}
for submapper in polymorphic_mappers: for submapper in polymorphic_mappers:
attrs.update(submapper.attrs) attrs.update(submapper.all_orm_descriptors)
return attrs return attrs
return mapper.attrs return get_mapper(expr).all_orm_descriptors
def get_hybrid_properties(model): def get_hybrid_properties(model):
@@ -521,13 +567,6 @@ def get_hybrid_properties(model):
) )
def get_expr_attr(expr, prop):
if isinstance(expr, AliasedInsp):
return getattr(expr.selectable.c, prop.key)
else:
return getattr(prop.parent.class_, prop.key)
def get_declarative_base(model): def get_declarative_base(model):
""" """
Returns the declarative base for given model class. Returns the declarative base for given model class.

View File

@@ -1,15 +1,5 @@
import sqlalchemy as sa from sqlalchemy.sql.expression import desc, asc
from sqlalchemy.orm.properties import ColumnProperty, SynonymProperty from .orm import get_query_descriptor
from sqlalchemy.sql.expression import desc, asc, Label
from sqlalchemy.orm.util import AliasedInsp
from .orm import (
get_attrs,
get_expr_attr,
get_hybrid_properties,
get_query_entity_by_alias,
get_query_entities,
query_labels,
)
class QuerySorterException(Exception): class QuerySorterException(Exception):
@@ -18,18 +8,11 @@ class QuerySorterException(Exception):
class QuerySorter(object): class QuerySorter(object):
def __init__(self, silent=True, separator='-'): def __init__(self, silent=True, separator='-'):
self.labels = []
self.separator = separator self.separator = separator
self.silent = silent self.silent = silent
def assign_order_by(self, entity, attr, func): def assign_order_by(self, entity, attr, func):
expr = None expr = get_query_descriptor(self.query, entity, attr)
if attr in self.labels:
expr = attr
else:
entity = get_query_entity_by_alias(self.query, entity)
if entity:
expr = self.order_by_attr(entity, attr)
if expr is not None: if expr is not None:
return self.query.order_by(func(expr)) return self.query.order_by(func(expr))
@@ -39,30 +22,6 @@ class QuerySorter(object):
) )
return self.query return self.query
def order_by_attr(self, entity, attr):
properties = get_attrs(entity)
if attr in properties:
property_ = properties[attr]
if isinstance(property_, ColumnProperty):
if isinstance(property_.columns[0], Label):
return getattr(entity, property_.key)
else:
return get_expr_attr(entity, property_)
elif isinstance(property_, SynonymProperty):
return get_expr_attr(entity, property_)
return
mapper = sa.inspect(entity)
entity = mapper.entity
if isinstance(mapper, AliasedInsp):
mapper = mapper.mapper
for key in get_hybrid_properties(mapper).keys():
if attr == key:
return getattr(entity, attr)
def parse_sort_arg(self, arg): def parse_sort_arg(self, arg):
if arg[0] == self.separator: if arg[0] == self.separator:
func = desc func = desc
@@ -79,7 +38,6 @@ class QuerySorter(object):
def __call__(self, query, *args): def __call__(self, query, *args):
self.query = query self.query = query
self.labels = query_labels(query)
for sort in args: for sort in args:
if not sort: if not sort:

View File

@@ -41,7 +41,7 @@ class TestGetQueryEntities(TestCase):
def test_mapper(self): def test_mapper(self):
query = self.session.query(sa.inspect(self.TextItem)) query = self.session.query(sa.inspect(self.TextItem))
assert list(get_query_entities(query)) == [sa.inspect(self.TextItem)] assert list(get_query_entities(query)) == [self.TextItem]
def test_entity(self): def test_entity(self):
query = self.session.query(self.TextItem) query = self.session.query(self.TextItem)
@@ -49,7 +49,7 @@ class TestGetQueryEntities(TestCase):
def test_instrumented_attribute(self): def test_instrumented_attribute(self):
query = self.session.query(self.TextItem.id) query = self.session.query(self.TextItem.id)
assert list(get_query_entities(query)) == [sa.inspect(self.TextItem)] assert list(get_query_entities(query)) == [self.TextItem]
def test_column(self): def test_column(self):
query = self.session.query(self.TextItem.__table__.c.id) query = self.session.query(self.TextItem.__table__.c.id)
@@ -65,7 +65,7 @@ class TestGetQueryEntities(TestCase):
self.BlogPost, self.BlogPost.id == self.TextItem.id self.BlogPost, self.BlogPost.id == self.TextItem.id
) )
assert list(get_query_entities(query)) == [ assert list(get_query_entities(query)) == [
self.TextItem, sa.inspect(self.BlogPost) self.TextItem, self.BlogPost
] ]
def test_joined_aliased_entity(self): def test_joined_aliased_entity(self):
@@ -74,9 +74,7 @@ class TestGetQueryEntities(TestCase):
query = self.session.query(self.TextItem).join( query = self.session.query(self.TextItem).join(
alias, alias.id == self.TextItem.id alias, alias.id == self.TextItem.id
) )
assert list(get_query_entities(query)) == [ assert list(get_query_entities(query)) == [self.TextItem, alias]
self.TextItem, sa.inspect(alias)
]
def test_column_entity_with_label(self): def test_column_entity_with_label(self):
query = self.session.query(self.Article.id.label('id')) query = self.session.query(self.Article.id.label('id'))

View File

@@ -156,6 +156,37 @@ class TestSortQuery(TestCase):
query = sort_query(query, 'some_hybrid') query = sort_query(query, 'some_hybrid')
assert_contains('ORDER BY article.name ASC', query) assert_contains('ORDER BY article.name ASC', query)
def test_with_mapper_and_column_property(self):
class Apple(self.Base):
__tablename__ = 'apple'
id = sa.Column(sa.Integer, primary_key=True)
article_id = sa.Column(sa.Integer, sa.ForeignKey(self.Article.id))
self.Article.apples = sa.orm.relationship(Apple)
self.Article.apple_count = sa.orm.column_property(
sa.select([sa.func.count(Apple.id)])
.where(Apple.article_id == self.Article.id)
.correlate(self.Article.__table__)
.label('apple_count'),
deferred=True
)
query = (
self.session.query(sa.inspect(self.Article))
.outerjoin(self.Article.apples)
.options(
sa.orm.undefer(self.Article.apple_count)
)
.options(sa.orm.contains_eager(self.Article.apples))
)
query = sort_query(query, 'apple_count')
assert 'ORDER BY apple_count' in str(query)
def test_table(self):
query = self.session.query(self.Article.__table__)
query = sort_query(query, 'name')
assert_contains('ORDER BY name', query)
class TestSortQueryRelationshipCounts(TestCase): class TestSortQueryRelationshipCounts(TestCase):
""" """