From 3b3b5b2e6d469e95512c6ea206f0619de34e60dc Mon Sep 17 00:00:00 2001 From: Konsta Vesterinen Date: Wed, 23 Oct 2013 15:07:38 +0300 Subject: [PATCH] Added generates decorator, fixes #18 --- sqlalchemy_utils/__init__.py | 4 +- sqlalchemy_utils/decorators.py | 58 ++++++++++++++++++++++ tests/test_generates.py | 91 ++++++++++++++++++++++++++++++++++ 3 files changed, 152 insertions(+), 1 deletion(-) create mode 100644 sqlalchemy_utils/decorators.py create mode 100644 tests/test_generates.py diff --git a/sqlalchemy_utils/__init__.py b/sqlalchemy_utils/__init__.py index 0f6a6f4..4ced43c 100644 --- a/sqlalchemy_utils/__init__.py +++ b/sqlalchemy_utils/__init__.py @@ -1,3 +1,4 @@ +from .decorators import generates from .exceptions import ImproperlyConfigured from .functions import ( batch_fetch, @@ -49,6 +50,8 @@ __all__ = ( coercion_listener, defer_except, escape_like, + generates, + generic_relationship, instrumented_list, merge, primary_keys, @@ -60,7 +63,6 @@ __all__ = ( sort_query, table_name, with_backrefs, - generic_relationship, ArrowType, ColorType, CountryType, diff --git a/sqlalchemy_utils/decorators.py b/sqlalchemy_utils/decorators.py new file mode 100644 index 0000000..378053b --- /dev/null +++ b/sqlalchemy_utils/decorators.py @@ -0,0 +1,58 @@ +from collections import defaultdict +import itertools +import sqlalchemy as sa +import six + +generator_registry = defaultdict(list) +listeners_registered = False + + +def generates(attr): + register_listeners() + + def wraps(func): + def wrapper(self, *args, **kwargs): + return func(self, *args, **kwargs) + + if isinstance(attr, six.string_types) and '.' in attr: + parts = attr.split('.') + generator_registry[parts[0]].append(wrapper) + wrapper.__generates__ = parts[1] + else: + wrapper.__generates__ = attr + return wrapper + return wraps + + +def register_listeners(): + global listeners_registered + + if not listeners_registered: + sa.event.listen( + sa.orm.mapper, + 'mapper_configured', + update_generator_registry + ) + sa.event.listen( + sa.orm.session.Session, + 'before_flush', + update_generated_properties + ) + listeners_registered = True + + +def update_generator_registry(mapper, class_): + for value in class_.__dict__.values(): + if hasattr(value, '__generates__'): + generator_registry[class_.__name__].append(value) + + +def update_generated_properties(session, ctx, instances): + for obj in itertools.chain(session.new, session.dirty): + class_ = obj.__class__.__name__ + if class_ in generator_registry: + for func in generator_registry[class_]: + attr = func.__generates__ + if not isinstance(attr, six.string_types): + attr = attr.name + setattr(obj, attr, func(obj)) diff --git a/tests/test_generates.py b/tests/test_generates.py new file mode 100644 index 0000000..156a245 --- /dev/null +++ b/tests/test_generates.py @@ -0,0 +1,91 @@ +import sqlalchemy as sa +from sqlalchemy_utils import generates +from tests import TestCase + + +class TestGeneratesWithBoundMethodAndClassVariableArg(TestCase): + def create_models(self): + class Article(self.Base): + __tablename__ = 'article' + id = sa.Column(sa.Integer, primary_key=True) + name = sa.Column(sa.Unicode(255)) + slug = sa.Column(sa.Unicode(255)) + + @generates(slug) + def _create_slug(self): + return self.name.lower().replace(' ', '-') + + self.Article = Article + + def test_generates_value_before_flush(self): + article = self.Article() + article.name = u'some article name' + self.session.add(article) + self.session.flush() + assert article.slug == u'some-article-name' + + +class TestGeneratesWithBoundMethodAndStringArg(TestCase): + def create_models(self): + class Article(self.Base): + __tablename__ = 'article' + id = sa.Column(sa.Integer, primary_key=True) + name = sa.Column(sa.Unicode(255)) + slug = sa.Column(sa.Unicode(255)) + + @generates('slug') + def _create_slug(self): + return self.name.lower().replace(' ', '-') + + self.Article = Article + + def test_generates_value_before_flush(self): + article = self.Article() + article.name = u'some article name' + self.session.add(article) + self.session.flush() + assert article.slug == u'some-article-name' + + +class TestGeneratesWithFunctionAndStringArg(TestCase): + def create_models(self): + class Article(self.Base): + __tablename__ = 'article' + id = sa.Column(sa.Integer, primary_key=True) + name = sa.Column(sa.Unicode(255)) + slug = sa.Column(sa.Unicode(255)) + + @generates('Article.slug') + def _create_article_slug(self): + return self.name.lower().replace(' ', '-') + + self.Article = Article + + def test_generates_value_before_flush(self): + article = self.Article() + article.name = u'some article name' + self.session.add(article) + self.session.flush() + assert article.slug == u'some-article-name' + + +class TestGeneratesWithFunctionAndClassVariableArg(TestCase): + def create_models(self): + class Article(self.Base): + __tablename__ = 'article' + id = sa.Column(sa.Integer, primary_key=True) + name = sa.Column(sa.Unicode(255)) + slug = sa.Column(sa.Unicode(255)) + + @generates(Article.slug) + def _create_article_slug(self): + return self.name.lower().replace(' ', '-') + + self.Article = Article + + def test_generates_value_before_flush(self): + article = self.Article() + article.name = u'some article name' + self.session.add(article) + self.session.flush() + assert article.slug == u'some-article-name'