Refactor QuerySorter & batch fetch

This commit is contained in:
Konsta Vesterinen
2013-12-26 03:18:22 +02:00
parent c911f50581
commit f1e9897fc5
3 changed files with 114 additions and 73 deletions

View File

@@ -37,31 +37,33 @@ class with_backrefs(object):
self.path = path
class Path(object):
"""
A class that represents an attribute path.
"""
def __init__(self, entities, prop, populate_backrefs=False):
self.validate_property(prop)
self.property = prop
self.entities = entities
self.populate_backrefs = populate_backrefs
self.fetcher = self.fetcher_class(self)
def validate_property(self, prop):
if (
not isinstance(self.property, RelationshipProperty) and
not isinstance(self.property, GenericRelationshipProperty)
not isinstance(prop, RelationshipProperty) and
not isinstance(prop, GenericRelationshipProperty)
):
raise PathException(
'Given attribute is not a relationship property.'
)
self.fetcher = self.fetcher_class(self)
@property
def session(self):
return object_session(self.entities[0])
@property
def parent_model(self):
return self.entities[0].__class__
@property
def model(self):
return self.property.mapper.class_
@@ -267,7 +269,7 @@ class Fetcher(object):
)
for value in related_entities:
backref_dict[local_values(self.prop, value[0])].append(
self.path.session.query(self.path.parent_model).get(
self.path.session.query(self.path.entities[0].__class__).get(
tuple(value[1:])
)
)

View File

@@ -2,6 +2,10 @@ from functools import partial
from funcy import first
from toolz import curry
import sqlalchemy as sa
from sqlalchemy import inspect
from sqlalchemy.orm.query import _ColumnEntity
from sqlalchemy.orm.mapper import Mapper
from sqlalchemy.orm.util import AliasedInsp
def remove_property(class_, name):
@@ -123,6 +127,83 @@ def remote_column_names(prop):
yield remote.name
def query_labels(query):
"""
Return all labels for given SQLAlchemy query object.
Example::
query = session.query(
Category,
db.func.count(Article.id).label('articles')
)
query_labels(query) # ('articles', )
:param query: SQLAlchemy Query object
"""
for entity in query._entities:
if isinstance(entity, _ColumnEntity) and entity._label_name:
yield entity._label_name
def query_entities(query):
"""
Return all entities for given SQLAlchemy query object.
Example::
query = session.query(
Category
)
query_entities(query) # ('Category', )
:param query: SQLAlchemy Query object
"""
for entity in query._entities:
yield entity.entity_zero.class_
for entity in query._join_entities:
if isinstance(entity, Mapper):
yield entity.class_
else:
yield entity
def get_query_entity_by_alias(query, alias):
entities = query_entities(query)
if not alias:
return first(entities)
for entity in entities:
if isinstance(entity, AliasedInsp):
name = entity.name
else:
name = entity.__table__.name
if name == alias:
return entity
def attrs(expr):
if isinstance(expr, AliasedInsp):
return expr.mapper.attrs
else:
return inspect(expr).attrs
def get_expr_attr(expr, attr_name):
if isinstance(expr, AliasedInsp):
return getattr(expr.selectable.c, attr_name)
else:
return getattr(expr, attr_name)
def declarative_base(model):
"""
Returns the declarative base for given model class.

View File

@@ -1,23 +1,13 @@
from sqlalchemy import inspect
from sqlalchemy.orm.mapper import Mapper
from sqlalchemy.orm.properties import ColumnProperty
from sqlalchemy.orm.query import _ColumnEntity
from sqlalchemy.orm.util import AliasedInsp
from sqlalchemy.sql.expression import desc, asc, Label
def attrs(expr):
if isinstance(expr, AliasedInsp):
return expr.mapper.attrs
else:
return inspect(expr).attrs
def sort_expression(expr, attr_name):
if isinstance(expr, AliasedInsp):
return getattr(expr.selectable.c, attr_name)
else:
return getattr(expr, attr_name)
from sqlalchemy.orm.util import AliasedInsp
from .orm import (
attrs,
query_labels,
query_entities,
get_query_entity_by_alias,
get_expr_attr
)
class QuerySorterException(Exception):
@@ -31,57 +21,20 @@ class QuerySorter(object):
self.separator = separator
self.silent = silent
def inspect_labels_and_entities(self):
for entity in self.query._entities:
# get all label names for queries such as:
# db.session.query(
# Category,
# db.func.count(Article.id).label('articles')
# )
if isinstance(entity, _ColumnEntity) and entity._label_name:
self.labels.append(entity._label_name)
else:
self.entities.append(entity.entity_zero.class_)
for mapper in self.query._join_entities:
if isinstance(mapper, Mapper):
self.entities.append(mapper.class_)
else:
self.entities.append(mapper)
def get_entity_by_alias(self, alias):
if not alias:
return self.entities[0]
for entity in self.entities:
if isinstance(entity, AliasedInsp):
name = entity.name
else:
name = entity.__table__.name
if name == alias:
return entity
def assign_order_by(self, sort):
if not sort:
return self.query
sort = self.parse_sort_arg(sort)
def assign_order_by(self, entity, attr, func):
expr = None
if sort['attr'] in self.labels:
expr = sort['attr']
if attr in self.labels:
expr = attr
else:
entity = self.get_entity_by_alias(sort['entity'])
entity = get_query_entity_by_alias(self.query, entity)
if entity:
expr = self.order_by_attr(entity, sort['attr'])
expr = self.order_by_attr(entity, attr)
if expr is not None:
return self.query.order_by(
sort['func'](expr)
)
return self.query.order_by(func(expr))
if not self.silent:
raise QuerySorterException(
"Could not sort query with expression '%s'" % sort['attr']
"Could not sort query with expression '%s'" % attr
)
return self.query
@@ -93,7 +46,7 @@ class QuerySorter(object):
if isinstance(property_.columns[0], Label):
expr = property_.columns[0].name
else:
expr = sort_expression(entity, property_.key)
expr = get_expr_attr(entity, property_.key)
return expr
else:
return
@@ -119,9 +72,14 @@ class QuerySorter(object):
def __call__(self, query, *args):
self.query = query
self.inspect_labels_and_entities()
self.labels = query_labels(query)
self.entities = query_entities(query)
for sort in args:
self.query = self.assign_order_by(sort)
if not sort:
continue
self.query = self.assign_order_by(
**self.parse_sort_arg(sort)
)
return self.query