Fix getdotattr list handling
This commit is contained in:
@@ -220,6 +220,9 @@ def declarative_base(model):
|
||||
return model
|
||||
|
||||
|
||||
from operator import attrgetter
|
||||
|
||||
|
||||
def getdotattr(obj_or_class, dot_path):
|
||||
"""
|
||||
Allow dot-notated strings to be passed to `getattr`.
|
||||
@@ -234,16 +237,23 @@ def getdotattr(obj_or_class, dot_path):
|
||||
:param obj_or_class: Any object or class
|
||||
:param dot_path: Attribute path with dot mark as separator
|
||||
"""
|
||||
def get_attr(mixed, attr):
|
||||
if isinstance(mixed, InstrumentedAttribute):
|
||||
return getattr(
|
||||
mixed.property.mapper.class_,
|
||||
attr
|
||||
)
|
||||
else:
|
||||
return getattr(mixed, attr)
|
||||
last = obj_or_class
|
||||
|
||||
return six.moves.reduce(get_attr, dot_path.split('.'), obj_or_class)
|
||||
for path in dot_path.split('.'):
|
||||
getter = attrgetter(path)
|
||||
if isinstance(last, list):
|
||||
tmp = []
|
||||
for el in last:
|
||||
if isinstance(el, list):
|
||||
tmp.extend(map(getter, el))
|
||||
else:
|
||||
tmp.append(getter(el))
|
||||
last = tmp
|
||||
elif isinstance(last, InstrumentedAttribute):
|
||||
last = getter(last.property.mapper.class_)
|
||||
else:
|
||||
last = getter(last)
|
||||
return last
|
||||
|
||||
|
||||
def has_changes(obj, attr):
|
||||
|
@@ -9,13 +9,11 @@ class TestGetDotAttr(TestCase):
|
||||
__tablename__ = 'document'
|
||||
id = sa.Column(sa.Integer, primary_key=True)
|
||||
name = sa.Column(sa.Unicode(255))
|
||||
locale = sa.Column(sa.String(10))
|
||||
|
||||
class Section(self.Base):
|
||||
__tablename__ = 'section'
|
||||
id = sa.Column(sa.Integer, primary_key=True)
|
||||
name = sa.Column(sa.Unicode(255))
|
||||
locale = sa.Column(sa.String(10))
|
||||
|
||||
document_id = sa.Column(
|
||||
sa.Integer, sa.ForeignKey(Document.id)
|
||||
@@ -27,7 +25,6 @@ class TestGetDotAttr(TestCase):
|
||||
__tablename__ = 'subsection'
|
||||
id = sa.Column(sa.Integer, primary_key=True)
|
||||
name = sa.Column(sa.Unicode(255))
|
||||
locale = sa.Column(sa.String(10))
|
||||
|
||||
section_id = sa.Column(
|
||||
sa.Integer, sa.ForeignKey(Section.id)
|
||||
@@ -35,11 +32,26 @@ class TestGetDotAttr(TestCase):
|
||||
|
||||
section = sa.orm.relationship(Section, backref='subsections')
|
||||
|
||||
class SubSubSection(self.Base):
|
||||
__tablename__ = 'subsubsection'
|
||||
id = sa.Column(sa.Integer, primary_key=True)
|
||||
name = sa.Column(sa.Unicode(255))
|
||||
locale = sa.Column(sa.String(10))
|
||||
|
||||
subsection_id = sa.Column(
|
||||
sa.Integer, sa.ForeignKey(SubSection.id)
|
||||
)
|
||||
|
||||
subsection = sa.orm.relationship(
|
||||
SubSection, backref='subsubsections'
|
||||
)
|
||||
|
||||
self.Document = Document
|
||||
self.Section = Section
|
||||
self.SubSection = SubSection
|
||||
self.SubSubSection = SubSubSection
|
||||
|
||||
def test_getdotattr_for_objects(self):
|
||||
def test_simple_objects(self):
|
||||
document = self.Document(name=u'some document')
|
||||
section = self.Section(document=document)
|
||||
subsection = self.SubSection(section=section)
|
||||
@@ -49,7 +61,20 @@ class TestGetDotAttr(TestCase):
|
||||
'section.document.name'
|
||||
) == u'some document'
|
||||
|
||||
def test_getdotattr_for_class_paths(self):
|
||||
def test_with_instrumented_lists(self):
|
||||
document = self.Document(name=u'some document')
|
||||
section = self.Section(document=document)
|
||||
subsection = self.SubSection(section=section)
|
||||
subsubsection = self.SubSubSection(subsection=subsection)
|
||||
|
||||
assert getdotattr(document, 'sections.subsections') == [
|
||||
[subsection]
|
||||
]
|
||||
assert getdotattr(document, 'sections.subsections.subsubsections') == [
|
||||
[subsubsection]
|
||||
]
|
||||
|
||||
def test_class_paths(self):
|
||||
assert getdotattr(self.Section, 'document') is self.Section.document
|
||||
assert (
|
||||
getdotattr(self.SubSection, 'section.document') is
|
||||
|
Reference in New Issue
Block a user