Add attribute listeners

This commit is contained in:
Konsta Vesterinen
2014-02-18 14:11:54 +02:00
parent 41483c8d38
commit 6b64bf24ed
2 changed files with 88 additions and 25 deletions

View File

@@ -2,13 +2,8 @@ from collections import defaultdict
import itertools
import sqlalchemy as sa
import six
def getdotattr(obj, dot_path):
"""
Allows dot-notated strings to be passed to `getattr`
"""
return six.moves.reduce(getattr, dot_path.split('.'), obj)
from .functions import getdotattr
from .path import AttrPath
class AttributeValueGenerator(object):
@@ -59,9 +54,27 @@ class AttributeValueGenerator(object):
"""
Adds generator functions to generator_registry.
"""
for value in class_.__dict__.values():
if hasattr(value, '__generates__'):
self.generator_registry[class_].append(value)
for generator in class_.__dict__.values():
if hasattr(generator, '__generates__'):
self.generator_registry[class_].append(generator)
path = generator.__generates__[1]
column_key = generator.__generates__[0]
if not isinstance(column_key, six.string_types):
column_key = column_key.key
if path:
path = AttrPath(class_, path)
attr = getdotattr(class_, str(path))
if isinstance(attr.property, sa.orm.ColumnProperty):
for attr in path[0:-1]:
@sa.event.listens_for(attr, 'set')
def receive_attr_set(target, value, oldvalue, initiator):
setattr(
target,
column_key,
getattr(value, path[-1].key)
)
def update_generated_properties(self, session, ctx, instances):
for obj in itertools.chain(session.new, session.dirty):

View File

@@ -57,8 +57,29 @@ class TestGeneratesWithFunctionAndClassVariableArg(GeneratesTestCase):
self.Article = Article
class DeepPathGeneratesTestCase(TestCase):
def test_simple_dotted_source_path(self):
document = self.Document(name=u'Document 1', locale='en')
section = self.Section(name=u'Section 1', document=document)
class TestGeneratesWithSourcePath(TestCase):
self.session.add(document)
self.session.add(section)
self.session.commit()
assert section.locale == 'en'
def test_deep_dotted_source_path(self):
document = self.Document(name=u'Document 1', locale='en')
section = self.Section(name=u'Section 1', document=document)
subsection = self.SubSection(name=u'Section 1', section=section)
self.session.add(subsection)
self.session.commit()
assert subsection.locale == 'en'
class TestGeneratesWithSourcePath(DeepPathGeneratesTestCase):
def create_models(self):
class Document(self.Base):
__tablename__ = 'document'
@@ -103,22 +124,51 @@ class TestGeneratesWithSourcePath(TestCase):
self.Section = Section
self.SubSection = SubSection
def test_simple_dotted_source_path(self):
class TestTwoWayAttributeValueGeneration(DeepPathGeneratesTestCase):
def create_models(self):
class Document(self.Base):
__tablename__ = 'document'
id = sa.Column(sa.Integer, primary_key=True)
name = sa.Column(sa.Unicode(255))
locale = sa.Column(sa.String(10))
class Section(self.Base):
__tablename__ = 'section'
id = sa.Column(sa.Integer, primary_key=True)
name = sa.Column(sa.Unicode(255))
locale = sa.Column(sa.String(10))
document_id = sa.Column(
sa.Integer, sa.ForeignKey(Document.id)
)
document = sa.orm.relationship(Document)
class SubSection(self.Base):
__tablename__ = 'subsection'
id = sa.Column(sa.Integer, primary_key=True)
name = sa.Column(sa.Unicode(255))
locale = sa.Column(sa.String(10))
section_id = sa.Column(
sa.Integer, sa.ForeignKey(Section.id)
)
section = sa.orm.relationship(Section)
@generates(locale, source='section.document.locale')
def copy_locale(self, value):
return value
self.Document = Document
self.Section = Section
self.SubSection = SubSection
def test_change_parent_attribute(self):
document = self.Document(name=u'Document 1', locale='en')
section = self.Section(name=u'Section 1', document=document)
self.session.add(document)
self.session.add(section)
self.session.commit()
assert section.locale == 'en'
def test_deep_dotted_source_path(self):
document = self.Document(name=u'Document 1', locale='en')
section = self.Section(name=u'Section 1', document=document)
subsection = self.SubSection(name=u'Section 1', section=section)
self.session.add(subsection)
self.session.commit()
assert subsection.locale == 'en'