Refactor QuerySorter & batch fetch
This commit is contained in:
@@ -37,31 +37,33 @@ class with_backrefs(object):
|
||||
self.path = path
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
class Path(object):
|
||||
"""
|
||||
A class that represents an attribute path.
|
||||
"""
|
||||
def __init__(self, entities, prop, populate_backrefs=False):
|
||||
self.validate_property(prop)
|
||||
self.property = prop
|
||||
self.entities = entities
|
||||
self.populate_backrefs = populate_backrefs
|
||||
self.fetcher = self.fetcher_class(self)
|
||||
|
||||
def validate_property(self, prop):
|
||||
if (
|
||||
not isinstance(self.property, RelationshipProperty) and
|
||||
not isinstance(self.property, GenericRelationshipProperty)
|
||||
not isinstance(prop, RelationshipProperty) and
|
||||
not isinstance(prop, GenericRelationshipProperty)
|
||||
):
|
||||
raise PathException(
|
||||
'Given attribute is not a relationship property.'
|
||||
)
|
||||
self.fetcher = self.fetcher_class(self)
|
||||
|
||||
@property
|
||||
def session(self):
|
||||
return object_session(self.entities[0])
|
||||
|
||||
@property
|
||||
def parent_model(self):
|
||||
return self.entities[0].__class__
|
||||
|
||||
@property
|
||||
def model(self):
|
||||
return self.property.mapper.class_
|
||||
@@ -267,7 +269,7 @@ class Fetcher(object):
|
||||
)
|
||||
for value in related_entities:
|
||||
backref_dict[local_values(self.prop, value[0])].append(
|
||||
self.path.session.query(self.path.parent_model).get(
|
||||
self.path.session.query(self.path.entities[0].__class__).get(
|
||||
tuple(value[1:])
|
||||
)
|
||||
)
|
||||
|
@@ -2,6 +2,10 @@ from functools import partial
|
||||
from funcy import first
|
||||
from toolz import curry
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy import inspect
|
||||
from sqlalchemy.orm.query import _ColumnEntity
|
||||
from sqlalchemy.orm.mapper import Mapper
|
||||
from sqlalchemy.orm.util import AliasedInsp
|
||||
|
||||
|
||||
def remove_property(class_, name):
|
||||
@@ -123,6 +127,83 @@ def remote_column_names(prop):
|
||||
yield remote.name
|
||||
|
||||
|
||||
|
||||
def query_labels(query):
|
||||
"""
|
||||
Return all labels for given SQLAlchemy query object.
|
||||
|
||||
Example::
|
||||
|
||||
|
||||
query = session.query(
|
||||
Category,
|
||||
db.func.count(Article.id).label('articles')
|
||||
)
|
||||
|
||||
query_labels(query) # ('articles', )
|
||||
|
||||
:param query: SQLAlchemy Query object
|
||||
"""
|
||||
for entity in query._entities:
|
||||
if isinstance(entity, _ColumnEntity) and entity._label_name:
|
||||
yield entity._label_name
|
||||
|
||||
|
||||
def query_entities(query):
|
||||
"""
|
||||
Return all entities for given SQLAlchemy query object.
|
||||
|
||||
Example::
|
||||
|
||||
|
||||
query = session.query(
|
||||
Category
|
||||
)
|
||||
|
||||
query_entities(query) # ('Category', )
|
||||
|
||||
:param query: SQLAlchemy Query object
|
||||
"""
|
||||
for entity in query._entities:
|
||||
yield entity.entity_zero.class_
|
||||
|
||||
for entity in query._join_entities:
|
||||
if isinstance(entity, Mapper):
|
||||
yield entity.class_
|
||||
else:
|
||||
yield entity
|
||||
|
||||
|
||||
def get_query_entity_by_alias(query, alias):
|
||||
entities = query_entities(query)
|
||||
if not alias:
|
||||
return first(entities)
|
||||
|
||||
for entity in entities:
|
||||
if isinstance(entity, AliasedInsp):
|
||||
name = entity.name
|
||||
else:
|
||||
name = entity.__table__.name
|
||||
|
||||
if name == alias:
|
||||
return entity
|
||||
|
||||
|
||||
def attrs(expr):
|
||||
if isinstance(expr, AliasedInsp):
|
||||
return expr.mapper.attrs
|
||||
else:
|
||||
return inspect(expr).attrs
|
||||
|
||||
|
||||
def get_expr_attr(expr, attr_name):
|
||||
if isinstance(expr, AliasedInsp):
|
||||
return getattr(expr.selectable.c, attr_name)
|
||||
else:
|
||||
return getattr(expr, attr_name)
|
||||
|
||||
|
||||
|
||||
def declarative_base(model):
|
||||
"""
|
||||
Returns the declarative base for given model class.
|
||||
|
@@ -1,23 +1,13 @@
|
||||
from sqlalchemy import inspect
|
||||
from sqlalchemy.orm.mapper import Mapper
|
||||
from sqlalchemy.orm.properties import ColumnProperty
|
||||
from sqlalchemy.orm.query import _ColumnEntity
|
||||
from sqlalchemy.orm.util import AliasedInsp
|
||||
from sqlalchemy.sql.expression import desc, asc, Label
|
||||
|
||||
|
||||
def attrs(expr):
|
||||
if isinstance(expr, AliasedInsp):
|
||||
return expr.mapper.attrs
|
||||
else:
|
||||
return inspect(expr).attrs
|
||||
|
||||
|
||||
def sort_expression(expr, attr_name):
|
||||
if isinstance(expr, AliasedInsp):
|
||||
return getattr(expr.selectable.c, attr_name)
|
||||
else:
|
||||
return getattr(expr, attr_name)
|
||||
from sqlalchemy.orm.util import AliasedInsp
|
||||
from .orm import (
|
||||
attrs,
|
||||
query_labels,
|
||||
query_entities,
|
||||
get_query_entity_by_alias,
|
||||
get_expr_attr
|
||||
)
|
||||
|
||||
|
||||
class QuerySorterException(Exception):
|
||||
@@ -31,57 +21,20 @@ class QuerySorter(object):
|
||||
self.separator = separator
|
||||
self.silent = silent
|
||||
|
||||
def inspect_labels_and_entities(self):
|
||||
for entity in self.query._entities:
|
||||
# get all label names for queries such as:
|
||||
# db.session.query(
|
||||
# Category,
|
||||
# db.func.count(Article.id).label('articles')
|
||||
# )
|
||||
if isinstance(entity, _ColumnEntity) and entity._label_name:
|
||||
self.labels.append(entity._label_name)
|
||||
else:
|
||||
self.entities.append(entity.entity_zero.class_)
|
||||
|
||||
for mapper in self.query._join_entities:
|
||||
if isinstance(mapper, Mapper):
|
||||
self.entities.append(mapper.class_)
|
||||
else:
|
||||
self.entities.append(mapper)
|
||||
|
||||
def get_entity_by_alias(self, alias):
|
||||
if not alias:
|
||||
return self.entities[0]
|
||||
|
||||
for entity in self.entities:
|
||||
if isinstance(entity, AliasedInsp):
|
||||
name = entity.name
|
||||
else:
|
||||
name = entity.__table__.name
|
||||
|
||||
if name == alias:
|
||||
return entity
|
||||
|
||||
def assign_order_by(self, sort):
|
||||
if not sort:
|
||||
return self.query
|
||||
|
||||
sort = self.parse_sort_arg(sort)
|
||||
def assign_order_by(self, entity, attr, func):
|
||||
expr = None
|
||||
if sort['attr'] in self.labels:
|
||||
expr = sort['attr']
|
||||
if attr in self.labels:
|
||||
expr = attr
|
||||
else:
|
||||
entity = self.get_entity_by_alias(sort['entity'])
|
||||
entity = get_query_entity_by_alias(self.query, entity)
|
||||
if entity:
|
||||
expr = self.order_by_attr(entity, sort['attr'])
|
||||
expr = self.order_by_attr(entity, attr)
|
||||
|
||||
if expr is not None:
|
||||
return self.query.order_by(
|
||||
sort['func'](expr)
|
||||
)
|
||||
return self.query.order_by(func(expr))
|
||||
if not self.silent:
|
||||
raise QuerySorterException(
|
||||
"Could not sort query with expression '%s'" % sort['attr']
|
||||
"Could not sort query with expression '%s'" % attr
|
||||
)
|
||||
return self.query
|
||||
|
||||
@@ -93,7 +46,7 @@ class QuerySorter(object):
|
||||
if isinstance(property_.columns[0], Label):
|
||||
expr = property_.columns[0].name
|
||||
else:
|
||||
expr = sort_expression(entity, property_.key)
|
||||
expr = get_expr_attr(entity, property_.key)
|
||||
return expr
|
||||
else:
|
||||
return
|
||||
@@ -119,9 +72,14 @@ class QuerySorter(object):
|
||||
|
||||
def __call__(self, query, *args):
|
||||
self.query = query
|
||||
self.inspect_labels_and_entities()
|
||||
self.labels = query_labels(query)
|
||||
self.entities = query_entities(query)
|
||||
for sort in args:
|
||||
self.query = self.assign_order_by(sort)
|
||||
if not sort:
|
||||
continue
|
||||
self.query = self.assign_order_by(
|
||||
**self.parse_sort_arg(sort)
|
||||
)
|
||||
return self.query
|
||||
|
||||
|
||||
|
Reference in New Issue
Block a user