Fix getdotattr list handling

This commit is contained in:
Konsta Vesterinen
2014-02-18 17:17:33 +02:00
parent a836c816d1
commit 989faeb9a6
2 changed files with 49 additions and 14 deletions

View File

@@ -220,6 +220,9 @@ def declarative_base(model):
return model
from operator import attrgetter
def getdotattr(obj_or_class, dot_path):
"""
Allow dot-notated strings to be passed to `getattr`.
@@ -234,16 +237,23 @@ def getdotattr(obj_or_class, dot_path):
:param obj_or_class: Any object or class
:param dot_path: Attribute path with dot mark as separator
"""
def get_attr(mixed, attr):
if isinstance(mixed, InstrumentedAttribute):
return getattr(
mixed.property.mapper.class_,
attr
)
else:
return getattr(mixed, attr)
last = obj_or_class
return six.moves.reduce(get_attr, dot_path.split('.'), obj_or_class)
for path in dot_path.split('.'):
getter = attrgetter(path)
if isinstance(last, list):
tmp = []
for el in last:
if isinstance(el, list):
tmp.extend(map(getter, el))
else:
tmp.append(getter(el))
last = tmp
elif isinstance(last, InstrumentedAttribute):
last = getter(last.property.mapper.class_)
else:
last = getter(last)
return last
def has_changes(obj, attr):

View File

@@ -9,13 +9,11 @@ class TestGetDotAttr(TestCase):
__tablename__ = 'document'
id = sa.Column(sa.Integer, primary_key=True)
name = sa.Column(sa.Unicode(255))
locale = sa.Column(sa.String(10))
class Section(self.Base):
__tablename__ = 'section'
id = sa.Column(sa.Integer, primary_key=True)
name = sa.Column(sa.Unicode(255))
locale = sa.Column(sa.String(10))
document_id = sa.Column(
sa.Integer, sa.ForeignKey(Document.id)
@@ -27,7 +25,6 @@ class TestGetDotAttr(TestCase):
__tablename__ = 'subsection'
id = sa.Column(sa.Integer, primary_key=True)
name = sa.Column(sa.Unicode(255))
locale = sa.Column(sa.String(10))
section_id = sa.Column(
sa.Integer, sa.ForeignKey(Section.id)
@@ -35,11 +32,26 @@ class TestGetDotAttr(TestCase):
section = sa.orm.relationship(Section, backref='subsections')
class SubSubSection(self.Base):
__tablename__ = 'subsubsection'
id = sa.Column(sa.Integer, primary_key=True)
name = sa.Column(sa.Unicode(255))
locale = sa.Column(sa.String(10))
subsection_id = sa.Column(
sa.Integer, sa.ForeignKey(SubSection.id)
)
subsection = sa.orm.relationship(
SubSection, backref='subsubsections'
)
self.Document = Document
self.Section = Section
self.SubSection = SubSection
self.SubSubSection = SubSubSection
def test_getdotattr_for_objects(self):
def test_simple_objects(self):
document = self.Document(name=u'some document')
section = self.Section(document=document)
subsection = self.SubSection(section=section)
@@ -49,7 +61,20 @@ class TestGetDotAttr(TestCase):
'section.document.name'
) == u'some document'
def test_getdotattr_for_class_paths(self):
def test_with_instrumented_lists(self):
document = self.Document(name=u'some document')
section = self.Section(document=document)
subsection = self.SubSection(section=section)
subsubsection = self.SubSubSection(subsection=subsection)
assert getdotattr(document, 'sections.subsections') == [
[subsection]
]
assert getdotattr(document, 'sections.subsections.subsubsections') == [
[subsubsection]
]
def test_class_paths(self):
assert getdotattr(self.Section, 'document') is self.Section.document
assert (
getdotattr(self.SubSection, 'section.document') is