diff --git a/sqlalchemy_utils/functions/orm.py b/sqlalchemy_utils/functions/orm.py index 12626a4..137f63c 100644 --- a/sqlalchemy_utils/functions/orm.py +++ b/sqlalchemy_utils/functions/orm.py @@ -220,6 +220,9 @@ def declarative_base(model): return model +from operator import attrgetter + + def getdotattr(obj_or_class, dot_path): """ Allow dot-notated strings to be passed to `getattr`. @@ -234,16 +237,23 @@ def getdotattr(obj_or_class, dot_path): :param obj_or_class: Any object or class :param dot_path: Attribute path with dot mark as separator """ - def get_attr(mixed, attr): - if isinstance(mixed, InstrumentedAttribute): - return getattr( - mixed.property.mapper.class_, - attr - ) - else: - return getattr(mixed, attr) + last = obj_or_class - return six.moves.reduce(get_attr, dot_path.split('.'), obj_or_class) + for path in dot_path.split('.'): + getter = attrgetter(path) + if isinstance(last, list): + tmp = [] + for el in last: + if isinstance(el, list): + tmp.extend(map(getter, el)) + else: + tmp.append(getter(el)) + last = tmp + elif isinstance(last, InstrumentedAttribute): + last = getter(last.property.mapper.class_) + else: + last = getter(last) + return last def has_changes(obj, attr): diff --git a/tests/functions/test_getdotattr.py b/tests/functions/test_getdotattr.py index 2eeae43..36f082f 100644 --- a/tests/functions/test_getdotattr.py +++ b/tests/functions/test_getdotattr.py @@ -9,13 +9,11 @@ class TestGetDotAttr(TestCase): __tablename__ = 'document' id = sa.Column(sa.Integer, primary_key=True) name = sa.Column(sa.Unicode(255)) - locale = sa.Column(sa.String(10)) class Section(self.Base): __tablename__ = 'section' id = sa.Column(sa.Integer, primary_key=True) name = sa.Column(sa.Unicode(255)) - locale = sa.Column(sa.String(10)) document_id = sa.Column( sa.Integer, sa.ForeignKey(Document.id) @@ -27,7 +25,6 @@ class TestGetDotAttr(TestCase): __tablename__ = 'subsection' id = sa.Column(sa.Integer, primary_key=True) name = sa.Column(sa.Unicode(255)) - locale = sa.Column(sa.String(10)) section_id = sa.Column( sa.Integer, sa.ForeignKey(Section.id) @@ -35,11 +32,26 @@ class TestGetDotAttr(TestCase): section = sa.orm.relationship(Section, backref='subsections') + class SubSubSection(self.Base): + __tablename__ = 'subsubsection' + id = sa.Column(sa.Integer, primary_key=True) + name = sa.Column(sa.Unicode(255)) + locale = sa.Column(sa.String(10)) + + subsection_id = sa.Column( + sa.Integer, sa.ForeignKey(SubSection.id) + ) + + subsection = sa.orm.relationship( + SubSection, backref='subsubsections' + ) + self.Document = Document self.Section = Section self.SubSection = SubSection + self.SubSubSection = SubSubSection - def test_getdotattr_for_objects(self): + def test_simple_objects(self): document = self.Document(name=u'some document') section = self.Section(document=document) subsection = self.SubSection(section=section) @@ -49,7 +61,20 @@ class TestGetDotAttr(TestCase): 'section.document.name' ) == u'some document' - def test_getdotattr_for_class_paths(self): + def test_with_instrumented_lists(self): + document = self.Document(name=u'some document') + section = self.Section(document=document) + subsection = self.SubSection(section=section) + subsubsection = self.SubSubSection(subsection=subsection) + + assert getdotattr(document, 'sections.subsections') == [ + [subsection] + ] + assert getdotattr(document, 'sections.subsections.subsubsections') == [ + [subsubsection] + ] + + def test_class_paths(self): assert getdotattr(self.Section, 'document') is self.Section.document assert ( getdotattr(self.SubSection, 'section.document') is