From 1899cccfb6166a0127ce286f9c1c9f79c5bcb641 Mon Sep 17 00:00:00 2001 From: Konsta Vesterinen Date: Fri, 1 Nov 2013 15:38:28 +0200 Subject: [PATCH] Refactored generates decorator --- sqlalchemy_utils/decorators.py | 109 ++++++++++++++++++--------------- tests/test_generates.py | 6 +- 2 files changed, 61 insertions(+), 54 deletions(-) diff --git a/sqlalchemy_utils/decorators.py b/sqlalchemy_utils/decorators.py index 9940b87..25589d6 100644 --- a/sqlalchemy_utils/decorators.py +++ b/sqlalchemy_utils/decorators.py @@ -3,11 +3,64 @@ import itertools import sqlalchemy as sa import six -generator_registry = defaultdict(list) -listeners_registered = False + +class AttributeValueGenerator(object): + def __init__(self): + self.reset() + + def reset(self): + self.generator_registry = defaultdict(list) + self.listeners_registered = False + + def generator_wrapper(self, func, attr): + def wrapper(self, *args, **kwargs): + return func(self, *args, **kwargs) + + if isinstance(attr, six.string_types) and '.' in attr: + parts = attr.split('.') + self.generator_registry[parts[0]].append(wrapper) + wrapper.__generates__ = parts[1] + elif isinstance(attr, sa.orm.attributes.InstrumentedAttribute): + self.generator_registry[attr.class_.__name__].append(wrapper) + wrapper.__generates__ = attr + else: + wrapper.__generates__ = attr + return wrapper + + def register_listeners(self): + if not self.listeners_registered: + sa.event.listen( + sa.orm.mapper, + 'mapper_configured', + self.update_generator_registry + ) + sa.event.listen( + sa.orm.session.Session, + 'before_flush', + self.update_generated_properties + ) + self.listeners_registered = True + + def update_generator_registry(self, mapper, class_): + for value in class_.__dict__.values(): + if hasattr(value, '__generates__'): + self.generator_registry[class_.__name__].append(value) + + def update_generated_properties(self, session, ctx, instances): + for obj in itertools.chain(session.new, session.dirty): + class_ = obj.__class__.__name__ + if class_ in self.generator_registry: + for func in self.generator_registry[class_]: + attr = func.__generates__ + if not isinstance(attr, six.string_types): + attr = attr.name + setattr(obj, attr, func(obj)) -def generates(attr): +generator = AttributeValueGenerator() + + +def generates(attr, generator=generator): """ Many times you may have generated property values. Usual cases include slugs from names or resized thumbnails from images. @@ -75,54 +128,8 @@ def generates(attr): return article.name.lower().replace(' ', '-') """ - register_listeners() + generator.register_listeners() def wraps(func): - def wrapper(self, *args, **kwargs): - return func(self, *args, **kwargs) - - if isinstance(attr, six.string_types) and '.' in attr: - parts = attr.split('.') - generator_registry[parts[0]].append(wrapper) - wrapper.__generates__ = parts[1] - elif isinstance(attr, sa.orm.attributes.InstrumentedAttribute): - generator_registry[attr.class_.__name__].append(wrapper) - wrapper.__generates__ = attr - else: - wrapper.__generates__ = attr - return wrapper + return generator.generator_wrapper(func, attr) return wraps - - -def register_listeners(): - global listeners_registered - - if not listeners_registered: - sa.event.listen( - sa.orm.mapper, - 'mapper_configured', - update_generator_registry - ) - sa.event.listen( - sa.orm.session.Session, - 'before_flush', - update_generated_properties - ) - listeners_registered = True - - -def update_generator_registry(mapper, class_): - for value in class_.__dict__.values(): - if hasattr(value, '__generates__'): - generator_registry[class_.__name__].append(value) - - -def update_generated_properties(session, ctx, instances): - for obj in itertools.chain(session.new, session.dirty): - class_ = obj.__class__.__name__ - if class_ in generator_registry: - for func in generator_registry[class_]: - attr = func.__generates__ - if not isinstance(attr, six.string_types): - attr = attr.name - setattr(obj, attr, func(obj)) diff --git a/tests/test_generates.py b/tests/test_generates.py index 9d02427..29aaedf 100644 --- a/tests/test_generates.py +++ b/tests/test_generates.py @@ -1,13 +1,13 @@ -from collections import defaultdict import sqlalchemy as sa -from sqlalchemy_utils import generates, decorators +from sqlalchemy_utils import generates +from sqlalchemy_utils.decorators import generator from tests import TestCase class GeneratesTestCase(TestCase): def teardown_method(self, method): TestCase.teardown_method(self, method) - decorators.generator_registry = defaultdict(list) + generator.reset() def test_generates_value_before_flush(self): article = self.Article()