Add attribute listeners
This commit is contained in:
@@ -2,13 +2,8 @@ from collections import defaultdict
|
||||
import itertools
|
||||
import sqlalchemy as sa
|
||||
import six
|
||||
|
||||
|
||||
def getdotattr(obj, dot_path):
|
||||
"""
|
||||
Allows dot-notated strings to be passed to `getattr`
|
||||
"""
|
||||
return six.moves.reduce(getattr, dot_path.split('.'), obj)
|
||||
from .functions import getdotattr
|
||||
from .path import AttrPath
|
||||
|
||||
|
||||
class AttributeValueGenerator(object):
|
||||
@@ -59,9 +54,27 @@ class AttributeValueGenerator(object):
|
||||
"""
|
||||
Adds generator functions to generator_registry.
|
||||
"""
|
||||
for value in class_.__dict__.values():
|
||||
if hasattr(value, '__generates__'):
|
||||
self.generator_registry[class_].append(value)
|
||||
for generator in class_.__dict__.values():
|
||||
if hasattr(generator, '__generates__'):
|
||||
self.generator_registry[class_].append(generator)
|
||||
|
||||
path = generator.__generates__[1]
|
||||
column_key = generator.__generates__[0]
|
||||
if not isinstance(column_key, six.string_types):
|
||||
column_key = column_key.key
|
||||
if path:
|
||||
path = AttrPath(class_, path)
|
||||
attr = getdotattr(class_, str(path))
|
||||
|
||||
if isinstance(attr.property, sa.orm.ColumnProperty):
|
||||
for attr in path[0:-1]:
|
||||
@sa.event.listens_for(attr, 'set')
|
||||
def receive_attr_set(target, value, oldvalue, initiator):
|
||||
setattr(
|
||||
target,
|
||||
column_key,
|
||||
getattr(value, path[-1].key)
|
||||
)
|
||||
|
||||
def update_generated_properties(self, session, ctx, instances):
|
||||
for obj in itertools.chain(session.new, session.dirty):
|
||||
|
@@ -57,8 +57,29 @@ class TestGeneratesWithFunctionAndClassVariableArg(GeneratesTestCase):
|
||||
self.Article = Article
|
||||
|
||||
|
||||
class DeepPathGeneratesTestCase(TestCase):
|
||||
def test_simple_dotted_source_path(self):
|
||||
document = self.Document(name=u'Document 1', locale='en')
|
||||
section = self.Section(name=u'Section 1', document=document)
|
||||
|
||||
class TestGeneratesWithSourcePath(TestCase):
|
||||
self.session.add(document)
|
||||
self.session.add(section)
|
||||
self.session.commit()
|
||||
|
||||
assert section.locale == 'en'
|
||||
|
||||
def test_deep_dotted_source_path(self):
|
||||
document = self.Document(name=u'Document 1', locale='en')
|
||||
section = self.Section(name=u'Section 1', document=document)
|
||||
subsection = self.SubSection(name=u'Section 1', section=section)
|
||||
|
||||
self.session.add(subsection)
|
||||
self.session.commit()
|
||||
|
||||
assert subsection.locale == 'en'
|
||||
|
||||
|
||||
class TestGeneratesWithSourcePath(DeepPathGeneratesTestCase):
|
||||
def create_models(self):
|
||||
class Document(self.Base):
|
||||
__tablename__ = 'document'
|
||||
@@ -103,22 +124,51 @@ class TestGeneratesWithSourcePath(TestCase):
|
||||
self.Section = Section
|
||||
self.SubSection = SubSection
|
||||
|
||||
def test_simple_dotted_source_path(self):
|
||||
|
||||
|
||||
class TestTwoWayAttributeValueGeneration(DeepPathGeneratesTestCase):
|
||||
def create_models(self):
|
||||
class Document(self.Base):
|
||||
__tablename__ = 'document'
|
||||
id = sa.Column(sa.Integer, primary_key=True)
|
||||
name = sa.Column(sa.Unicode(255))
|
||||
locale = sa.Column(sa.String(10))
|
||||
|
||||
class Section(self.Base):
|
||||
__tablename__ = 'section'
|
||||
id = sa.Column(sa.Integer, primary_key=True)
|
||||
name = sa.Column(sa.Unicode(255))
|
||||
locale = sa.Column(sa.String(10))
|
||||
|
||||
document_id = sa.Column(
|
||||
sa.Integer, sa.ForeignKey(Document.id)
|
||||
)
|
||||
|
||||
document = sa.orm.relationship(Document)
|
||||
|
||||
class SubSection(self.Base):
|
||||
__tablename__ = 'subsection'
|
||||
id = sa.Column(sa.Integer, primary_key=True)
|
||||
name = sa.Column(sa.Unicode(255))
|
||||
locale = sa.Column(sa.String(10))
|
||||
|
||||
section_id = sa.Column(
|
||||
sa.Integer, sa.ForeignKey(Section.id)
|
||||
)
|
||||
|
||||
section = sa.orm.relationship(Section)
|
||||
|
||||
@generates(locale, source='section.document.locale')
|
||||
def copy_locale(self, value):
|
||||
return value
|
||||
|
||||
self.Document = Document
|
||||
self.Section = Section
|
||||
self.SubSection = SubSection
|
||||
|
||||
def test_change_parent_attribute(self):
|
||||
document = self.Document(name=u'Document 1', locale='en')
|
||||
section = self.Section(name=u'Section 1', document=document)
|
||||
|
||||
self.session.add(document)
|
||||
self.session.add(section)
|
||||
self.session.commit()
|
||||
|
||||
assert section.locale == 'en'
|
||||
|
||||
def test_deep_dotted_source_path(self):
|
||||
document = self.Document(name=u'Document 1', locale='en')
|
||||
section = self.Section(name=u'Section 1', document=document)
|
||||
subsection = self.SubSection(name=u'Section 1', section=section)
|
||||
|
||||
self.session.add(subsection)
|
||||
self.session.commit()
|
||||
|
||||
assert subsection.locale == 'en'
|
||||
|
Reference in New Issue
Block a user