diff --git a/sqlalchemy_utils/path.py b/sqlalchemy_utils/path.py index 9538245..765fb0e 100644 --- a/sqlalchemy_utils/path.py +++ b/sqlalchemy_utils/path.py @@ -11,18 +11,25 @@ class Path(object): self.path = path self.separator = separator + @property + def parts(self): + return self.path.split(self.separator) + def __iter__(self): - for part in self.path.split(self.separator): + for part in self.parts: yield part def __len__(self): - return len(self.path.split(self.separator)) + return len(self.parts) def __repr__(self): return "%s('%s')" % (self.__class__.__name__, self.path) + def index(self, element): + return self.parts.index(element) + def __getitem__(self, slice): - result = self.path.split(self.separator)[slice] + result = self.parts[slice] if isinstance(result, list): return self.__class__( self.separator.join(result), @@ -84,6 +91,11 @@ class AttrPath(object): '.'.join(map(get_backref, reversed(self.parts))) ) + def index(self, element): + for index, el in enumerate(self.parts): + if el is element: + return index + def __getitem__(self, slice): result = self.parts[slice] if isinstance(result, list): diff --git a/tests/test_path.py b/tests/test_path.py index 0dac1bc..b4c4a41 100644 --- a/tests/test_path.py +++ b/tests/test_path.py @@ -70,6 +70,11 @@ class TestAttrPath(TestCase): "'section.document')" ) + def test_index(self): + path = AttrPath(self.SubSection, 'section.document') + assert path.index(self.Section.document) == 1 + assert path.index(self.SubSection.section) == 0 + def test_getitem(self): path = AttrPath(self.SubSection, 'section.document') assert path[0] is self.SubSection.section @@ -139,6 +144,9 @@ class TestPath(object): def test_str(self): assert str(Path('s.s2')) == 's.s2' + def test_index(self): + assert Path('s.s2.s3').index('s2') == 1 + def test_unicode(self): assert unicode(Path('s.s2')) == u's.s2'