Add column observers for generates
This commit is contained in:
@@ -67,14 +67,64 @@ class AttributeValueGenerator(object):
|
|||||||
attr = getdotattr(class_, str(path))
|
attr = getdotattr(class_, str(path))
|
||||||
|
|
||||||
if isinstance(attr.property, sa.orm.ColumnProperty):
|
if isinstance(attr.property, sa.orm.ColumnProperty):
|
||||||
for attr in path[0:-1]:
|
for attr in path:
|
||||||
|
self.generate_property_observer(
|
||||||
|
path,
|
||||||
|
attr,
|
||||||
|
column_key
|
||||||
|
)
|
||||||
|
|
||||||
|
def generate_property_observer(self, path, attr, property_key):
|
||||||
|
"""
|
||||||
|
Generate SQLAlchemy listener that observes given attr within given
|
||||||
|
path space.
|
||||||
|
"""
|
||||||
@sa.event.listens_for(attr, 'set')
|
@sa.event.listens_for(attr, 'set')
|
||||||
def receive_attr_set(target, value, oldvalue, initiator):
|
def receive_attr_set(target, value, oldvalue, initiator):
|
||||||
|
index = path.index(attr)
|
||||||
|
if not index:
|
||||||
setattr(
|
setattr(
|
||||||
target,
|
target,
|
||||||
column_key,
|
property_key,
|
||||||
getattr(value, path[-1].key)
|
getdotattr(value, str(path[1:]))
|
||||||
)
|
)
|
||||||
|
elif index == len(path) - 1:
|
||||||
|
inversed_path = ~path[0:-1]
|
||||||
|
entities = list(
|
||||||
|
getdotattr(
|
||||||
|
target,
|
||||||
|
str(inversed_path)
|
||||||
|
)
|
||||||
|
)
|
||||||
|
for entity in entities:
|
||||||
|
if isinstance(entity, list):
|
||||||
|
for e in entity:
|
||||||
|
setattr(
|
||||||
|
e,
|
||||||
|
property_key,
|
||||||
|
value
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
setattr(
|
||||||
|
entity,
|
||||||
|
property_key,
|
||||||
|
value
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
inversed_path = ~path[0:-1]
|
||||||
|
entities = list(
|
||||||
|
getdotattr(
|
||||||
|
target,
|
||||||
|
str(inversed_path[index:])
|
||||||
|
)
|
||||||
|
)
|
||||||
|
for entity in entities:
|
||||||
|
setattr(
|
||||||
|
entity,
|
||||||
|
property_key,
|
||||||
|
getdotattr(value, str(path[(index + 1):]))
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def update_generated_properties(self, session, ctx, instances):
|
def update_generated_properties(self, session, ctx, instances):
|
||||||
for obj in itertools.chain(session.new, session.dirty):
|
for obj in itertools.chain(session.new, session.dirty):
|
||||||
|
@@ -1,3 +1,4 @@
|
|||||||
|
import sqlalchemy as sa
|
||||||
from sqlalchemy.orm.attributes import InstrumentedAttribute
|
from sqlalchemy.orm.attributes import InstrumentedAttribute
|
||||||
from .utils import str_coercible
|
from .utils import str_coercible
|
||||||
|
|
||||||
@@ -86,8 +87,13 @@ class AttrPath(object):
|
|||||||
)
|
)
|
||||||
return backref
|
return backref
|
||||||
|
|
||||||
|
if isinstance(self.parts[-1].property, sa.orm.ColumnProperty):
|
||||||
|
class_ = self.parts[-1].class_
|
||||||
|
else:
|
||||||
|
class_ = self.parts[-1].mapper.class_
|
||||||
|
|
||||||
return self.__class__(
|
return self.__class__(
|
||||||
self.parts[-1].mapper.class_,
|
class_,
|
||||||
'.'.join(map(get_backref, reversed(self.parts)))
|
'.'.join(map(get_backref, reversed(self.parts)))
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@@ -126,7 +126,7 @@ class TestGeneratesWithSourcePath(DeepPathGeneratesTestCase):
|
|||||||
|
|
||||||
|
|
||||||
|
|
||||||
class TestTwoWayAttributeValueGeneration(DeepPathGeneratesTestCase):
|
class TestInstantAttributeValueGeneration(TestCase):
|
||||||
def create_models(self):
|
def create_models(self):
|
||||||
class Document(self.Base):
|
class Document(self.Base):
|
||||||
__tablename__ = 'document'
|
__tablename__ = 'document'
|
||||||
@@ -144,7 +144,7 @@ class TestTwoWayAttributeValueGeneration(DeepPathGeneratesTestCase):
|
|||||||
sa.Integer, sa.ForeignKey(Document.id)
|
sa.Integer, sa.ForeignKey(Document.id)
|
||||||
)
|
)
|
||||||
|
|
||||||
document = sa.orm.relationship(Document)
|
document = sa.orm.relationship(Document, backref='sections')
|
||||||
|
|
||||||
class SubSection(self.Base):
|
class SubSection(self.Base):
|
||||||
__tablename__ = 'subsection'
|
__tablename__ = 'subsection'
|
||||||
@@ -156,7 +156,7 @@ class TestTwoWayAttributeValueGeneration(DeepPathGeneratesTestCase):
|
|||||||
sa.Integer, sa.ForeignKey(Section.id)
|
sa.Integer, sa.ForeignKey(Section.id)
|
||||||
)
|
)
|
||||||
|
|
||||||
section = sa.orm.relationship(Section)
|
section = sa.orm.relationship(Section, backref='subsections')
|
||||||
|
|
||||||
@generates(locale, source='section.document.locale')
|
@generates(locale, source='section.document.locale')
|
||||||
def copy_locale(self, value):
|
def copy_locale(self, value):
|
||||||
@@ -169,6 +169,10 @@ class TestTwoWayAttributeValueGeneration(DeepPathGeneratesTestCase):
|
|||||||
def test_change_parent_attribute(self):
|
def test_change_parent_attribute(self):
|
||||||
document = self.Document(name=u'Document 1', locale='en')
|
document = self.Document(name=u'Document 1', locale='en')
|
||||||
section = self.Section(name=u'Section 1', document=document)
|
section = self.Section(name=u'Section 1', document=document)
|
||||||
assert section.locale == 'en'
|
|
||||||
subsection = self.SubSection(name=u'Section 1', section=section)
|
subsection = self.SubSection(name=u'Section 1', section=section)
|
||||||
assert subsection.locale == 'en'
|
assert subsection.locale == 'en'
|
||||||
|
document.locale = 'fi'
|
||||||
|
assert subsection.locale == 'fi'
|
||||||
|
section.document = self.Document(name=u'Document 2', locale='sv')
|
||||||
|
assert subsection.locale == 'sv'
|
||||||
|
|
||||||
|
Reference in New Issue
Block a user