Fix sort_query support for queries using mappers
This commit is contained in:
@@ -12,6 +12,7 @@ from sqlalchemy.ext.hybrid import hybrid_property
|
|||||||
from sqlalchemy.orm import mapperlib
|
from sqlalchemy.orm import mapperlib
|
||||||
from sqlalchemy.orm.attributes import InstrumentedAttribute
|
from sqlalchemy.orm.attributes import InstrumentedAttribute
|
||||||
from sqlalchemy.orm.exc import UnmappedInstanceError
|
from sqlalchemy.orm.exc import UnmappedInstanceError
|
||||||
|
from sqlalchemy.orm.properties import ColumnProperty
|
||||||
from sqlalchemy.orm.query import _ColumnEntity
|
from sqlalchemy.orm.query import _ColumnEntity
|
||||||
from sqlalchemy.orm.session import object_session
|
from sqlalchemy.orm.session import object_session
|
||||||
from sqlalchemy.orm.util import AliasedInsp
|
from sqlalchemy.orm.util import AliasedInsp
|
||||||
@@ -197,13 +198,7 @@ def get_tables(mixed):
|
|||||||
SQLAlchemy Mapper / Declarative class or a SA Alias object wrapping
|
SQLAlchemy Mapper / Declarative class or a SA Alias object wrapping
|
||||||
any of these objects.
|
any of these objects.
|
||||||
"""
|
"""
|
||||||
if isinstance(mixed, sa.orm.util.AliasedClass):
|
return get_mapper(mixed).tables
|
||||||
mapper = sa.inspect(mixed).mapper
|
|
||||||
else:
|
|
||||||
if not isclass(mixed):
|
|
||||||
mixed = mixed.__class__
|
|
||||||
mapper = sa.inspect(mixed)
|
|
||||||
return mapper.tables
|
|
||||||
|
|
||||||
|
|
||||||
def get_columns(mixed):
|
def get_columns(mixed):
|
||||||
@@ -414,17 +409,19 @@ def get_query_entities(query):
|
|||||||
|
|
||||||
:param query: SQLAlchemy Query object
|
:param query: SQLAlchemy Query object
|
||||||
"""
|
"""
|
||||||
return list(
|
return [
|
||||||
map(get_selectable, chain(query._entities, query._join_entities))
|
get_query_entity(entity) for entity in
|
||||||
)
|
chain(query._entities, query._join_entities)
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
def get_selectable(mixed):
|
def get_query_entity(mixed):
|
||||||
if isinstance(mixed, (sa.orm.Mapper, AliasedInsp)):
|
if hasattr(mixed, 'expr'):
|
||||||
return mixed
|
expr = mixed.expr
|
||||||
expr = mixed.expr
|
else:
|
||||||
|
expr = mixed
|
||||||
if isinstance(expr, sa.orm.attributes.InstrumentedAttribute):
|
if isinstance(expr, sa.orm.attributes.InstrumentedAttribute):
|
||||||
return expr.parent
|
return expr.parent.class_
|
||||||
elif isinstance(expr, sa.Column):
|
elif isinstance(expr, sa.Column):
|
||||||
return expr.table
|
return expr.table
|
||||||
elif isinstance(expr, sa.sql.expression.Label):
|
elif isinstance(expr, sa.sql.expression.Label):
|
||||||
@@ -432,17 +429,22 @@ def get_selectable(mixed):
|
|||||||
return mixed.entity_zero
|
return mixed.entity_zero
|
||||||
else:
|
else:
|
||||||
return expr
|
return expr
|
||||||
|
elif isinstance(expr, sa.orm.Mapper):
|
||||||
|
return expr.class_
|
||||||
|
elif isinstance(expr, AliasedInsp):
|
||||||
|
return expr.entity
|
||||||
return expr
|
return expr
|
||||||
|
|
||||||
|
|
||||||
def get_query_entity_by_alias(query, alias):
|
def get_query_entity_by_alias(query, alias):
|
||||||
entities = get_query_entities(query)
|
entities = get_query_entities(query)
|
||||||
|
|
||||||
if not alias:
|
if not alias:
|
||||||
return entities[0]
|
return entities[0]
|
||||||
|
|
||||||
for entity in entities:
|
for entity in entities:
|
||||||
if isinstance(entity, AliasedInsp):
|
if isinstance(entity, sa.orm.util.AliasedClass):
|
||||||
name = entity.name
|
name = sa.inspect(entity).name
|
||||||
else:
|
else:
|
||||||
name = get_mapper(entity).tables[0].name
|
name = get_mapper(entity).tables[0].name
|
||||||
|
|
||||||
@@ -457,17 +459,61 @@ def get_polymorphic_mappers(mixed):
|
|||||||
return mixed.polymorphic_map.values()
|
return mixed.polymorphic_map.values()
|
||||||
|
|
||||||
|
|
||||||
def get_attrs(expr):
|
def get_query_descriptor(query, entity, attr):
|
||||||
insp = sa.inspect(expr)
|
if attr in query_labels(query):
|
||||||
mapper = get_mapper(expr)
|
return attr
|
||||||
polymorphic_mappers = get_polymorphic_mappers(insp)
|
else:
|
||||||
|
entity = get_query_entity_by_alias(query, entity)
|
||||||
|
if entity:
|
||||||
|
descriptor = get_descriptor(entity, attr)
|
||||||
|
if (
|
||||||
|
hasattr(descriptor, 'property') and
|
||||||
|
isinstance(descriptor.property, sa.orm.RelationshipProperty)
|
||||||
|
):
|
||||||
|
return
|
||||||
|
return descriptor
|
||||||
|
|
||||||
|
|
||||||
|
def get_descriptor(entity, attr):
|
||||||
|
mapper = sa.inspect(entity)
|
||||||
|
|
||||||
|
for key, descriptor in get_all_descriptors(mapper).items():
|
||||||
|
if attr == key:
|
||||||
|
prop = (
|
||||||
|
descriptor.property
|
||||||
|
if hasattr(descriptor, 'property')
|
||||||
|
else None
|
||||||
|
)
|
||||||
|
if isinstance(prop, ColumnProperty):
|
||||||
|
if isinstance(entity, sa.orm.util.AliasedClass):
|
||||||
|
for c in mapper.selectable.c:
|
||||||
|
if c.key == attr:
|
||||||
|
return c
|
||||||
|
else:
|
||||||
|
# If the property belongs to a class that uses
|
||||||
|
# polymorphic inheritance we have to take into account
|
||||||
|
# situations where the attribute exists in child class
|
||||||
|
# but not in parent class.
|
||||||
|
return getattr(prop.parent.class_, attr)
|
||||||
|
else:
|
||||||
|
# Handle synonyms, relationship proeprties and hybrid
|
||||||
|
# properties
|
||||||
|
try:
|
||||||
|
return getattr(entity, attr)
|
||||||
|
except AttributeError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
def get_all_descriptors(expr):
|
||||||
|
insp = sa.inspect(expr)
|
||||||
|
polymorphic_mappers = get_polymorphic_mappers(insp)
|
||||||
if polymorphic_mappers:
|
if polymorphic_mappers:
|
||||||
|
|
||||||
attrs = {}
|
attrs = {}
|
||||||
for submapper in polymorphic_mappers:
|
for submapper in polymorphic_mappers:
|
||||||
attrs.update(submapper.attrs)
|
attrs.update(submapper.all_orm_descriptors)
|
||||||
return attrs
|
return attrs
|
||||||
return mapper.attrs
|
return get_mapper(expr).all_orm_descriptors
|
||||||
|
|
||||||
|
|
||||||
def get_hybrid_properties(model):
|
def get_hybrid_properties(model):
|
||||||
@@ -521,13 +567,6 @@ def get_hybrid_properties(model):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def get_expr_attr(expr, prop):
|
|
||||||
if isinstance(expr, AliasedInsp):
|
|
||||||
return getattr(expr.selectable.c, prop.key)
|
|
||||||
else:
|
|
||||||
return getattr(prop.parent.class_, prop.key)
|
|
||||||
|
|
||||||
|
|
||||||
def get_declarative_base(model):
|
def get_declarative_base(model):
|
||||||
"""
|
"""
|
||||||
Returns the declarative base for given model class.
|
Returns the declarative base for given model class.
|
||||||
|
@@ -1,15 +1,5 @@
|
|||||||
import sqlalchemy as sa
|
from sqlalchemy.sql.expression import desc, asc
|
||||||
from sqlalchemy.orm.properties import ColumnProperty, SynonymProperty
|
from .orm import get_query_descriptor
|
||||||
from sqlalchemy.sql.expression import desc, asc, Label
|
|
||||||
from sqlalchemy.orm.util import AliasedInsp
|
|
||||||
from .orm import (
|
|
||||||
get_attrs,
|
|
||||||
get_expr_attr,
|
|
||||||
get_hybrid_properties,
|
|
||||||
get_query_entity_by_alias,
|
|
||||||
get_query_entities,
|
|
||||||
query_labels,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class QuerySorterException(Exception):
|
class QuerySorterException(Exception):
|
||||||
@@ -18,18 +8,11 @@ class QuerySorterException(Exception):
|
|||||||
|
|
||||||
class QuerySorter(object):
|
class QuerySorter(object):
|
||||||
def __init__(self, silent=True, separator='-'):
|
def __init__(self, silent=True, separator='-'):
|
||||||
self.labels = []
|
|
||||||
self.separator = separator
|
self.separator = separator
|
||||||
self.silent = silent
|
self.silent = silent
|
||||||
|
|
||||||
def assign_order_by(self, entity, attr, func):
|
def assign_order_by(self, entity, attr, func):
|
||||||
expr = None
|
expr = get_query_descriptor(self.query, entity, attr)
|
||||||
if attr in self.labels:
|
|
||||||
expr = attr
|
|
||||||
else:
|
|
||||||
entity = get_query_entity_by_alias(self.query, entity)
|
|
||||||
if entity:
|
|
||||||
expr = self.order_by_attr(entity, attr)
|
|
||||||
|
|
||||||
if expr is not None:
|
if expr is not None:
|
||||||
return self.query.order_by(func(expr))
|
return self.query.order_by(func(expr))
|
||||||
@@ -39,30 +22,6 @@ class QuerySorter(object):
|
|||||||
)
|
)
|
||||||
return self.query
|
return self.query
|
||||||
|
|
||||||
def order_by_attr(self, entity, attr):
|
|
||||||
properties = get_attrs(entity)
|
|
||||||
if attr in properties:
|
|
||||||
property_ = properties[attr]
|
|
||||||
|
|
||||||
if isinstance(property_, ColumnProperty):
|
|
||||||
if isinstance(property_.columns[0], Label):
|
|
||||||
return getattr(entity, property_.key)
|
|
||||||
else:
|
|
||||||
return get_expr_attr(entity, property_)
|
|
||||||
elif isinstance(property_, SynonymProperty):
|
|
||||||
return get_expr_attr(entity, property_)
|
|
||||||
return
|
|
||||||
|
|
||||||
mapper = sa.inspect(entity)
|
|
||||||
entity = mapper.entity
|
|
||||||
|
|
||||||
if isinstance(mapper, AliasedInsp):
|
|
||||||
mapper = mapper.mapper
|
|
||||||
|
|
||||||
for key in get_hybrid_properties(mapper).keys():
|
|
||||||
if attr == key:
|
|
||||||
return getattr(entity, attr)
|
|
||||||
|
|
||||||
def parse_sort_arg(self, arg):
|
def parse_sort_arg(self, arg):
|
||||||
if arg[0] == self.separator:
|
if arg[0] == self.separator:
|
||||||
func = desc
|
func = desc
|
||||||
@@ -79,7 +38,6 @@ class QuerySorter(object):
|
|||||||
|
|
||||||
def __call__(self, query, *args):
|
def __call__(self, query, *args):
|
||||||
self.query = query
|
self.query = query
|
||||||
self.labels = query_labels(query)
|
|
||||||
|
|
||||||
for sort in args:
|
for sort in args:
|
||||||
if not sort:
|
if not sort:
|
||||||
|
@@ -41,7 +41,7 @@ class TestGetQueryEntities(TestCase):
|
|||||||
|
|
||||||
def test_mapper(self):
|
def test_mapper(self):
|
||||||
query = self.session.query(sa.inspect(self.TextItem))
|
query = self.session.query(sa.inspect(self.TextItem))
|
||||||
assert list(get_query_entities(query)) == [sa.inspect(self.TextItem)]
|
assert list(get_query_entities(query)) == [self.TextItem]
|
||||||
|
|
||||||
def test_entity(self):
|
def test_entity(self):
|
||||||
query = self.session.query(self.TextItem)
|
query = self.session.query(self.TextItem)
|
||||||
@@ -49,7 +49,7 @@ class TestGetQueryEntities(TestCase):
|
|||||||
|
|
||||||
def test_instrumented_attribute(self):
|
def test_instrumented_attribute(self):
|
||||||
query = self.session.query(self.TextItem.id)
|
query = self.session.query(self.TextItem.id)
|
||||||
assert list(get_query_entities(query)) == [sa.inspect(self.TextItem)]
|
assert list(get_query_entities(query)) == [self.TextItem]
|
||||||
|
|
||||||
def test_column(self):
|
def test_column(self):
|
||||||
query = self.session.query(self.TextItem.__table__.c.id)
|
query = self.session.query(self.TextItem.__table__.c.id)
|
||||||
@@ -65,7 +65,7 @@ class TestGetQueryEntities(TestCase):
|
|||||||
self.BlogPost, self.BlogPost.id == self.TextItem.id
|
self.BlogPost, self.BlogPost.id == self.TextItem.id
|
||||||
)
|
)
|
||||||
assert list(get_query_entities(query)) == [
|
assert list(get_query_entities(query)) == [
|
||||||
self.TextItem, sa.inspect(self.BlogPost)
|
self.TextItem, self.BlogPost
|
||||||
]
|
]
|
||||||
|
|
||||||
def test_joined_aliased_entity(self):
|
def test_joined_aliased_entity(self):
|
||||||
@@ -74,9 +74,7 @@ class TestGetQueryEntities(TestCase):
|
|||||||
query = self.session.query(self.TextItem).join(
|
query = self.session.query(self.TextItem).join(
|
||||||
alias, alias.id == self.TextItem.id
|
alias, alias.id == self.TextItem.id
|
||||||
)
|
)
|
||||||
assert list(get_query_entities(query)) == [
|
assert list(get_query_entities(query)) == [self.TextItem, alias]
|
||||||
self.TextItem, sa.inspect(alias)
|
|
||||||
]
|
|
||||||
|
|
||||||
def test_column_entity_with_label(self):
|
def test_column_entity_with_label(self):
|
||||||
query = self.session.query(self.Article.id.label('id'))
|
query = self.session.query(self.Article.id.label('id'))
|
||||||
|
@@ -156,6 +156,37 @@ class TestSortQuery(TestCase):
|
|||||||
query = sort_query(query, 'some_hybrid')
|
query = sort_query(query, 'some_hybrid')
|
||||||
assert_contains('ORDER BY article.name ASC', query)
|
assert_contains('ORDER BY article.name ASC', query)
|
||||||
|
|
||||||
|
def test_with_mapper_and_column_property(self):
|
||||||
|
class Apple(self.Base):
|
||||||
|
__tablename__ = 'apple'
|
||||||
|
id = sa.Column(sa.Integer, primary_key=True)
|
||||||
|
article_id = sa.Column(sa.Integer, sa.ForeignKey(self.Article.id))
|
||||||
|
|
||||||
|
self.Article.apples = sa.orm.relationship(Apple)
|
||||||
|
|
||||||
|
self.Article.apple_count = sa.orm.column_property(
|
||||||
|
sa.select([sa.func.count(Apple.id)])
|
||||||
|
.where(Apple.article_id == self.Article.id)
|
||||||
|
.correlate(self.Article.__table__)
|
||||||
|
.label('apple_count'),
|
||||||
|
deferred=True
|
||||||
|
)
|
||||||
|
query = (
|
||||||
|
self.session.query(sa.inspect(self.Article))
|
||||||
|
.outerjoin(self.Article.apples)
|
||||||
|
.options(
|
||||||
|
sa.orm.undefer(self.Article.apple_count)
|
||||||
|
)
|
||||||
|
.options(sa.orm.contains_eager(self.Article.apples))
|
||||||
|
)
|
||||||
|
query = sort_query(query, 'apple_count')
|
||||||
|
assert 'ORDER BY apple_count' in str(query)
|
||||||
|
|
||||||
|
def test_table(self):
|
||||||
|
query = self.session.query(self.Article.__table__)
|
||||||
|
query = sort_query(query, 'name')
|
||||||
|
assert_contains('ORDER BY name', query)
|
||||||
|
|
||||||
|
|
||||||
class TestSortQueryRelationshipCounts(TestCase):
|
class TestSortQueryRelationshipCounts(TestCase):
|
||||||
"""
|
"""
|
||||||
|
Reference in New Issue
Block a user