Refactor QuerySorter & batch fetch
This commit is contained in:
@@ -37,31 +37,33 @@ class with_backrefs(object):
|
|||||||
self.path = path
|
self.path = path
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class Path(object):
|
class Path(object):
|
||||||
"""
|
"""
|
||||||
A class that represents an attribute path.
|
A class that represents an attribute path.
|
||||||
"""
|
"""
|
||||||
def __init__(self, entities, prop, populate_backrefs=False):
|
def __init__(self, entities, prop, populate_backrefs=False):
|
||||||
|
self.validate_property(prop)
|
||||||
self.property = prop
|
self.property = prop
|
||||||
self.entities = entities
|
self.entities = entities
|
||||||
self.populate_backrefs = populate_backrefs
|
self.populate_backrefs = populate_backrefs
|
||||||
|
self.fetcher = self.fetcher_class(self)
|
||||||
|
|
||||||
|
def validate_property(self, prop):
|
||||||
if (
|
if (
|
||||||
not isinstance(self.property, RelationshipProperty) and
|
not isinstance(prop, RelationshipProperty) and
|
||||||
not isinstance(self.property, GenericRelationshipProperty)
|
not isinstance(prop, GenericRelationshipProperty)
|
||||||
):
|
):
|
||||||
raise PathException(
|
raise PathException(
|
||||||
'Given attribute is not a relationship property.'
|
'Given attribute is not a relationship property.'
|
||||||
)
|
)
|
||||||
self.fetcher = self.fetcher_class(self)
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def session(self):
|
def session(self):
|
||||||
return object_session(self.entities[0])
|
return object_session(self.entities[0])
|
||||||
|
|
||||||
@property
|
|
||||||
def parent_model(self):
|
|
||||||
return self.entities[0].__class__
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def model(self):
|
def model(self):
|
||||||
return self.property.mapper.class_
|
return self.property.mapper.class_
|
||||||
@@ -267,7 +269,7 @@ class Fetcher(object):
|
|||||||
)
|
)
|
||||||
for value in related_entities:
|
for value in related_entities:
|
||||||
backref_dict[local_values(self.prop, value[0])].append(
|
backref_dict[local_values(self.prop, value[0])].append(
|
||||||
self.path.session.query(self.path.parent_model).get(
|
self.path.session.query(self.path.entities[0].__class__).get(
|
||||||
tuple(value[1:])
|
tuple(value[1:])
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
@@ -2,6 +2,10 @@ from functools import partial
|
|||||||
from funcy import first
|
from funcy import first
|
||||||
from toolz import curry
|
from toolz import curry
|
||||||
import sqlalchemy as sa
|
import sqlalchemy as sa
|
||||||
|
from sqlalchemy import inspect
|
||||||
|
from sqlalchemy.orm.query import _ColumnEntity
|
||||||
|
from sqlalchemy.orm.mapper import Mapper
|
||||||
|
from sqlalchemy.orm.util import AliasedInsp
|
||||||
|
|
||||||
|
|
||||||
def remove_property(class_, name):
|
def remove_property(class_, name):
|
||||||
@@ -123,6 +127,83 @@ def remote_column_names(prop):
|
|||||||
yield remote.name
|
yield remote.name
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
def query_labels(query):
|
||||||
|
"""
|
||||||
|
Return all labels for given SQLAlchemy query object.
|
||||||
|
|
||||||
|
Example::
|
||||||
|
|
||||||
|
|
||||||
|
query = session.query(
|
||||||
|
Category,
|
||||||
|
db.func.count(Article.id).label('articles')
|
||||||
|
)
|
||||||
|
|
||||||
|
query_labels(query) # ('articles', )
|
||||||
|
|
||||||
|
:param query: SQLAlchemy Query object
|
||||||
|
"""
|
||||||
|
for entity in query._entities:
|
||||||
|
if isinstance(entity, _ColumnEntity) and entity._label_name:
|
||||||
|
yield entity._label_name
|
||||||
|
|
||||||
|
|
||||||
|
def query_entities(query):
|
||||||
|
"""
|
||||||
|
Return all entities for given SQLAlchemy query object.
|
||||||
|
|
||||||
|
Example::
|
||||||
|
|
||||||
|
|
||||||
|
query = session.query(
|
||||||
|
Category
|
||||||
|
)
|
||||||
|
|
||||||
|
query_entities(query) # ('Category', )
|
||||||
|
|
||||||
|
:param query: SQLAlchemy Query object
|
||||||
|
"""
|
||||||
|
for entity in query._entities:
|
||||||
|
yield entity.entity_zero.class_
|
||||||
|
|
||||||
|
for entity in query._join_entities:
|
||||||
|
if isinstance(entity, Mapper):
|
||||||
|
yield entity.class_
|
||||||
|
else:
|
||||||
|
yield entity
|
||||||
|
|
||||||
|
|
||||||
|
def get_query_entity_by_alias(query, alias):
|
||||||
|
entities = query_entities(query)
|
||||||
|
if not alias:
|
||||||
|
return first(entities)
|
||||||
|
|
||||||
|
for entity in entities:
|
||||||
|
if isinstance(entity, AliasedInsp):
|
||||||
|
name = entity.name
|
||||||
|
else:
|
||||||
|
name = entity.__table__.name
|
||||||
|
|
||||||
|
if name == alias:
|
||||||
|
return entity
|
||||||
|
|
||||||
|
|
||||||
|
def attrs(expr):
|
||||||
|
if isinstance(expr, AliasedInsp):
|
||||||
|
return expr.mapper.attrs
|
||||||
|
else:
|
||||||
|
return inspect(expr).attrs
|
||||||
|
|
||||||
|
|
||||||
|
def get_expr_attr(expr, attr_name):
|
||||||
|
if isinstance(expr, AliasedInsp):
|
||||||
|
return getattr(expr.selectable.c, attr_name)
|
||||||
|
else:
|
||||||
|
return getattr(expr, attr_name)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def declarative_base(model):
|
def declarative_base(model):
|
||||||
"""
|
"""
|
||||||
Returns the declarative base for given model class.
|
Returns the declarative base for given model class.
|
||||||
|
@@ -1,23 +1,13 @@
|
|||||||
from sqlalchemy import inspect
|
|
||||||
from sqlalchemy.orm.mapper import Mapper
|
|
||||||
from sqlalchemy.orm.properties import ColumnProperty
|
from sqlalchemy.orm.properties import ColumnProperty
|
||||||
from sqlalchemy.orm.query import _ColumnEntity
|
|
||||||
from sqlalchemy.orm.util import AliasedInsp
|
|
||||||
from sqlalchemy.sql.expression import desc, asc, Label
|
from sqlalchemy.sql.expression import desc, asc, Label
|
||||||
|
from sqlalchemy.orm.util import AliasedInsp
|
||||||
|
from .orm import (
|
||||||
def attrs(expr):
|
attrs,
|
||||||
if isinstance(expr, AliasedInsp):
|
query_labels,
|
||||||
return expr.mapper.attrs
|
query_entities,
|
||||||
else:
|
get_query_entity_by_alias,
|
||||||
return inspect(expr).attrs
|
get_expr_attr
|
||||||
|
)
|
||||||
|
|
||||||
def sort_expression(expr, attr_name):
|
|
||||||
if isinstance(expr, AliasedInsp):
|
|
||||||
return getattr(expr.selectable.c, attr_name)
|
|
||||||
else:
|
|
||||||
return getattr(expr, attr_name)
|
|
||||||
|
|
||||||
|
|
||||||
class QuerySorterException(Exception):
|
class QuerySorterException(Exception):
|
||||||
@@ -31,57 +21,20 @@ class QuerySorter(object):
|
|||||||
self.separator = separator
|
self.separator = separator
|
||||||
self.silent = silent
|
self.silent = silent
|
||||||
|
|
||||||
def inspect_labels_and_entities(self):
|
def assign_order_by(self, entity, attr, func):
|
||||||
for entity in self.query._entities:
|
|
||||||
# get all label names for queries such as:
|
|
||||||
# db.session.query(
|
|
||||||
# Category,
|
|
||||||
# db.func.count(Article.id).label('articles')
|
|
||||||
# )
|
|
||||||
if isinstance(entity, _ColumnEntity) and entity._label_name:
|
|
||||||
self.labels.append(entity._label_name)
|
|
||||||
else:
|
|
||||||
self.entities.append(entity.entity_zero.class_)
|
|
||||||
|
|
||||||
for mapper in self.query._join_entities:
|
|
||||||
if isinstance(mapper, Mapper):
|
|
||||||
self.entities.append(mapper.class_)
|
|
||||||
else:
|
|
||||||
self.entities.append(mapper)
|
|
||||||
|
|
||||||
def get_entity_by_alias(self, alias):
|
|
||||||
if not alias:
|
|
||||||
return self.entities[0]
|
|
||||||
|
|
||||||
for entity in self.entities:
|
|
||||||
if isinstance(entity, AliasedInsp):
|
|
||||||
name = entity.name
|
|
||||||
else:
|
|
||||||
name = entity.__table__.name
|
|
||||||
|
|
||||||
if name == alias:
|
|
||||||
return entity
|
|
||||||
|
|
||||||
def assign_order_by(self, sort):
|
|
||||||
if not sort:
|
|
||||||
return self.query
|
|
||||||
|
|
||||||
sort = self.parse_sort_arg(sort)
|
|
||||||
expr = None
|
expr = None
|
||||||
if sort['attr'] in self.labels:
|
if attr in self.labels:
|
||||||
expr = sort['attr']
|
expr = attr
|
||||||
else:
|
else:
|
||||||
entity = self.get_entity_by_alias(sort['entity'])
|
entity = get_query_entity_by_alias(self.query, entity)
|
||||||
if entity:
|
if entity:
|
||||||
expr = self.order_by_attr(entity, sort['attr'])
|
expr = self.order_by_attr(entity, attr)
|
||||||
|
|
||||||
if expr is not None:
|
if expr is not None:
|
||||||
return self.query.order_by(
|
return self.query.order_by(func(expr))
|
||||||
sort['func'](expr)
|
|
||||||
)
|
|
||||||
if not self.silent:
|
if not self.silent:
|
||||||
raise QuerySorterException(
|
raise QuerySorterException(
|
||||||
"Could not sort query with expression '%s'" % sort['attr']
|
"Could not sort query with expression '%s'" % attr
|
||||||
)
|
)
|
||||||
return self.query
|
return self.query
|
||||||
|
|
||||||
@@ -93,7 +46,7 @@ class QuerySorter(object):
|
|||||||
if isinstance(property_.columns[0], Label):
|
if isinstance(property_.columns[0], Label):
|
||||||
expr = property_.columns[0].name
|
expr = property_.columns[0].name
|
||||||
else:
|
else:
|
||||||
expr = sort_expression(entity, property_.key)
|
expr = get_expr_attr(entity, property_.key)
|
||||||
return expr
|
return expr
|
||||||
else:
|
else:
|
||||||
return
|
return
|
||||||
@@ -119,9 +72,14 @@ class QuerySorter(object):
|
|||||||
|
|
||||||
def __call__(self, query, *args):
|
def __call__(self, query, *args):
|
||||||
self.query = query
|
self.query = query
|
||||||
self.inspect_labels_and_entities()
|
self.labels = query_labels(query)
|
||||||
|
self.entities = query_entities(query)
|
||||||
for sort in args:
|
for sort in args:
|
||||||
self.query = self.assign_order_by(sort)
|
if not sort:
|
||||||
|
continue
|
||||||
|
self.query = self.assign_order_by(
|
||||||
|
**self.parse_sort_arg(sort)
|
||||||
|
)
|
||||||
return self.query
|
return self.query
|
||||||
|
|
||||||
|
|
||||||
|
Reference in New Issue
Block a user