Add attribute listeners

This commit is contained in:
Konsta Vesterinen
2014-02-18 14:11:54 +02:00
parent 41483c8d38
commit 6b64bf24ed
2 changed files with 88 additions and 25 deletions

View File

@@ -2,13 +2,8 @@ from collections import defaultdict
import itertools import itertools
import sqlalchemy as sa import sqlalchemy as sa
import six import six
from .functions import getdotattr
from .path import AttrPath
def getdotattr(obj, dot_path):
"""
Allows dot-notated strings to be passed to `getattr`
"""
return six.moves.reduce(getattr, dot_path.split('.'), obj)
class AttributeValueGenerator(object): class AttributeValueGenerator(object):
@@ -59,9 +54,27 @@ class AttributeValueGenerator(object):
""" """
Adds generator functions to generator_registry. Adds generator functions to generator_registry.
""" """
for value in class_.__dict__.values(): for generator in class_.__dict__.values():
if hasattr(value, '__generates__'): if hasattr(generator, '__generates__'):
self.generator_registry[class_].append(value) self.generator_registry[class_].append(generator)
path = generator.__generates__[1]
column_key = generator.__generates__[0]
if not isinstance(column_key, six.string_types):
column_key = column_key.key
if path:
path = AttrPath(class_, path)
attr = getdotattr(class_, str(path))
if isinstance(attr.property, sa.orm.ColumnProperty):
for attr in path[0:-1]:
@sa.event.listens_for(attr, 'set')
def receive_attr_set(target, value, oldvalue, initiator):
setattr(
target,
column_key,
getattr(value, path[-1].key)
)
def update_generated_properties(self, session, ctx, instances): def update_generated_properties(self, session, ctx, instances):
for obj in itertools.chain(session.new, session.dirty): for obj in itertools.chain(session.new, session.dirty):

View File

@@ -57,8 +57,29 @@ class TestGeneratesWithFunctionAndClassVariableArg(GeneratesTestCase):
self.Article = Article self.Article = Article
class DeepPathGeneratesTestCase(TestCase):
def test_simple_dotted_source_path(self):
document = self.Document(name=u'Document 1', locale='en')
section = self.Section(name=u'Section 1', document=document)
class TestGeneratesWithSourcePath(TestCase): self.session.add(document)
self.session.add(section)
self.session.commit()
assert section.locale == 'en'
def test_deep_dotted_source_path(self):
document = self.Document(name=u'Document 1', locale='en')
section = self.Section(name=u'Section 1', document=document)
subsection = self.SubSection(name=u'Section 1', section=section)
self.session.add(subsection)
self.session.commit()
assert subsection.locale == 'en'
class TestGeneratesWithSourcePath(DeepPathGeneratesTestCase):
def create_models(self): def create_models(self):
class Document(self.Base): class Document(self.Base):
__tablename__ = 'document' __tablename__ = 'document'
@@ -103,22 +124,51 @@ class TestGeneratesWithSourcePath(TestCase):
self.Section = Section self.Section = Section
self.SubSection = SubSection self.SubSection = SubSection
def test_simple_dotted_source_path(self):
class TestTwoWayAttributeValueGeneration(DeepPathGeneratesTestCase):
def create_models(self):
class Document(self.Base):
__tablename__ = 'document'
id = sa.Column(sa.Integer, primary_key=True)
name = sa.Column(sa.Unicode(255))
locale = sa.Column(sa.String(10))
class Section(self.Base):
__tablename__ = 'section'
id = sa.Column(sa.Integer, primary_key=True)
name = sa.Column(sa.Unicode(255))
locale = sa.Column(sa.String(10))
document_id = sa.Column(
sa.Integer, sa.ForeignKey(Document.id)
)
document = sa.orm.relationship(Document)
class SubSection(self.Base):
__tablename__ = 'subsection'
id = sa.Column(sa.Integer, primary_key=True)
name = sa.Column(sa.Unicode(255))
locale = sa.Column(sa.String(10))
section_id = sa.Column(
sa.Integer, sa.ForeignKey(Section.id)
)
section = sa.orm.relationship(Section)
@generates(locale, source='section.document.locale')
def copy_locale(self, value):
return value
self.Document = Document
self.Section = Section
self.SubSection = SubSection
def test_change_parent_attribute(self):
document = self.Document(name=u'Document 1', locale='en') document = self.Document(name=u'Document 1', locale='en')
section = self.Section(name=u'Section 1', document=document) section = self.Section(name=u'Section 1', document=document)
self.session.add(document)
self.session.add(section)
self.session.commit()
assert section.locale == 'en' assert section.locale == 'en'
def test_deep_dotted_source_path(self):
document = self.Document(name=u'Document 1', locale='en')
section = self.Section(name=u'Section 1', document=document)
subsection = self.SubSection(name=u'Section 1', section=section) subsection = self.SubSection(name=u'Section 1', section=section)
self.session.add(subsection)
self.session.commit()
assert subsection.locale == 'en' assert subsection.locale == 'en'