Add source param to generates

This commit is contained in:
Konsta Vesterinen
2014-01-30 15:02:34 +02:00
parent d04155728e
commit b95a7d855e
2 changed files with 77 additions and 7 deletions

View File

@@ -4,6 +4,13 @@ import sqlalchemy as sa
import six import six
def getdotattr(obj, dot_path):
"""
Allows dot-notated strings to be passed to `getattr`
"""
return reduce(getattr, dot_path.split('.'), obj)
class AttributeValueGenerator(object): class AttributeValueGenerator(object):
def __init__(self): def __init__(self):
self.listener_args = [ self.listener_args = [
@@ -31,15 +38,15 @@ class AttributeValueGenerator(object):
self.listeners_registered = False self.listeners_registered = False
self.generator_registry = defaultdict(list) self.generator_registry = defaultdict(list)
def generator_wrapper(self, func, attr): def generator_wrapper(self, func, attr, source):
def wrapper(self, *args, **kwargs): def wrapper(self, *args, **kwargs):
return func(self, *args, **kwargs) return func(self, *args, **kwargs)
if isinstance(attr, sa.orm.attributes.InstrumentedAttribute): if isinstance(attr, sa.orm.attributes.InstrumentedAttribute):
self.generator_registry[attr.class_].append(wrapper) self.generator_registry[attr.class_].append(wrapper)
wrapper.__generates__ = attr wrapper.__generates__ = attr, source
else: else:
wrapper.__generates__ = attr wrapper.__generates__ = attr, source
return wrapper return wrapper
def register_listeners(self): def register_listeners(self):
@@ -61,17 +68,23 @@ class AttributeValueGenerator(object):
class_ = obj.__class__ class_ = obj.__class__
if class_ in self.generator_registry: if class_ in self.generator_registry:
for func in self.generator_registry[class_]: for func in self.generator_registry[class_]:
attr = func.__generates__ attr, source = func.__generates__
if not isinstance(attr, six.string_types): if not isinstance(attr, six.string_types):
attr = attr.name attr = attr.name
setattr(obj, attr, func(obj))
if source is None:
setattr(obj, attr, func(obj))
else:
setattr(obj, attr, func(obj, getdotattr(obj, source)))
generator = AttributeValueGenerator() generator = AttributeValueGenerator()
def generates(attr, generator=generator): def generates(attr, source=None, generator=generator):
""" """
Decorator that marks given function as attribute value generator.
Many times you may have generated property values. Usual cases include Many times you may have generated property values. Usual cases include
slugs from names or resized thumbnails from images. slugs from names or resized thumbnails from images.
@@ -131,5 +144,5 @@ def generates(attr, generator=generator):
generator.register_listeners() generator.register_listeners()
def wraps(func): def wraps(func):
return generator.generator_wrapper(func, attr) return generator.generator_wrapper(func, attr, source)
return wraps return wraps

View File

@@ -55,3 +55,60 @@ class TestGeneratesWithFunctionAndClassVariableArg(GeneratesTestCase):
return self.name.lower().replace(' ', '-') return self.name.lower().replace(' ', '-')
self.Article = Article self.Article = Article
class TestGeneratesWithSourcePath(TestCase):
def create_models(self):
class Document(self.Base):
__tablename__ = 'document'
id = sa.Column(sa.Integer, primary_key=True)
name = sa.Column(sa.Unicode(255))
locale = sa.Column(sa.String(10))
class Section(self.Base):
__tablename__ = 'section'
id = sa.Column(sa.Integer, primary_key=True)
name = sa.Column(sa.Unicode(255))
locale = sa.Column(sa.String(10))
document_id = sa.Column(
sa.Integer, sa.ForeignKey(Document.id)
)
document = sa.orm.relationship(Document)
@generates(locale, source='document')
def copy_locale(self, document):
return document.locale
class SubSection(self.Base):
__tablename__ = 'subsection'
id = sa.Column(sa.Integer, primary_key=True)
name = sa.Column(sa.Unicode(255))
locale = sa.Column(sa.String(10))
section_id = sa.Column(
sa.Integer, sa.ForeignKey(Section.id)
)
section = sa.orm.relationship(Section)
@generates(locale, source='section.document')
def copy_locale(self, document):
return document.locale
self.Document = Document
self.Section = Section
self.SubSection = SubSection
def test_simple_source_paths(self):
document = self.Document(name=u'Document 1', locale='en')
section = self.Section(name=u'Section 1', document=document)
self.session.add(document)
self.session.add(section)
self.session.commit()
assert section.locale == 'en'