Add column observers for generates

This commit is contained in:
Konsta Vesterinen
2014-02-18 17:30:19 +02:00
parent 989faeb9a6
commit 81f65eb693
3 changed files with 73 additions and 13 deletions

View File

@@ -67,14 +67,64 @@ class AttributeValueGenerator(object):
attr = getdotattr(class_, str(path)) attr = getdotattr(class_, str(path))
if isinstance(attr.property, sa.orm.ColumnProperty): if isinstance(attr.property, sa.orm.ColumnProperty):
for attr in path[0:-1]: for attr in path:
self.generate_property_observer(
path,
attr,
column_key
)
def generate_property_observer(self, path, attr, property_key):
"""
Generate SQLAlchemy listener that observes given attr within given
path space.
"""
@sa.event.listens_for(attr, 'set') @sa.event.listens_for(attr, 'set')
def receive_attr_set(target, value, oldvalue, initiator): def receive_attr_set(target, value, oldvalue, initiator):
index = path.index(attr)
if not index:
setattr( setattr(
target, target,
column_key, property_key,
getattr(value, path[-1].key) getdotattr(value, str(path[1:]))
) )
elif index == len(path) - 1:
inversed_path = ~path[0:-1]
entities = list(
getdotattr(
target,
str(inversed_path)
)
)
for entity in entities:
if isinstance(entity, list):
for e in entity:
setattr(
e,
property_key,
value
)
else:
setattr(
entity,
property_key,
value
)
else:
inversed_path = ~path[0:-1]
entities = list(
getdotattr(
target,
str(inversed_path[index:])
)
)
for entity in entities:
setattr(
entity,
property_key,
getdotattr(value, str(path[(index + 1):]))
)
def update_generated_properties(self, session, ctx, instances): def update_generated_properties(self, session, ctx, instances):
for obj in itertools.chain(session.new, session.dirty): for obj in itertools.chain(session.new, session.dirty):

View File

@@ -1,3 +1,4 @@
import sqlalchemy as sa
from sqlalchemy.orm.attributes import InstrumentedAttribute from sqlalchemy.orm.attributes import InstrumentedAttribute
from .utils import str_coercible from .utils import str_coercible
@@ -86,8 +87,13 @@ class AttrPath(object):
) )
return backref return backref
if isinstance(self.parts[-1].property, sa.orm.ColumnProperty):
class_ = self.parts[-1].class_
else:
class_ = self.parts[-1].mapper.class_
return self.__class__( return self.__class__(
self.parts[-1].mapper.class_, class_,
'.'.join(map(get_backref, reversed(self.parts))) '.'.join(map(get_backref, reversed(self.parts)))
) )

View File

@@ -126,7 +126,7 @@ class TestGeneratesWithSourcePath(DeepPathGeneratesTestCase):
class TestTwoWayAttributeValueGeneration(DeepPathGeneratesTestCase): class TestInstantAttributeValueGeneration(TestCase):
def create_models(self): def create_models(self):
class Document(self.Base): class Document(self.Base):
__tablename__ = 'document' __tablename__ = 'document'
@@ -144,7 +144,7 @@ class TestTwoWayAttributeValueGeneration(DeepPathGeneratesTestCase):
sa.Integer, sa.ForeignKey(Document.id) sa.Integer, sa.ForeignKey(Document.id)
) )
document = sa.orm.relationship(Document) document = sa.orm.relationship(Document, backref='sections')
class SubSection(self.Base): class SubSection(self.Base):
__tablename__ = 'subsection' __tablename__ = 'subsection'
@@ -156,7 +156,7 @@ class TestTwoWayAttributeValueGeneration(DeepPathGeneratesTestCase):
sa.Integer, sa.ForeignKey(Section.id) sa.Integer, sa.ForeignKey(Section.id)
) )
section = sa.orm.relationship(Section) section = sa.orm.relationship(Section, backref='subsections')
@generates(locale, source='section.document.locale') @generates(locale, source='section.document.locale')
def copy_locale(self, value): def copy_locale(self, value):
@@ -169,6 +169,10 @@ class TestTwoWayAttributeValueGeneration(DeepPathGeneratesTestCase):
def test_change_parent_attribute(self): def test_change_parent_attribute(self):
document = self.Document(name=u'Document 1', locale='en') document = self.Document(name=u'Document 1', locale='en')
section = self.Section(name=u'Section 1', document=document) section = self.Section(name=u'Section 1', document=document)
assert section.locale == 'en'
subsection = self.SubSection(name=u'Section 1', section=section) subsection = self.SubSection(name=u'Section 1', section=section)
assert subsection.locale == 'en' assert subsection.locale == 'en'
document.locale = 'fi'
assert subsection.locale == 'fi'
section.document = self.Document(name=u'Document 2', locale='sv')
assert subsection.locale == 'sv'