diff --git a/sqlalchemy_utils/decorators.py b/sqlalchemy_utils/decorators.py index e6b2828..135298e 100644 --- a/sqlalchemy_utils/decorators.py +++ b/sqlalchemy_utils/decorators.py @@ -67,14 +67,64 @@ class AttributeValueGenerator(object): attr = getdotattr(class_, str(path)) if isinstance(attr.property, sa.orm.ColumnProperty): - for attr in path[0:-1]: - @sa.event.listens_for(attr, 'set') - def receive_attr_set(target, value, oldvalue, initiator): - setattr( - target, - column_key, - getattr(value, path[-1].key) - ) + for attr in path: + self.generate_property_observer( + path, + attr, + column_key + ) + + def generate_property_observer(self, path, attr, property_key): + """ + Generate SQLAlchemy listener that observes given attr within given + path space. + """ + @sa.event.listens_for(attr, 'set') + def receive_attr_set(target, value, oldvalue, initiator): + index = path.index(attr) + if not index: + setattr( + target, + property_key, + getdotattr(value, str(path[1:])) + ) + elif index == len(path) - 1: + inversed_path = ~path[0:-1] + entities = list( + getdotattr( + target, + str(inversed_path) + ) + ) + for entity in entities: + if isinstance(entity, list): + for e in entity: + setattr( + e, + property_key, + value + ) + else: + setattr( + entity, + property_key, + value + ) + else: + inversed_path = ~path[0:-1] + entities = list( + getdotattr( + target, + str(inversed_path[index:]) + ) + ) + for entity in entities: + setattr( + entity, + property_key, + getdotattr(value, str(path[(index + 1):])) + ) + def update_generated_properties(self, session, ctx, instances): for obj in itertools.chain(session.new, session.dirty): diff --git a/sqlalchemy_utils/path.py b/sqlalchemy_utils/path.py index 765fb0e..b56bcd9 100644 --- a/sqlalchemy_utils/path.py +++ b/sqlalchemy_utils/path.py @@ -1,3 +1,4 @@ +import sqlalchemy as sa from sqlalchemy.orm.attributes import InstrumentedAttribute from .utils import str_coercible @@ -86,8 +87,13 @@ class AttrPath(object): ) return backref + if isinstance(self.parts[-1].property, sa.orm.ColumnProperty): + class_ = self.parts[-1].class_ + else: + class_ = self.parts[-1].mapper.class_ + return self.__class__( - self.parts[-1].mapper.class_, + class_, '.'.join(map(get_backref, reversed(self.parts))) ) diff --git a/tests/test_generates.py b/tests/test_generates.py index 24b54f7..a7e7c33 100644 --- a/tests/test_generates.py +++ b/tests/test_generates.py @@ -126,7 +126,7 @@ class TestGeneratesWithSourcePath(DeepPathGeneratesTestCase): -class TestTwoWayAttributeValueGeneration(DeepPathGeneratesTestCase): +class TestInstantAttributeValueGeneration(TestCase): def create_models(self): class Document(self.Base): __tablename__ = 'document' @@ -144,7 +144,7 @@ class TestTwoWayAttributeValueGeneration(DeepPathGeneratesTestCase): sa.Integer, sa.ForeignKey(Document.id) ) - document = sa.orm.relationship(Document) + document = sa.orm.relationship(Document, backref='sections') class SubSection(self.Base): __tablename__ = 'subsection' @@ -156,7 +156,7 @@ class TestTwoWayAttributeValueGeneration(DeepPathGeneratesTestCase): sa.Integer, sa.ForeignKey(Section.id) ) - section = sa.orm.relationship(Section) + section = sa.orm.relationship(Section, backref='subsections') @generates(locale, source='section.document.locale') def copy_locale(self, value): @@ -169,6 +169,10 @@ class TestTwoWayAttributeValueGeneration(DeepPathGeneratesTestCase): def test_change_parent_attribute(self): document = self.Document(name=u'Document 1', locale='en') section = self.Section(name=u'Section 1', document=document) - assert section.locale == 'en' subsection = self.SubSection(name=u'Section 1', section=section) assert subsection.locale == 'en' + document.locale = 'fi' + assert subsection.locale == 'fi' + section.document = self.Document(name=u'Document 2', locale='sv') + assert subsection.locale == 'sv' +