Refactor aggregates, remove SA 0.8 support

This commit is contained in:
Konsta Vesterinen
2015-07-16 12:02:53 +03:00
parent 6ad148aae1
commit 74392872e6

View File

@@ -368,16 +368,14 @@ from weakref import WeakKeyDictionary
import six import six
import sqlalchemy as sa import sqlalchemy as sa
from sqlalchemy.ext.declarative import declared_attr from sqlalchemy.ext.declarative import declared_attr
from sqlalchemy.sql.functions import _FunctionGenerator
from .functions.orm import get_column_key from .functions.orm import get_column_key
from .relationships import chained_join, select_aggregate from .relationships import (
chained_join,
try: path_to_relationships,
# SQLAlchemy 0.9 select_aggregate
from sqlalchemy.sql.functions import _FunctionGenerator )
except ImportError:
# SQLAlchemy 0.8
from sqlalchemy.sql.expression import _FunctionGenerator
aggregated_attrs = WeakKeyDictionary(defaultdict(list)) aggregated_attrs = WeakKeyDictionary(defaultdict(list))
@@ -438,10 +436,13 @@ def aggregate_expression(expr, class_):
class AggregatedValue(object): class AggregatedValue(object):
def __init__(self, class_, attr, relationships, expr): def __init__(self, class_, attr, path, expr):
self.class_ = class_ self.class_ = class_
self.attr = attr self.attr = attr
self.relationships = relationships self.path = path
self.relationships = list(
reversed(path_to_relationships(path, class_))
)
self.expr = aggregate_expression(expr, class_) self.expr = aggregate_expression(expr, class_)
@property @property
@@ -515,22 +516,16 @@ class AggregationManager(object):
def update_generator_registry(self): def update_generator_registry(self):
for class_, attrs in six.iteritems(aggregated_attrs): for class_, attrs in six.iteritems(aggregated_attrs):
for expr, relationship, column in attrs: for expr, path, column in attrs:
relationships = [] value = AggregatedValue(
rel_class = class_ class_=class_,
attr=column,
for path_name in relationship.split('.'): path=path,
rel = getattr(rel_class, path_name) expr=expr(class_)
relationships.append(rel) )
rel_class = rel.mapper.class_ key = value.relationships[0].mapper.class_
self.generator_registry[key].append(
self.generator_registry[rel_class].append( value
AggregatedValue(
class_=class_,
attr=column,
relationships=list(reversed(relationships)),
expr=expr(class_)
)
) )
def construct_aggregate_queries(self, session, ctx): def construct_aggregate_queries(self, session, ctx):