Fix getdotattr list handling
This commit is contained in:
@@ -220,6 +220,9 @@ def declarative_base(model):
|
|||||||
return model
|
return model
|
||||||
|
|
||||||
|
|
||||||
|
from operator import attrgetter
|
||||||
|
|
||||||
|
|
||||||
def getdotattr(obj_or_class, dot_path):
|
def getdotattr(obj_or_class, dot_path):
|
||||||
"""
|
"""
|
||||||
Allow dot-notated strings to be passed to `getattr`.
|
Allow dot-notated strings to be passed to `getattr`.
|
||||||
@@ -234,16 +237,23 @@ def getdotattr(obj_or_class, dot_path):
|
|||||||
:param obj_or_class: Any object or class
|
:param obj_or_class: Any object or class
|
||||||
:param dot_path: Attribute path with dot mark as separator
|
:param dot_path: Attribute path with dot mark as separator
|
||||||
"""
|
"""
|
||||||
def get_attr(mixed, attr):
|
last = obj_or_class
|
||||||
if isinstance(mixed, InstrumentedAttribute):
|
|
||||||
return getattr(
|
|
||||||
mixed.property.mapper.class_,
|
|
||||||
attr
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
return getattr(mixed, attr)
|
|
||||||
|
|
||||||
return six.moves.reduce(get_attr, dot_path.split('.'), obj_or_class)
|
for path in dot_path.split('.'):
|
||||||
|
getter = attrgetter(path)
|
||||||
|
if isinstance(last, list):
|
||||||
|
tmp = []
|
||||||
|
for el in last:
|
||||||
|
if isinstance(el, list):
|
||||||
|
tmp.extend(map(getter, el))
|
||||||
|
else:
|
||||||
|
tmp.append(getter(el))
|
||||||
|
last = tmp
|
||||||
|
elif isinstance(last, InstrumentedAttribute):
|
||||||
|
last = getter(last.property.mapper.class_)
|
||||||
|
else:
|
||||||
|
last = getter(last)
|
||||||
|
return last
|
||||||
|
|
||||||
|
|
||||||
def has_changes(obj, attr):
|
def has_changes(obj, attr):
|
||||||
|
@@ -9,13 +9,11 @@ class TestGetDotAttr(TestCase):
|
|||||||
__tablename__ = 'document'
|
__tablename__ = 'document'
|
||||||
id = sa.Column(sa.Integer, primary_key=True)
|
id = sa.Column(sa.Integer, primary_key=True)
|
||||||
name = sa.Column(sa.Unicode(255))
|
name = sa.Column(sa.Unicode(255))
|
||||||
locale = sa.Column(sa.String(10))
|
|
||||||
|
|
||||||
class Section(self.Base):
|
class Section(self.Base):
|
||||||
__tablename__ = 'section'
|
__tablename__ = 'section'
|
||||||
id = sa.Column(sa.Integer, primary_key=True)
|
id = sa.Column(sa.Integer, primary_key=True)
|
||||||
name = sa.Column(sa.Unicode(255))
|
name = sa.Column(sa.Unicode(255))
|
||||||
locale = sa.Column(sa.String(10))
|
|
||||||
|
|
||||||
document_id = sa.Column(
|
document_id = sa.Column(
|
||||||
sa.Integer, sa.ForeignKey(Document.id)
|
sa.Integer, sa.ForeignKey(Document.id)
|
||||||
@@ -27,7 +25,6 @@ class TestGetDotAttr(TestCase):
|
|||||||
__tablename__ = 'subsection'
|
__tablename__ = 'subsection'
|
||||||
id = sa.Column(sa.Integer, primary_key=True)
|
id = sa.Column(sa.Integer, primary_key=True)
|
||||||
name = sa.Column(sa.Unicode(255))
|
name = sa.Column(sa.Unicode(255))
|
||||||
locale = sa.Column(sa.String(10))
|
|
||||||
|
|
||||||
section_id = sa.Column(
|
section_id = sa.Column(
|
||||||
sa.Integer, sa.ForeignKey(Section.id)
|
sa.Integer, sa.ForeignKey(Section.id)
|
||||||
@@ -35,11 +32,26 @@ class TestGetDotAttr(TestCase):
|
|||||||
|
|
||||||
section = sa.orm.relationship(Section, backref='subsections')
|
section = sa.orm.relationship(Section, backref='subsections')
|
||||||
|
|
||||||
|
class SubSubSection(self.Base):
|
||||||
|
__tablename__ = 'subsubsection'
|
||||||
|
id = sa.Column(sa.Integer, primary_key=True)
|
||||||
|
name = sa.Column(sa.Unicode(255))
|
||||||
|
locale = sa.Column(sa.String(10))
|
||||||
|
|
||||||
|
subsection_id = sa.Column(
|
||||||
|
sa.Integer, sa.ForeignKey(SubSection.id)
|
||||||
|
)
|
||||||
|
|
||||||
|
subsection = sa.orm.relationship(
|
||||||
|
SubSection, backref='subsubsections'
|
||||||
|
)
|
||||||
|
|
||||||
self.Document = Document
|
self.Document = Document
|
||||||
self.Section = Section
|
self.Section = Section
|
||||||
self.SubSection = SubSection
|
self.SubSection = SubSection
|
||||||
|
self.SubSubSection = SubSubSection
|
||||||
|
|
||||||
def test_getdotattr_for_objects(self):
|
def test_simple_objects(self):
|
||||||
document = self.Document(name=u'some document')
|
document = self.Document(name=u'some document')
|
||||||
section = self.Section(document=document)
|
section = self.Section(document=document)
|
||||||
subsection = self.SubSection(section=section)
|
subsection = self.SubSection(section=section)
|
||||||
@@ -49,7 +61,20 @@ class TestGetDotAttr(TestCase):
|
|||||||
'section.document.name'
|
'section.document.name'
|
||||||
) == u'some document'
|
) == u'some document'
|
||||||
|
|
||||||
def test_getdotattr_for_class_paths(self):
|
def test_with_instrumented_lists(self):
|
||||||
|
document = self.Document(name=u'some document')
|
||||||
|
section = self.Section(document=document)
|
||||||
|
subsection = self.SubSection(section=section)
|
||||||
|
subsubsection = self.SubSubSection(subsection=subsection)
|
||||||
|
|
||||||
|
assert getdotattr(document, 'sections.subsections') == [
|
||||||
|
[subsection]
|
||||||
|
]
|
||||||
|
assert getdotattr(document, 'sections.subsections.subsubsections') == [
|
||||||
|
[subsubsection]
|
||||||
|
]
|
||||||
|
|
||||||
|
def test_class_paths(self):
|
||||||
assert getdotattr(self.Section, 'document') is self.Section.document
|
assert getdotattr(self.Section, 'document') is self.Section.document
|
||||||
assert (
|
assert (
|
||||||
getdotattr(self.SubSection, 'section.document') is
|
getdotattr(self.SubSection, 'section.document') is
|
||||||
|
Reference in New Issue
Block a user