diff --git a/sqlalchemy_utils/batch.py b/sqlalchemy_utils/batch.py index 620fa95..bbbca33 100644 --- a/sqlalchemy_utils/batch.py +++ b/sqlalchemy_utils/batch.py @@ -37,31 +37,33 @@ class with_backrefs(object): self.path = path + + + class Path(object): """ A class that represents an attribute path. """ def __init__(self, entities, prop, populate_backrefs=False): + self.validate_property(prop) self.property = prop self.entities = entities self.populate_backrefs = populate_backrefs + self.fetcher = self.fetcher_class(self) + + def validate_property(self, prop): if ( - not isinstance(self.property, RelationshipProperty) and - not isinstance(self.property, GenericRelationshipProperty) + not isinstance(prop, RelationshipProperty) and + not isinstance(prop, GenericRelationshipProperty) ): raise PathException( 'Given attribute is not a relationship property.' ) - self.fetcher = self.fetcher_class(self) @property def session(self): return object_session(self.entities[0]) - @property - def parent_model(self): - return self.entities[0].__class__ - @property def model(self): return self.property.mapper.class_ @@ -267,7 +269,7 @@ class Fetcher(object): ) for value in related_entities: backref_dict[local_values(self.prop, value[0])].append( - self.path.session.query(self.path.parent_model).get( + self.path.session.query(self.path.entities[0].__class__).get( tuple(value[1:]) ) ) diff --git a/sqlalchemy_utils/functions/orm.py b/sqlalchemy_utils/functions/orm.py index 14f4e09..609eeab 100644 --- a/sqlalchemy_utils/functions/orm.py +++ b/sqlalchemy_utils/functions/orm.py @@ -2,6 +2,10 @@ from functools import partial from funcy import first from toolz import curry import sqlalchemy as sa +from sqlalchemy import inspect +from sqlalchemy.orm.query import _ColumnEntity +from sqlalchemy.orm.mapper import Mapper +from sqlalchemy.orm.util import AliasedInsp def remove_property(class_, name): @@ -123,6 +127,83 @@ def remote_column_names(prop): yield remote.name + +def query_labels(query): + """ + Return all labels for given SQLAlchemy query object. + + Example:: + + + query = session.query( + Category, + db.func.count(Article.id).label('articles') + ) + + query_labels(query) # ('articles', ) + + :param query: SQLAlchemy Query object + """ + for entity in query._entities: + if isinstance(entity, _ColumnEntity) and entity._label_name: + yield entity._label_name + + +def query_entities(query): + """ + Return all entities for given SQLAlchemy query object. + + Example:: + + + query = session.query( + Category + ) + + query_entities(query) # ('Category', ) + + :param query: SQLAlchemy Query object + """ + for entity in query._entities: + yield entity.entity_zero.class_ + + for entity in query._join_entities: + if isinstance(entity, Mapper): + yield entity.class_ + else: + yield entity + + +def get_query_entity_by_alias(query, alias): + entities = query_entities(query) + if not alias: + return first(entities) + + for entity in entities: + if isinstance(entity, AliasedInsp): + name = entity.name + else: + name = entity.__table__.name + + if name == alias: + return entity + + +def attrs(expr): + if isinstance(expr, AliasedInsp): + return expr.mapper.attrs + else: + return inspect(expr).attrs + + +def get_expr_attr(expr, attr_name): + if isinstance(expr, AliasedInsp): + return getattr(expr.selectable.c, attr_name) + else: + return getattr(expr, attr_name) + + + def declarative_base(model): """ Returns the declarative base for given model class. diff --git a/sqlalchemy_utils/functions/sort_query.py b/sqlalchemy_utils/functions/sort_query.py index 468b774..281153f 100644 --- a/sqlalchemy_utils/functions/sort_query.py +++ b/sqlalchemy_utils/functions/sort_query.py @@ -1,23 +1,13 @@ -from sqlalchemy import inspect -from sqlalchemy.orm.mapper import Mapper from sqlalchemy.orm.properties import ColumnProperty -from sqlalchemy.orm.query import _ColumnEntity -from sqlalchemy.orm.util import AliasedInsp from sqlalchemy.sql.expression import desc, asc, Label - - -def attrs(expr): - if isinstance(expr, AliasedInsp): - return expr.mapper.attrs - else: - return inspect(expr).attrs - - -def sort_expression(expr, attr_name): - if isinstance(expr, AliasedInsp): - return getattr(expr.selectable.c, attr_name) - else: - return getattr(expr, attr_name) +from sqlalchemy.orm.util import AliasedInsp +from .orm import ( + attrs, + query_labels, + query_entities, + get_query_entity_by_alias, + get_expr_attr +) class QuerySorterException(Exception): @@ -31,57 +21,20 @@ class QuerySorter(object): self.separator = separator self.silent = silent - def inspect_labels_and_entities(self): - for entity in self.query._entities: - # get all label names for queries such as: - # db.session.query( - # Category, - # db.func.count(Article.id).label('articles') - # ) - if isinstance(entity, _ColumnEntity) and entity._label_name: - self.labels.append(entity._label_name) - else: - self.entities.append(entity.entity_zero.class_) - - for mapper in self.query._join_entities: - if isinstance(mapper, Mapper): - self.entities.append(mapper.class_) - else: - self.entities.append(mapper) - - def get_entity_by_alias(self, alias): - if not alias: - return self.entities[0] - - for entity in self.entities: - if isinstance(entity, AliasedInsp): - name = entity.name - else: - name = entity.__table__.name - - if name == alias: - return entity - - def assign_order_by(self, sort): - if not sort: - return self.query - - sort = self.parse_sort_arg(sort) + def assign_order_by(self, entity, attr, func): expr = None - if sort['attr'] in self.labels: - expr = sort['attr'] + if attr in self.labels: + expr = attr else: - entity = self.get_entity_by_alias(sort['entity']) + entity = get_query_entity_by_alias(self.query, entity) if entity: - expr = self.order_by_attr(entity, sort['attr']) + expr = self.order_by_attr(entity, attr) if expr is not None: - return self.query.order_by( - sort['func'](expr) - ) + return self.query.order_by(func(expr)) if not self.silent: raise QuerySorterException( - "Could not sort query with expression '%s'" % sort['attr'] + "Could not sort query with expression '%s'" % attr ) return self.query @@ -93,7 +46,7 @@ class QuerySorter(object): if isinstance(property_.columns[0], Label): expr = property_.columns[0].name else: - expr = sort_expression(entity, property_.key) + expr = get_expr_attr(entity, property_.key) return expr else: return @@ -119,9 +72,14 @@ class QuerySorter(object): def __call__(self, query, *args): self.query = query - self.inspect_labels_and_entities() + self.labels = query_labels(query) + self.entities = query_entities(query) for sort in args: - self.query = self.assign_order_by(sort) + if not sort: + continue + self.query = self.assign_order_by( + **self.parse_sort_arg(sort) + ) return self.query