Add attr path helpers

This commit is contained in:
Konsta Vesterinen
2014-02-18 13:42:07 +02:00
parent 09c2d7ddc7
commit fc7ac78ff1
3 changed files with 267 additions and 0 deletions

View File

@@ -4,6 +4,12 @@ Changelog
Here you can see the full list of changes between each SQLAlchemy-Utils release.
0.24.0 (2014-02-18)
^^^^^^^^^^^^^^^^^^^
- Added getdotattr
0.23.5 (2014-02-15)
^^^^^^^^^^^^^^^^^^^

109
sqlalchemy_utils/path.py Normal file
View File

@@ -0,0 +1,109 @@
from sqlalchemy.orm.attributes import InstrumentedAttribute
class Path(object):
def __init__(self, path, separator='.'):
if isinstance(path, Path):
self.path = path.path
else:
self.path = path
self.separator = separator
def __iter__(self):
for part in self.path.split(self.separator):
yield part
def __len__(self):
return len(self.path.split(self.separator))
def __repr__(self):
return "%s('%s')" % (self.__class__.__name__, self.path)
def __getitem__(self, slice):
result = self.path.split(self.separator)[slice]
if isinstance(result, list):
return self.__class__(
self.separator.join(result),
separator=self.separator
)
return result
def __eq__(self, other):
return self.path == other.path and self.separator == other.separator
def __ne__(self, other):
return not (self == other)
def get_attr(mixed, attr):
if isinstance(mixed, InstrumentedAttribute):
return getattr(
mixed.property.mapper.class_,
attr
)
else:
return getattr(mixed, attr)
class AttrPath(object):
def __init__(self, class_, path):
self.class_ = class_
self.path = Path(path)
self.parts = []
last_attr = class_
for value in self.path:
last_attr = get_attr(last_attr, value)
self.parts.append(last_attr)
def __iter__(self):
for part in self.parts:
yield part
def __invert__(self):
def get_backref(part):
prop = part.property
backref = prop.backref
if backref is None:
raise Exception(
"Invert failed because property '%s' of class "
"%s has no backref." % (
prop.key,
prop.mapper.class_.__name__
)
)
return backref
return self.__class__(
self.parts[-1].mapper.class_,
'.'.join(map(get_backref, reversed(self.parts)))
)
def __getitem__(self, slice):
result = self.parts[slice]
if isinstance(result, list):
if result[0] is self.parts[0]:
class_ = self.class_
else:
class_ = result[0].mapper.class_
return self.__class__(
class_,
self.path[slice]
)
else:
return result
def __len__(self):
return len(self.path)
def __repr__(self):
return "%s(%r, %r)" % (
self.__class__.__name__,
self.class_,
self.path.path
)
def __eq__(self, other):
return self.path == other.path and self.class_ == other.class_
def __ne__(self, other):
return not (self == other)

152
tests/test_path.py Normal file
View File

@@ -0,0 +1,152 @@
from pytest import mark
import sqlalchemy as sa
from sqlalchemy_utils.path import Path, AttrPath
from tests import TestCase
class TestAttrPath(TestCase):
def create_models(self):
class Document(self.Base):
__tablename__ = 'document'
id = sa.Column(sa.Integer, primary_key=True)
name = sa.Column(sa.Unicode(255))
locale = sa.Column(sa.String(10))
class Section(self.Base):
__tablename__ = 'section'
id = sa.Column(sa.Integer, primary_key=True)
name = sa.Column(sa.Unicode(255))
locale = sa.Column(sa.String(10))
document_id = sa.Column(
sa.Integer, sa.ForeignKey(Document.id)
)
document = sa.orm.relationship(Document, backref='sections')
class SubSection(self.Base):
__tablename__ = 'subsection'
id = sa.Column(sa.Integer, primary_key=True)
name = sa.Column(sa.Unicode(255))
locale = sa.Column(sa.String(10))
section_id = sa.Column(
sa.Integer, sa.ForeignKey(Section.id)
)
section = sa.orm.relationship(Section, backref='subsections')
self.Document = Document
self.Section = Section
self.SubSection = SubSection
def test_invert(self):
path = ~ AttrPath(self.SubSection, 'section.document')
assert path.parts == [
self.Document.sections,
self.Section.subsections
]
def test_len(self):
len(AttrPath(self.SubSection, 'section.document')) == 2
def test_init(self):
path = AttrPath(self.SubSection, 'section.document')
assert path.class_ == self.SubSection
assert path.path == Path('section.document')
def test_iter(self):
path = AttrPath(self.SubSection, 'section.document')
assert list(path) == [
self.SubSection.section,
self.Section.document
]
def test_repr(self):
path = AttrPath(self.SubSection, 'section.document')
assert repr(path) == (
"AttrPath(<class 'tests.test_path.SubSection'>, "
"'section.document')"
)
def test_getitem(self):
path = AttrPath(self.SubSection, 'section.document')
assert path[0] is self.SubSection.section
assert path[1] is self.Section.document
def test_getitem_with_slice(self):
path = AttrPath(self.SubSection, 'section.document')
assert path[:] == AttrPath(self.SubSection, 'section.document')
def test_eq(self):
assert (
AttrPath(self.SubSection, 'section.document') ==
AttrPath(self.SubSection, 'section.document')
)
assert not (
AttrPath(self.SubSection, 'section') ==
AttrPath(self.SubSection, 'section.document')
)
def test_ne(self):
assert not (
AttrPath(self.SubSection, 'section.document') !=
AttrPath(self.SubSection, 'section.document')
)
assert (
AttrPath(self.SubSection, 'section') !=
AttrPath(self.SubSection, 'section.document')
)
class TestPath(object):
def test_init(self):
path = Path('attr.attr2')
assert path.path == 'attr.attr2'
def test_init_with_path_object(self):
path = Path(Path('attr.attr2'))
assert path.path == 'attr.attr2'
def test_iter(self):
path = Path('s.s2.s3')
assert list(path) == ['s', 's2', 's3']
@mark.parametrize(('path', 'length'), (
(Path('s.s2.s3'), 3),
(Path('s.s2'), 2),
(Path(''), 0)
))
def test_len(self, path, length):
return len(path) == length
def test_reversed(self):
path = Path('s.s2.s3')
assert list(reversed(path)) == ['s3', 's2', 's']
def test_repr(self):
path = Path('s.s2')
assert repr(path) == "Path('s.s2')"
def test_getitem(self):
path = Path('s.s2')
assert path[0] == 's'
assert path[1] == 's2'
def test_getitem_with_slice(self):
path = Path('s.s2.s3')
assert path[1:] == Path('s2.s3')
@mark.parametrize(('test', 'result'), (
(Path('s.s2') == Path('s.s2'), True),
(Path('s.s2') == Path('s.s3'), False)
))
def test_eq(self, test, result):
assert test is result
@mark.parametrize(('test', 'result'), (
(Path('s.s2') != Path('s.s2'), False),
(Path('s.s2') != Path('s.s3'), True)
))
def test_ne(self, test, result):
assert test is result