diff --git a/CHANGES.rst b/CHANGES.rst index 8ac071a..75d284c 100644 --- a/CHANGES.rst +++ b/CHANGES.rst @@ -4,6 +4,12 @@ Changelog Here you can see the full list of changes between each SQLAlchemy-Utils release. +0.24.0 (2014-02-18) +^^^^^^^^^^^^^^^^^^^ + +- Added getdotattr + + 0.23.5 (2014-02-15) ^^^^^^^^^^^^^^^^^^^ diff --git a/sqlalchemy_utils/path.py b/sqlalchemy_utils/path.py new file mode 100644 index 0000000..7b8ad2b --- /dev/null +++ b/sqlalchemy_utils/path.py @@ -0,0 +1,109 @@ +from sqlalchemy.orm.attributes import InstrumentedAttribute + + +class Path(object): + def __init__(self, path, separator='.'): + if isinstance(path, Path): + self.path = path.path + else: + self.path = path + self.separator = separator + + def __iter__(self): + for part in self.path.split(self.separator): + yield part + + def __len__(self): + return len(self.path.split(self.separator)) + + def __repr__(self): + return "%s('%s')" % (self.__class__.__name__, self.path) + + def __getitem__(self, slice): + result = self.path.split(self.separator)[slice] + if isinstance(result, list): + return self.__class__( + self.separator.join(result), + separator=self.separator + ) + return result + + def __eq__(self, other): + return self.path == other.path and self.separator == other.separator + + def __ne__(self, other): + return not (self == other) + + +def get_attr(mixed, attr): + if isinstance(mixed, InstrumentedAttribute): + return getattr( + mixed.property.mapper.class_, + attr + ) + else: + return getattr(mixed, attr) + + +class AttrPath(object): + def __init__(self, class_, path): + self.class_ = class_ + self.path = Path(path) + self.parts = [] + last_attr = class_ + for value in self.path: + last_attr = get_attr(last_attr, value) + self.parts.append(last_attr) + + def __iter__(self): + for part in self.parts: + yield part + + def __invert__(self): + def get_backref(part): + prop = part.property + backref = prop.backref + if backref is None: + raise Exception( + "Invert failed because property '%s' of class " + "%s has no backref." % ( + prop.key, + prop.mapper.class_.__name__ + ) + ) + return backref + + return self.__class__( + self.parts[-1].mapper.class_, + '.'.join(map(get_backref, reversed(self.parts))) + ) + + def __getitem__(self, slice): + result = self.parts[slice] + if isinstance(result, list): + if result[0] is self.parts[0]: + class_ = self.class_ + else: + class_ = result[0].mapper.class_ + return self.__class__( + class_, + self.path[slice] + ) + else: + return result + + def __len__(self): + return len(self.path) + + def __repr__(self): + return "%s(%r, %r)" % ( + self.__class__.__name__, + self.class_, + self.path.path + ) + + def __eq__(self, other): + return self.path == other.path and self.class_ == other.class_ + + def __ne__(self, other): + return not (self == other) diff --git a/tests/test_path.py b/tests/test_path.py new file mode 100644 index 0000000..660c032 --- /dev/null +++ b/tests/test_path.py @@ -0,0 +1,152 @@ +from pytest import mark +import sqlalchemy as sa +from sqlalchemy_utils.path import Path, AttrPath +from tests import TestCase + + +class TestAttrPath(TestCase): + def create_models(self): + class Document(self.Base): + __tablename__ = 'document' + id = sa.Column(sa.Integer, primary_key=True) + name = sa.Column(sa.Unicode(255)) + locale = sa.Column(sa.String(10)) + + class Section(self.Base): + __tablename__ = 'section' + id = sa.Column(sa.Integer, primary_key=True) + name = sa.Column(sa.Unicode(255)) + locale = sa.Column(sa.String(10)) + + document_id = sa.Column( + sa.Integer, sa.ForeignKey(Document.id) + ) + + document = sa.orm.relationship(Document, backref='sections') + + class SubSection(self.Base): + __tablename__ = 'subsection' + id = sa.Column(sa.Integer, primary_key=True) + name = sa.Column(sa.Unicode(255)) + locale = sa.Column(sa.String(10)) + + section_id = sa.Column( + sa.Integer, sa.ForeignKey(Section.id) + ) + + section = sa.orm.relationship(Section, backref='subsections') + + self.Document = Document + self.Section = Section + self.SubSection = SubSection + + def test_invert(self): + path = ~ AttrPath(self.SubSection, 'section.document') + assert path.parts == [ + self.Document.sections, + self.Section.subsections + ] + + def test_len(self): + len(AttrPath(self.SubSection, 'section.document')) == 2 + + def test_init(self): + path = AttrPath(self.SubSection, 'section.document') + assert path.class_ == self.SubSection + assert path.path == Path('section.document') + + def test_iter(self): + path = AttrPath(self.SubSection, 'section.document') + assert list(path) == [ + self.SubSection.section, + self.Section.document + ] + + def test_repr(self): + path = AttrPath(self.SubSection, 'section.document') + assert repr(path) == ( + "AttrPath(, " + "'section.document')" + ) + + def test_getitem(self): + path = AttrPath(self.SubSection, 'section.document') + assert path[0] is self.SubSection.section + assert path[1] is self.Section.document + + def test_getitem_with_slice(self): + path = AttrPath(self.SubSection, 'section.document') + assert path[:] == AttrPath(self.SubSection, 'section.document') + + def test_eq(self): + assert ( + AttrPath(self.SubSection, 'section.document') == + AttrPath(self.SubSection, 'section.document') + ) + assert not ( + AttrPath(self.SubSection, 'section') == + AttrPath(self.SubSection, 'section.document') + ) + + def test_ne(self): + assert not ( + AttrPath(self.SubSection, 'section.document') != + AttrPath(self.SubSection, 'section.document') + ) + assert ( + AttrPath(self.SubSection, 'section') != + AttrPath(self.SubSection, 'section.document') + ) + + +class TestPath(object): + def test_init(self): + path = Path('attr.attr2') + assert path.path == 'attr.attr2' + + def test_init_with_path_object(self): + path = Path(Path('attr.attr2')) + assert path.path == 'attr.attr2' + + def test_iter(self): + path = Path('s.s2.s3') + assert list(path) == ['s', 's2', 's3'] + + @mark.parametrize(('path', 'length'), ( + (Path('s.s2.s3'), 3), + (Path('s.s2'), 2), + (Path(''), 0) + )) + def test_len(self, path, length): + return len(path) == length + + def test_reversed(self): + path = Path('s.s2.s3') + assert list(reversed(path)) == ['s3', 's2', 's'] + + def test_repr(self): + path = Path('s.s2') + assert repr(path) == "Path('s.s2')" + + def test_getitem(self): + path = Path('s.s2') + assert path[0] == 's' + assert path[1] == 's2' + + def test_getitem_with_slice(self): + path = Path('s.s2.s3') + assert path[1:] == Path('s2.s3') + + @mark.parametrize(('test', 'result'), ( + (Path('s.s2') == Path('s.s2'), True), + (Path('s.s2') == Path('s.s3'), False) + )) + def test_eq(self, test, result): + assert test is result + + @mark.parametrize(('test', 'result'), ( + (Path('s.s2') != Path('s.s2'), False), + (Path('s.s2') != Path('s.s3'), True) + )) + def test_ne(self, test, result): + assert test is result