Add attr path helpers
This commit is contained in:
@@ -4,6 +4,12 @@ Changelog
|
||||
Here you can see the full list of changes between each SQLAlchemy-Utils release.
|
||||
|
||||
|
||||
0.24.0 (2014-02-18)
|
||||
^^^^^^^^^^^^^^^^^^^
|
||||
|
||||
- Added getdotattr
|
||||
|
||||
|
||||
0.23.5 (2014-02-15)
|
||||
^^^^^^^^^^^^^^^^^^^
|
||||
|
||||
|
109
sqlalchemy_utils/path.py
Normal file
109
sqlalchemy_utils/path.py
Normal file
@@ -0,0 +1,109 @@
|
||||
from sqlalchemy.orm.attributes import InstrumentedAttribute
|
||||
|
||||
|
||||
class Path(object):
|
||||
def __init__(self, path, separator='.'):
|
||||
if isinstance(path, Path):
|
||||
self.path = path.path
|
||||
else:
|
||||
self.path = path
|
||||
self.separator = separator
|
||||
|
||||
def __iter__(self):
|
||||
for part in self.path.split(self.separator):
|
||||
yield part
|
||||
|
||||
def __len__(self):
|
||||
return len(self.path.split(self.separator))
|
||||
|
||||
def __repr__(self):
|
||||
return "%s('%s')" % (self.__class__.__name__, self.path)
|
||||
|
||||
def __getitem__(self, slice):
|
||||
result = self.path.split(self.separator)[slice]
|
||||
if isinstance(result, list):
|
||||
return self.__class__(
|
||||
self.separator.join(result),
|
||||
separator=self.separator
|
||||
)
|
||||
return result
|
||||
|
||||
def __eq__(self, other):
|
||||
return self.path == other.path and self.separator == other.separator
|
||||
|
||||
def __ne__(self, other):
|
||||
return not (self == other)
|
||||
|
||||
|
||||
def get_attr(mixed, attr):
|
||||
if isinstance(mixed, InstrumentedAttribute):
|
||||
return getattr(
|
||||
mixed.property.mapper.class_,
|
||||
attr
|
||||
)
|
||||
else:
|
||||
return getattr(mixed, attr)
|
||||
|
||||
|
||||
class AttrPath(object):
|
||||
def __init__(self, class_, path):
|
||||
self.class_ = class_
|
||||
self.path = Path(path)
|
||||
self.parts = []
|
||||
last_attr = class_
|
||||
for value in self.path:
|
||||
last_attr = get_attr(last_attr, value)
|
||||
self.parts.append(last_attr)
|
||||
|
||||
def __iter__(self):
|
||||
for part in self.parts:
|
||||
yield part
|
||||
|
||||
def __invert__(self):
|
||||
def get_backref(part):
|
||||
prop = part.property
|
||||
backref = prop.backref
|
||||
if backref is None:
|
||||
raise Exception(
|
||||
"Invert failed because property '%s' of class "
|
||||
"%s has no backref." % (
|
||||
prop.key,
|
||||
prop.mapper.class_.__name__
|
||||
)
|
||||
)
|
||||
return backref
|
||||
|
||||
return self.__class__(
|
||||
self.parts[-1].mapper.class_,
|
||||
'.'.join(map(get_backref, reversed(self.parts)))
|
||||
)
|
||||
|
||||
def __getitem__(self, slice):
|
||||
result = self.parts[slice]
|
||||
if isinstance(result, list):
|
||||
if result[0] is self.parts[0]:
|
||||
class_ = self.class_
|
||||
else:
|
||||
class_ = result[0].mapper.class_
|
||||
return self.__class__(
|
||||
class_,
|
||||
self.path[slice]
|
||||
)
|
||||
else:
|
||||
return result
|
||||
|
||||
def __len__(self):
|
||||
return len(self.path)
|
||||
|
||||
def __repr__(self):
|
||||
return "%s(%r, %r)" % (
|
||||
self.__class__.__name__,
|
||||
self.class_,
|
||||
self.path.path
|
||||
)
|
||||
|
||||
def __eq__(self, other):
|
||||
return self.path == other.path and self.class_ == other.class_
|
||||
|
||||
def __ne__(self, other):
|
||||
return not (self == other)
|
152
tests/test_path.py
Normal file
152
tests/test_path.py
Normal file
@@ -0,0 +1,152 @@
|
||||
from pytest import mark
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy_utils.path import Path, AttrPath
|
||||
from tests import TestCase
|
||||
|
||||
|
||||
class TestAttrPath(TestCase):
|
||||
def create_models(self):
|
||||
class Document(self.Base):
|
||||
__tablename__ = 'document'
|
||||
id = sa.Column(sa.Integer, primary_key=True)
|
||||
name = sa.Column(sa.Unicode(255))
|
||||
locale = sa.Column(sa.String(10))
|
||||
|
||||
class Section(self.Base):
|
||||
__tablename__ = 'section'
|
||||
id = sa.Column(sa.Integer, primary_key=True)
|
||||
name = sa.Column(sa.Unicode(255))
|
||||
locale = sa.Column(sa.String(10))
|
||||
|
||||
document_id = sa.Column(
|
||||
sa.Integer, sa.ForeignKey(Document.id)
|
||||
)
|
||||
|
||||
document = sa.orm.relationship(Document, backref='sections')
|
||||
|
||||
class SubSection(self.Base):
|
||||
__tablename__ = 'subsection'
|
||||
id = sa.Column(sa.Integer, primary_key=True)
|
||||
name = sa.Column(sa.Unicode(255))
|
||||
locale = sa.Column(sa.String(10))
|
||||
|
||||
section_id = sa.Column(
|
||||
sa.Integer, sa.ForeignKey(Section.id)
|
||||
)
|
||||
|
||||
section = sa.orm.relationship(Section, backref='subsections')
|
||||
|
||||
self.Document = Document
|
||||
self.Section = Section
|
||||
self.SubSection = SubSection
|
||||
|
||||
def test_invert(self):
|
||||
path = ~ AttrPath(self.SubSection, 'section.document')
|
||||
assert path.parts == [
|
||||
self.Document.sections,
|
||||
self.Section.subsections
|
||||
]
|
||||
|
||||
def test_len(self):
|
||||
len(AttrPath(self.SubSection, 'section.document')) == 2
|
||||
|
||||
def test_init(self):
|
||||
path = AttrPath(self.SubSection, 'section.document')
|
||||
assert path.class_ == self.SubSection
|
||||
assert path.path == Path('section.document')
|
||||
|
||||
def test_iter(self):
|
||||
path = AttrPath(self.SubSection, 'section.document')
|
||||
assert list(path) == [
|
||||
self.SubSection.section,
|
||||
self.Section.document
|
||||
]
|
||||
|
||||
def test_repr(self):
|
||||
path = AttrPath(self.SubSection, 'section.document')
|
||||
assert repr(path) == (
|
||||
"AttrPath(<class 'tests.test_path.SubSection'>, "
|
||||
"'section.document')"
|
||||
)
|
||||
|
||||
def test_getitem(self):
|
||||
path = AttrPath(self.SubSection, 'section.document')
|
||||
assert path[0] is self.SubSection.section
|
||||
assert path[1] is self.Section.document
|
||||
|
||||
def test_getitem_with_slice(self):
|
||||
path = AttrPath(self.SubSection, 'section.document')
|
||||
assert path[:] == AttrPath(self.SubSection, 'section.document')
|
||||
|
||||
def test_eq(self):
|
||||
assert (
|
||||
AttrPath(self.SubSection, 'section.document') ==
|
||||
AttrPath(self.SubSection, 'section.document')
|
||||
)
|
||||
assert not (
|
||||
AttrPath(self.SubSection, 'section') ==
|
||||
AttrPath(self.SubSection, 'section.document')
|
||||
)
|
||||
|
||||
def test_ne(self):
|
||||
assert not (
|
||||
AttrPath(self.SubSection, 'section.document') !=
|
||||
AttrPath(self.SubSection, 'section.document')
|
||||
)
|
||||
assert (
|
||||
AttrPath(self.SubSection, 'section') !=
|
||||
AttrPath(self.SubSection, 'section.document')
|
||||
)
|
||||
|
||||
|
||||
class TestPath(object):
|
||||
def test_init(self):
|
||||
path = Path('attr.attr2')
|
||||
assert path.path == 'attr.attr2'
|
||||
|
||||
def test_init_with_path_object(self):
|
||||
path = Path(Path('attr.attr2'))
|
||||
assert path.path == 'attr.attr2'
|
||||
|
||||
def test_iter(self):
|
||||
path = Path('s.s2.s3')
|
||||
assert list(path) == ['s', 's2', 's3']
|
||||
|
||||
@mark.parametrize(('path', 'length'), (
|
||||
(Path('s.s2.s3'), 3),
|
||||
(Path('s.s2'), 2),
|
||||
(Path(''), 0)
|
||||
))
|
||||
def test_len(self, path, length):
|
||||
return len(path) == length
|
||||
|
||||
def test_reversed(self):
|
||||
path = Path('s.s2.s3')
|
||||
assert list(reversed(path)) == ['s3', 's2', 's']
|
||||
|
||||
def test_repr(self):
|
||||
path = Path('s.s2')
|
||||
assert repr(path) == "Path('s.s2')"
|
||||
|
||||
def test_getitem(self):
|
||||
path = Path('s.s2')
|
||||
assert path[0] == 's'
|
||||
assert path[1] == 's2'
|
||||
|
||||
def test_getitem_with_slice(self):
|
||||
path = Path('s.s2.s3')
|
||||
assert path[1:] == Path('s2.s3')
|
||||
|
||||
@mark.parametrize(('test', 'result'), (
|
||||
(Path('s.s2') == Path('s.s2'), True),
|
||||
(Path('s.s2') == Path('s.s3'), False)
|
||||
))
|
||||
def test_eq(self, test, result):
|
||||
assert test is result
|
||||
|
||||
@mark.parametrize(('test', 'result'), (
|
||||
(Path('s.s2') != Path('s.s2'), False),
|
||||
(Path('s.s2') != Path('s.s3'), True)
|
||||
))
|
||||
def test_ne(self, test, result):
|
||||
assert test is result
|
Reference in New Issue
Block a user