From 74392872e6c1f55f8996d4349c6ddd5d9ec5105d Mon Sep 17 00:00:00 2001 From: Konsta Vesterinen Date: Thu, 16 Jul 2015 12:02:53 +0300 Subject: [PATCH] Refactor aggregates, remove SA 0.8 support --- sqlalchemy_utils/aggregates.py | 47 +++++++++++++++------------------- 1 file changed, 21 insertions(+), 26 deletions(-) diff --git a/sqlalchemy_utils/aggregates.py b/sqlalchemy_utils/aggregates.py index c02b1f2..69e1087 100644 --- a/sqlalchemy_utils/aggregates.py +++ b/sqlalchemy_utils/aggregates.py @@ -368,16 +368,14 @@ from weakref import WeakKeyDictionary import six import sqlalchemy as sa from sqlalchemy.ext.declarative import declared_attr +from sqlalchemy.sql.functions import _FunctionGenerator from .functions.orm import get_column_key -from .relationships import chained_join, select_aggregate - -try: - # SQLAlchemy 0.9 - from sqlalchemy.sql.functions import _FunctionGenerator -except ImportError: - # SQLAlchemy 0.8 - from sqlalchemy.sql.expression import _FunctionGenerator +from .relationships import ( + chained_join, + path_to_relationships, + select_aggregate +) aggregated_attrs = WeakKeyDictionary(defaultdict(list)) @@ -438,10 +436,13 @@ def aggregate_expression(expr, class_): class AggregatedValue(object): - def __init__(self, class_, attr, relationships, expr): + def __init__(self, class_, attr, path, expr): self.class_ = class_ self.attr = attr - self.relationships = relationships + self.path = path + self.relationships = list( + reversed(path_to_relationships(path, class_)) + ) self.expr = aggregate_expression(expr, class_) @property @@ -515,22 +516,16 @@ class AggregationManager(object): def update_generator_registry(self): for class_, attrs in six.iteritems(aggregated_attrs): - for expr, relationship, column in attrs: - relationships = [] - rel_class = class_ - - for path_name in relationship.split('.'): - rel = getattr(rel_class, path_name) - relationships.append(rel) - rel_class = rel.mapper.class_ - - self.generator_registry[rel_class].append( - AggregatedValue( - class_=class_, - attr=column, - relationships=list(reversed(relationships)), - expr=expr(class_) - ) + for expr, path, column in attrs: + value = AggregatedValue( + class_=class_, + attr=column, + path=path, + expr=expr(class_) + ) + key = value.relationships[0].mapper.class_ + self.generator_registry[key].append( + value ) def construct_aggregate_queries(self, session, ctx):