From 6b64bf24ed6932b55c53e97e9db0745dc71c09a3 Mon Sep 17 00:00:00 2001 From: Konsta Vesterinen Date: Tue, 18 Feb 2014 14:11:54 +0200 Subject: [PATCH] Add attribute listeners --- sqlalchemy_utils/decorators.py | 33 +++++++++----- tests/test_generates.py | 80 +++++++++++++++++++++++++++------- 2 files changed, 88 insertions(+), 25 deletions(-) diff --git a/sqlalchemy_utils/decorators.py b/sqlalchemy_utils/decorators.py index 469ddbe..e6b2828 100644 --- a/sqlalchemy_utils/decorators.py +++ b/sqlalchemy_utils/decorators.py @@ -2,13 +2,8 @@ from collections import defaultdict import itertools import sqlalchemy as sa import six - - -def getdotattr(obj, dot_path): - """ - Allows dot-notated strings to be passed to `getattr` - """ - return six.moves.reduce(getattr, dot_path.split('.'), obj) +from .functions import getdotattr +from .path import AttrPath class AttributeValueGenerator(object): @@ -59,9 +54,27 @@ class AttributeValueGenerator(object): """ Adds generator functions to generator_registry. """ - for value in class_.__dict__.values(): - if hasattr(value, '__generates__'): - self.generator_registry[class_].append(value) + for generator in class_.__dict__.values(): + if hasattr(generator, '__generates__'): + self.generator_registry[class_].append(generator) + + path = generator.__generates__[1] + column_key = generator.__generates__[0] + if not isinstance(column_key, six.string_types): + column_key = column_key.key + if path: + path = AttrPath(class_, path) + attr = getdotattr(class_, str(path)) + + if isinstance(attr.property, sa.orm.ColumnProperty): + for attr in path[0:-1]: + @sa.event.listens_for(attr, 'set') + def receive_attr_set(target, value, oldvalue, initiator): + setattr( + target, + column_key, + getattr(value, path[-1].key) + ) def update_generated_properties(self, session, ctx, instances): for obj in itertools.chain(session.new, session.dirty): diff --git a/tests/test_generates.py b/tests/test_generates.py index 29a0cb4..24b54f7 100644 --- a/tests/test_generates.py +++ b/tests/test_generates.py @@ -57,8 +57,29 @@ class TestGeneratesWithFunctionAndClassVariableArg(GeneratesTestCase): self.Article = Article +class DeepPathGeneratesTestCase(TestCase): + def test_simple_dotted_source_path(self): + document = self.Document(name=u'Document 1', locale='en') + section = self.Section(name=u'Section 1', document=document) -class TestGeneratesWithSourcePath(TestCase): + self.session.add(document) + self.session.add(section) + self.session.commit() + + assert section.locale == 'en' + + def test_deep_dotted_source_path(self): + document = self.Document(name=u'Document 1', locale='en') + section = self.Section(name=u'Section 1', document=document) + subsection = self.SubSection(name=u'Section 1', section=section) + + self.session.add(subsection) + self.session.commit() + + assert subsection.locale == 'en' + + +class TestGeneratesWithSourcePath(DeepPathGeneratesTestCase): def create_models(self): class Document(self.Base): __tablename__ = 'document' @@ -103,22 +124,51 @@ class TestGeneratesWithSourcePath(TestCase): self.Section = Section self.SubSection = SubSection - def test_simple_dotted_source_path(self): + + +class TestTwoWayAttributeValueGeneration(DeepPathGeneratesTestCase): + def create_models(self): + class Document(self.Base): + __tablename__ = 'document' + id = sa.Column(sa.Integer, primary_key=True) + name = sa.Column(sa.Unicode(255)) + locale = sa.Column(sa.String(10)) + + class Section(self.Base): + __tablename__ = 'section' + id = sa.Column(sa.Integer, primary_key=True) + name = sa.Column(sa.Unicode(255)) + locale = sa.Column(sa.String(10)) + + document_id = sa.Column( + sa.Integer, sa.ForeignKey(Document.id) + ) + + document = sa.orm.relationship(Document) + + class SubSection(self.Base): + __tablename__ = 'subsection' + id = sa.Column(sa.Integer, primary_key=True) + name = sa.Column(sa.Unicode(255)) + locale = sa.Column(sa.String(10)) + + section_id = sa.Column( + sa.Integer, sa.ForeignKey(Section.id) + ) + + section = sa.orm.relationship(Section) + + @generates(locale, source='section.document.locale') + def copy_locale(self, value): + return value + + self.Document = Document + self.Section = Section + self.SubSection = SubSection + + def test_change_parent_attribute(self): document = self.Document(name=u'Document 1', locale='en') section = self.Section(name=u'Section 1', document=document) - - self.session.add(document) - self.session.add(section) - self.session.commit() - assert section.locale == 'en' - - def test_deep_dotted_source_path(self): - document = self.Document(name=u'Document 1', locale='en') - section = self.Section(name=u'Section 1', document=document) subsection = self.SubSection(name=u'Section 1', section=section) - - self.session.add(subsection) - self.session.commit() - assert subsection.locale == 'en'