Refactored generates decorator

This commit is contained in:
Konsta Vesterinen
2013-11-01 15:38:28 +02:00
parent 8c7dab337e
commit 1899cccfb6
2 changed files with 61 additions and 54 deletions

View File

@@ -3,11 +3,64 @@ import itertools
import sqlalchemy as sa import sqlalchemy as sa
import six import six
generator_registry = defaultdict(list)
listeners_registered = False class AttributeValueGenerator(object):
def __init__(self):
self.reset()
def reset(self):
self.generator_registry = defaultdict(list)
self.listeners_registered = False
def generator_wrapper(self, func, attr):
def wrapper(self, *args, **kwargs):
return func(self, *args, **kwargs)
if isinstance(attr, six.string_types) and '.' in attr:
parts = attr.split('.')
self.generator_registry[parts[0]].append(wrapper)
wrapper.__generates__ = parts[1]
elif isinstance(attr, sa.orm.attributes.InstrumentedAttribute):
self.generator_registry[attr.class_.__name__].append(wrapper)
wrapper.__generates__ = attr
else:
wrapper.__generates__ = attr
return wrapper
def register_listeners(self):
if not self.listeners_registered:
sa.event.listen(
sa.orm.mapper,
'mapper_configured',
self.update_generator_registry
)
sa.event.listen(
sa.orm.session.Session,
'before_flush',
self.update_generated_properties
)
self.listeners_registered = True
def update_generator_registry(self, mapper, class_):
for value in class_.__dict__.values():
if hasattr(value, '__generates__'):
self.generator_registry[class_.__name__].append(value)
def update_generated_properties(self, session, ctx, instances):
for obj in itertools.chain(session.new, session.dirty):
class_ = obj.__class__.__name__
if class_ in self.generator_registry:
for func in self.generator_registry[class_]:
attr = func.__generates__
if not isinstance(attr, six.string_types):
attr = attr.name
setattr(obj, attr, func(obj))
def generates(attr): generator = AttributeValueGenerator()
def generates(attr, generator=generator):
""" """
Many times you may have generated property values. Usual cases include Many times you may have generated property values. Usual cases include
slugs from names or resized thumbnails from images. slugs from names or resized thumbnails from images.
@@ -75,54 +128,8 @@ def generates(attr):
return article.name.lower().replace(' ', '-') return article.name.lower().replace(' ', '-')
""" """
register_listeners() generator.register_listeners()
def wraps(func): def wraps(func):
def wrapper(self, *args, **kwargs): return generator.generator_wrapper(func, attr)
return func(self, *args, **kwargs)
if isinstance(attr, six.string_types) and '.' in attr:
parts = attr.split('.')
generator_registry[parts[0]].append(wrapper)
wrapper.__generates__ = parts[1]
elif isinstance(attr, sa.orm.attributes.InstrumentedAttribute):
generator_registry[attr.class_.__name__].append(wrapper)
wrapper.__generates__ = attr
else:
wrapper.__generates__ = attr
return wrapper
return wraps return wraps
def register_listeners():
global listeners_registered
if not listeners_registered:
sa.event.listen(
sa.orm.mapper,
'mapper_configured',
update_generator_registry
)
sa.event.listen(
sa.orm.session.Session,
'before_flush',
update_generated_properties
)
listeners_registered = True
def update_generator_registry(mapper, class_):
for value in class_.__dict__.values():
if hasattr(value, '__generates__'):
generator_registry[class_.__name__].append(value)
def update_generated_properties(session, ctx, instances):
for obj in itertools.chain(session.new, session.dirty):
class_ = obj.__class__.__name__
if class_ in generator_registry:
for func in generator_registry[class_]:
attr = func.__generates__
if not isinstance(attr, six.string_types):
attr = attr.name
setattr(obj, attr, func(obj))

View File

@@ -1,13 +1,13 @@
from collections import defaultdict
import sqlalchemy as sa import sqlalchemy as sa
from sqlalchemy_utils import generates, decorators from sqlalchemy_utils import generates
from sqlalchemy_utils.decorators import generator
from tests import TestCase from tests import TestCase
class GeneratesTestCase(TestCase): class GeneratesTestCase(TestCase):
def teardown_method(self, method): def teardown_method(self, method):
TestCase.teardown_method(self, method) TestCase.teardown_method(self, method)
decorators.generator_registry = defaultdict(list) generator.reset()
def test_generates_value_before_flush(self): def test_generates_value_before_flush(self):
article = self.Article() article = self.Article()