diff --git a/sqlalchemy_utils/decorators.py b/sqlalchemy_utils/decorators.py index 135298e..5f6512e 100644 --- a/sqlalchemy_utils/decorators.py +++ b/sqlalchemy_utils/decorators.py @@ -90,40 +90,42 @@ class AttributeValueGenerator(object): ) elif index == len(path) - 1: inversed_path = ~path[0:-1] - entities = list( - getdotattr( - target, - str(inversed_path) - ) + entities = getdotattr( + target, + str(inversed_path) ) - for entity in entities: - if isinstance(entity, list): - for e in entity: + if entities: + if not isinstance(entities, list): + entities = [entities] + for entity in entities: + if isinstance(entity, list): + for e in entity: + setattr( + e, + property_key, + value + ) + else: setattr( - e, + entity, property_key, value ) - else: + else: + inversed_path = ~path[0:-1] + entities = getdotattr( + target, + str(inversed_path[index:]) + ) + if entities: + if not isinstance(entities, list): + entities = [entities] + for entity in entities: setattr( entity, property_key, - value + getdotattr(value, str(path[(index + 1):])) ) - else: - inversed_path = ~path[0:-1] - entities = list( - getdotattr( - target, - str(inversed_path[index:]) - ) - ) - for entity in entities: - setattr( - entity, - property_key, - getdotattr(value, str(path[(index + 1):])) - ) def update_generated_properties(self, session, ctx, instances): diff --git a/sqlalchemy_utils/functions/orm.py b/sqlalchemy_utils/functions/orm.py index 137f63c..d43b52a 100644 --- a/sqlalchemy_utils/functions/orm.py +++ b/sqlalchemy_utils/functions/orm.py @@ -251,6 +251,8 @@ def getdotattr(obj_or_class, dot_path): last = tmp elif isinstance(last, InstrumentedAttribute): last = getter(last.property.mapper.class_) + elif last is None: + return None else: last = getter(last) return last diff --git a/sqlalchemy_utils/path.py b/sqlalchemy_utils/path.py index b56bcd9..a963878 100644 --- a/sqlalchemy_utils/path.py +++ b/sqlalchemy_utils/path.py @@ -85,7 +85,10 @@ class AttrPath(object): prop.parent.class_.__name__ ) ) - return backref + if isinstance(backref, tuple): + return backref[0] + else: + return backref if isinstance(self.parts[-1].property, sa.orm.ColumnProperty): class_ = self.parts[-1].class_ diff --git a/tests/test_generates.py b/tests/test_generates.py index a7e7c33..2dd162b 100644 --- a/tests/test_generates.py +++ b/tests/test_generates.py @@ -170,9 +170,86 @@ class TestInstantAttributeValueGeneration(TestCase): document = self.Document(name=u'Document 1', locale='en') section = self.Section(name=u'Section 1', document=document) subsection = self.SubSection(name=u'Section 1', section=section) - assert subsection.locale == 'en' document.locale = 'fi' assert subsection.locale == 'fi' + + def test_simple_assignment(self): + document = self.Document(name=u'Document 1', locale='en') + section = self.Section(name=u'Section 1', document=document) + subsection = self.SubSection(name=u'Section 1', section=section) + assert subsection.locale == 'en' + + def test_intermediate_object_reference(self): + document = self.Document(name=u'Document 1', locale='en') + section = self.Section(name=u'Section 1', document=document) + subsection = self.SubSection(name=u'Section 1', section=section) section.document = self.Document(name=u'Document 2', locale='sv') assert subsection.locale == 'sv' + +class TestInstantAttrGenerationWithScalars(TestCase): + def create_models(self): + class Document(self.Base): + __tablename__ = 'document' + id = sa.Column(sa.Integer, primary_key=True) + name = sa.Column(sa.Unicode(255)) + locale = sa.Column(sa.String(10)) + + class Section(self.Base): + __tablename__ = 'section' + id = sa.Column(sa.Integer, primary_key=True) + name = sa.Column(sa.Unicode(255)) + locale = sa.Column(sa.String(10)) + + document_id = sa.Column( + sa.Integer, sa.ForeignKey(Document.id) + ) + + document = sa.orm.relationship( + Document, + backref=sa.orm.backref('section', uselist=False) + ) + + class SubSection(self.Base): + __tablename__ = 'subsection' + id = sa.Column(sa.Integer, primary_key=True) + name = sa.Column(sa.Unicode(255)) + locale = sa.Column(sa.String(10)) + + section_id = sa.Column( + sa.Integer, sa.ForeignKey(Section.id) + ) + + section = sa.orm.relationship( + Section, + backref=sa.orm.backref('subsection', uselist=False) + ) + + @generates(locale, source='section.document.locale') + def copy_locale(self, value): + return value + + + self.Document = Document + self.Section = Section + self.SubSection = SubSection + + def test_change_parent_attribute(self): + document = self.Document(name=u'Document 1', locale='en') + section = self.Section(name=u'Section 1', document=document) + subsection = self.SubSection(name=u'Section 1', section=section) + document.locale = 'fi' + assert subsection.locale == 'fi' + + def test_simple_assignment(self): + document = self.Document(name=u'Document 1', locale='en') + section = self.Section(name=u'Section 1', document=document) + subsection = self.SubSection(name=u'Section 1', section=section) + assert subsection.locale == 'en' + + def test_intermediate_object_reference(self): + document = self.Document(name=u'Document 1', locale='en') + section = self.Section(name=u'Section 1', document=document) + subsection = self.SubSection(name=u'Section 1', section=section) + section.document = self.Document(name=u'Document 2', locale='sv') + assert subsection.locale == 'sv'