From b95a7d855ea4a46a42e1308f5e34eaeca0976150 Mon Sep 17 00:00:00 2001 From: Konsta Vesterinen Date: Thu, 30 Jan 2014 15:02:34 +0200 Subject: [PATCH] Add source param to generates --- sqlalchemy_utils/decorators.py | 27 +++++++++++----- tests/test_generates.py | 57 ++++++++++++++++++++++++++++++++++ 2 files changed, 77 insertions(+), 7 deletions(-) diff --git a/sqlalchemy_utils/decorators.py b/sqlalchemy_utils/decorators.py index 293d12e..9c44a03 100644 --- a/sqlalchemy_utils/decorators.py +++ b/sqlalchemy_utils/decorators.py @@ -4,6 +4,13 @@ import sqlalchemy as sa import six +def getdotattr(obj, dot_path): + """ + Allows dot-notated strings to be passed to `getattr` + """ + return reduce(getattr, dot_path.split('.'), obj) + + class AttributeValueGenerator(object): def __init__(self): self.listener_args = [ @@ -31,15 +38,15 @@ class AttributeValueGenerator(object): self.listeners_registered = False self.generator_registry = defaultdict(list) - def generator_wrapper(self, func, attr): + def generator_wrapper(self, func, attr, source): def wrapper(self, *args, **kwargs): return func(self, *args, **kwargs) if isinstance(attr, sa.orm.attributes.InstrumentedAttribute): self.generator_registry[attr.class_].append(wrapper) - wrapper.__generates__ = attr + wrapper.__generates__ = attr, source else: - wrapper.__generates__ = attr + wrapper.__generates__ = attr, source return wrapper def register_listeners(self): @@ -61,17 +68,23 @@ class AttributeValueGenerator(object): class_ = obj.__class__ if class_ in self.generator_registry: for func in self.generator_registry[class_]: - attr = func.__generates__ + attr, source = func.__generates__ if not isinstance(attr, six.string_types): attr = attr.name - setattr(obj, attr, func(obj)) + + if source is None: + setattr(obj, attr, func(obj)) + else: + setattr(obj, attr, func(obj, getdotattr(obj, source))) generator = AttributeValueGenerator() -def generates(attr, generator=generator): +def generates(attr, source=None, generator=generator): """ + Decorator that marks given function as attribute value generator. + Many times you may have generated property values. Usual cases include slugs from names or resized thumbnails from images. @@ -131,5 +144,5 @@ def generates(attr, generator=generator): generator.register_listeners() def wraps(func): - return generator.generator_wrapper(func, attr) + return generator.generator_wrapper(func, attr, source) return wraps diff --git a/tests/test_generates.py b/tests/test_generates.py index 6b4c9e1..0363aa9 100644 --- a/tests/test_generates.py +++ b/tests/test_generates.py @@ -55,3 +55,60 @@ class TestGeneratesWithFunctionAndClassVariableArg(GeneratesTestCase): return self.name.lower().replace(' ', '-') self.Article = Article + + + +class TestGeneratesWithSourcePath(TestCase): + def create_models(self): + class Document(self.Base): + __tablename__ = 'document' + id = sa.Column(sa.Integer, primary_key=True) + name = sa.Column(sa.Unicode(255)) + locale = sa.Column(sa.String(10)) + + class Section(self.Base): + __tablename__ = 'section' + id = sa.Column(sa.Integer, primary_key=True) + name = sa.Column(sa.Unicode(255)) + locale = sa.Column(sa.String(10)) + + document_id = sa.Column( + sa.Integer, sa.ForeignKey(Document.id) + ) + + document = sa.orm.relationship(Document) + + @generates(locale, source='document') + def copy_locale(self, document): + return document.locale + + class SubSection(self.Base): + __tablename__ = 'subsection' + id = sa.Column(sa.Integer, primary_key=True) + name = sa.Column(sa.Unicode(255)) + locale = sa.Column(sa.String(10)) + + section_id = sa.Column( + sa.Integer, sa.ForeignKey(Section.id) + ) + + section = sa.orm.relationship(Section) + + @generates(locale, source='section.document') + def copy_locale(self, document): + return document.locale + + + self.Document = Document + self.Section = Section + self.SubSection = SubSection + + def test_simple_source_paths(self): + document = self.Document(name=u'Document 1', locale='en') + section = self.Section(name=u'Section 1', document=document) + + self.session.add(document) + self.session.add(section) + self.session.commit() + + assert section.locale == 'en'