Added generates decorator, fixes #18

This commit is contained in:
Konsta Vesterinen
2013-10-23 15:07:38 +03:00
parent 70d0fe0383
commit 3b3b5b2e6d
3 changed files with 152 additions and 1 deletions

View File

@@ -1,3 +1,4 @@
from .decorators import generates
from .exceptions import ImproperlyConfigured
from .functions import (
batch_fetch,
@@ -49,6 +50,8 @@ __all__ = (
coercion_listener,
defer_except,
escape_like,
generates,
generic_relationship,
instrumented_list,
merge,
primary_keys,
@@ -60,7 +63,6 @@ __all__ = (
sort_query,
table_name,
with_backrefs,
generic_relationship,
ArrowType,
ColorType,
CountryType,

View File

@@ -0,0 +1,58 @@
from collections import defaultdict
import itertools
import sqlalchemy as sa
import six
generator_registry = defaultdict(list)
listeners_registered = False
def generates(attr):
register_listeners()
def wraps(func):
def wrapper(self, *args, **kwargs):
return func(self, *args, **kwargs)
if isinstance(attr, six.string_types) and '.' in attr:
parts = attr.split('.')
generator_registry[parts[0]].append(wrapper)
wrapper.__generates__ = parts[1]
else:
wrapper.__generates__ = attr
return wrapper
return wraps
def register_listeners():
global listeners_registered
if not listeners_registered:
sa.event.listen(
sa.orm.mapper,
'mapper_configured',
update_generator_registry
)
sa.event.listen(
sa.orm.session.Session,
'before_flush',
update_generated_properties
)
listeners_registered = True
def update_generator_registry(mapper, class_):
for value in class_.__dict__.values():
if hasattr(value, '__generates__'):
generator_registry[class_.__name__].append(value)
def update_generated_properties(session, ctx, instances):
for obj in itertools.chain(session.new, session.dirty):
class_ = obj.__class__.__name__
if class_ in generator_registry:
for func in generator_registry[class_]:
attr = func.__generates__
if not isinstance(attr, six.string_types):
attr = attr.name
setattr(obj, attr, func(obj))

91
tests/test_generates.py Normal file
View File

@@ -0,0 +1,91 @@
import sqlalchemy as sa
from sqlalchemy_utils import generates
from tests import TestCase
class TestGeneratesWithBoundMethodAndClassVariableArg(TestCase):
def create_models(self):
class Article(self.Base):
__tablename__ = 'article'
id = sa.Column(sa.Integer, primary_key=True)
name = sa.Column(sa.Unicode(255))
slug = sa.Column(sa.Unicode(255))
@generates(slug)
def _create_slug(self):
return self.name.lower().replace(' ', '-')
self.Article = Article
def test_generates_value_before_flush(self):
article = self.Article()
article.name = u'some article name'
self.session.add(article)
self.session.flush()
assert article.slug == u'some-article-name'
class TestGeneratesWithBoundMethodAndStringArg(TestCase):
def create_models(self):
class Article(self.Base):
__tablename__ = 'article'
id = sa.Column(sa.Integer, primary_key=True)
name = sa.Column(sa.Unicode(255))
slug = sa.Column(sa.Unicode(255))
@generates('slug')
def _create_slug(self):
return self.name.lower().replace(' ', '-')
self.Article = Article
def test_generates_value_before_flush(self):
article = self.Article()
article.name = u'some article name'
self.session.add(article)
self.session.flush()
assert article.slug == u'some-article-name'
class TestGeneratesWithFunctionAndStringArg(TestCase):
def create_models(self):
class Article(self.Base):
__tablename__ = 'article'
id = sa.Column(sa.Integer, primary_key=True)
name = sa.Column(sa.Unicode(255))
slug = sa.Column(sa.Unicode(255))
@generates('Article.slug')
def _create_article_slug(self):
return self.name.lower().replace(' ', '-')
self.Article = Article
def test_generates_value_before_flush(self):
article = self.Article()
article.name = u'some article name'
self.session.add(article)
self.session.flush()
assert article.slug == u'some-article-name'
class TestGeneratesWithFunctionAndClassVariableArg(TestCase):
def create_models(self):
class Article(self.Base):
__tablename__ = 'article'
id = sa.Column(sa.Integer, primary_key=True)
name = sa.Column(sa.Unicode(255))
slug = sa.Column(sa.Unicode(255))
@generates(Article.slug)
def _create_article_slug(self):
return self.name.lower().replace(' ', '-')
self.Article = Article
def test_generates_value_before_flush(self):
article = self.Article()
article.name = u'some article name'
self.session.add(article)
self.session.flush()
assert article.slug == u'some-article-name'