Refactor QuerySorter & batch fetch

This commit is contained in:
Konsta Vesterinen
2013-12-26 03:18:22 +02:00
parent c911f50581
commit f1e9897fc5
3 changed files with 114 additions and 73 deletions

View File

@@ -37,31 +37,33 @@ class with_backrefs(object):
self.path = path self.path = path
class Path(object): class Path(object):
""" """
A class that represents an attribute path. A class that represents an attribute path.
""" """
def __init__(self, entities, prop, populate_backrefs=False): def __init__(self, entities, prop, populate_backrefs=False):
self.validate_property(prop)
self.property = prop self.property = prop
self.entities = entities self.entities = entities
self.populate_backrefs = populate_backrefs self.populate_backrefs = populate_backrefs
self.fetcher = self.fetcher_class(self)
def validate_property(self, prop):
if ( if (
not isinstance(self.property, RelationshipProperty) and not isinstance(prop, RelationshipProperty) and
not isinstance(self.property, GenericRelationshipProperty) not isinstance(prop, GenericRelationshipProperty)
): ):
raise PathException( raise PathException(
'Given attribute is not a relationship property.' 'Given attribute is not a relationship property.'
) )
self.fetcher = self.fetcher_class(self)
@property @property
def session(self): def session(self):
return object_session(self.entities[0]) return object_session(self.entities[0])
@property
def parent_model(self):
return self.entities[0].__class__
@property @property
def model(self): def model(self):
return self.property.mapper.class_ return self.property.mapper.class_
@@ -267,7 +269,7 @@ class Fetcher(object):
) )
for value in related_entities: for value in related_entities:
backref_dict[local_values(self.prop, value[0])].append( backref_dict[local_values(self.prop, value[0])].append(
self.path.session.query(self.path.parent_model).get( self.path.session.query(self.path.entities[0].__class__).get(
tuple(value[1:]) tuple(value[1:])
) )
) )

View File

@@ -2,6 +2,10 @@ from functools import partial
from funcy import first from funcy import first
from toolz import curry from toolz import curry
import sqlalchemy as sa import sqlalchemy as sa
from sqlalchemy import inspect
from sqlalchemy.orm.query import _ColumnEntity
from sqlalchemy.orm.mapper import Mapper
from sqlalchemy.orm.util import AliasedInsp
def remove_property(class_, name): def remove_property(class_, name):
@@ -123,6 +127,83 @@ def remote_column_names(prop):
yield remote.name yield remote.name
def query_labels(query):
"""
Return all labels for given SQLAlchemy query object.
Example::
query = session.query(
Category,
db.func.count(Article.id).label('articles')
)
query_labels(query) # ('articles', )
:param query: SQLAlchemy Query object
"""
for entity in query._entities:
if isinstance(entity, _ColumnEntity) and entity._label_name:
yield entity._label_name
def query_entities(query):
"""
Return all entities for given SQLAlchemy query object.
Example::
query = session.query(
Category
)
query_entities(query) # ('Category', )
:param query: SQLAlchemy Query object
"""
for entity in query._entities:
yield entity.entity_zero.class_
for entity in query._join_entities:
if isinstance(entity, Mapper):
yield entity.class_
else:
yield entity
def get_query_entity_by_alias(query, alias):
entities = query_entities(query)
if not alias:
return first(entities)
for entity in entities:
if isinstance(entity, AliasedInsp):
name = entity.name
else:
name = entity.__table__.name
if name == alias:
return entity
def attrs(expr):
if isinstance(expr, AliasedInsp):
return expr.mapper.attrs
else:
return inspect(expr).attrs
def get_expr_attr(expr, attr_name):
if isinstance(expr, AliasedInsp):
return getattr(expr.selectable.c, attr_name)
else:
return getattr(expr, attr_name)
def declarative_base(model): def declarative_base(model):
""" """
Returns the declarative base for given model class. Returns the declarative base for given model class.

View File

@@ -1,23 +1,13 @@
from sqlalchemy import inspect
from sqlalchemy.orm.mapper import Mapper
from sqlalchemy.orm.properties import ColumnProperty from sqlalchemy.orm.properties import ColumnProperty
from sqlalchemy.orm.query import _ColumnEntity
from sqlalchemy.orm.util import AliasedInsp
from sqlalchemy.sql.expression import desc, asc, Label from sqlalchemy.sql.expression import desc, asc, Label
from sqlalchemy.orm.util import AliasedInsp
from .orm import (
def attrs(expr): attrs,
if isinstance(expr, AliasedInsp): query_labels,
return expr.mapper.attrs query_entities,
else: get_query_entity_by_alias,
return inspect(expr).attrs get_expr_attr
)
def sort_expression(expr, attr_name):
if isinstance(expr, AliasedInsp):
return getattr(expr.selectable.c, attr_name)
else:
return getattr(expr, attr_name)
class QuerySorterException(Exception): class QuerySorterException(Exception):
@@ -31,57 +21,20 @@ class QuerySorter(object):
self.separator = separator self.separator = separator
self.silent = silent self.silent = silent
def inspect_labels_and_entities(self): def assign_order_by(self, entity, attr, func):
for entity in self.query._entities:
# get all label names for queries such as:
# db.session.query(
# Category,
# db.func.count(Article.id).label('articles')
# )
if isinstance(entity, _ColumnEntity) and entity._label_name:
self.labels.append(entity._label_name)
else:
self.entities.append(entity.entity_zero.class_)
for mapper in self.query._join_entities:
if isinstance(mapper, Mapper):
self.entities.append(mapper.class_)
else:
self.entities.append(mapper)
def get_entity_by_alias(self, alias):
if not alias:
return self.entities[0]
for entity in self.entities:
if isinstance(entity, AliasedInsp):
name = entity.name
else:
name = entity.__table__.name
if name == alias:
return entity
def assign_order_by(self, sort):
if not sort:
return self.query
sort = self.parse_sort_arg(sort)
expr = None expr = None
if sort['attr'] in self.labels: if attr in self.labels:
expr = sort['attr'] expr = attr
else: else:
entity = self.get_entity_by_alias(sort['entity']) entity = get_query_entity_by_alias(self.query, entity)
if entity: if entity:
expr = self.order_by_attr(entity, sort['attr']) expr = self.order_by_attr(entity, attr)
if expr is not None: if expr is not None:
return self.query.order_by( return self.query.order_by(func(expr))
sort['func'](expr)
)
if not self.silent: if not self.silent:
raise QuerySorterException( raise QuerySorterException(
"Could not sort query with expression '%s'" % sort['attr'] "Could not sort query with expression '%s'" % attr
) )
return self.query return self.query
@@ -93,7 +46,7 @@ class QuerySorter(object):
if isinstance(property_.columns[0], Label): if isinstance(property_.columns[0], Label):
expr = property_.columns[0].name expr = property_.columns[0].name
else: else:
expr = sort_expression(entity, property_.key) expr = get_expr_attr(entity, property_.key)
return expr return expr
else: else:
return return
@@ -119,9 +72,14 @@ class QuerySorter(object):
def __call__(self, query, *args): def __call__(self, query, *args):
self.query = query self.query = query
self.inspect_labels_and_entities() self.labels = query_labels(query)
self.entities = query_entities(query)
for sort in args: for sort in args:
self.query = self.assign_order_by(sort) if not sort:
continue
self.query = self.assign_order_by(
**self.parse_sort_arg(sort)
)
return self.query return self.query