Add attr path helpers
This commit is contained in:
@@ -4,6 +4,12 @@ Changelog
|
|||||||
Here you can see the full list of changes between each SQLAlchemy-Utils release.
|
Here you can see the full list of changes between each SQLAlchemy-Utils release.
|
||||||
|
|
||||||
|
|
||||||
|
0.24.0 (2014-02-18)
|
||||||
|
^^^^^^^^^^^^^^^^^^^
|
||||||
|
|
||||||
|
- Added getdotattr
|
||||||
|
|
||||||
|
|
||||||
0.23.5 (2014-02-15)
|
0.23.5 (2014-02-15)
|
||||||
^^^^^^^^^^^^^^^^^^^
|
^^^^^^^^^^^^^^^^^^^
|
||||||
|
|
||||||
|
109
sqlalchemy_utils/path.py
Normal file
109
sqlalchemy_utils/path.py
Normal file
@@ -0,0 +1,109 @@
|
|||||||
|
from sqlalchemy.orm.attributes import InstrumentedAttribute
|
||||||
|
|
||||||
|
|
||||||
|
class Path(object):
|
||||||
|
def __init__(self, path, separator='.'):
|
||||||
|
if isinstance(path, Path):
|
||||||
|
self.path = path.path
|
||||||
|
else:
|
||||||
|
self.path = path
|
||||||
|
self.separator = separator
|
||||||
|
|
||||||
|
def __iter__(self):
|
||||||
|
for part in self.path.split(self.separator):
|
||||||
|
yield part
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
return len(self.path.split(self.separator))
|
||||||
|
|
||||||
|
def __repr__(self):
|
||||||
|
return "%s('%s')" % (self.__class__.__name__, self.path)
|
||||||
|
|
||||||
|
def __getitem__(self, slice):
|
||||||
|
result = self.path.split(self.separator)[slice]
|
||||||
|
if isinstance(result, list):
|
||||||
|
return self.__class__(
|
||||||
|
self.separator.join(result),
|
||||||
|
separator=self.separator
|
||||||
|
)
|
||||||
|
return result
|
||||||
|
|
||||||
|
def __eq__(self, other):
|
||||||
|
return self.path == other.path and self.separator == other.separator
|
||||||
|
|
||||||
|
def __ne__(self, other):
|
||||||
|
return not (self == other)
|
||||||
|
|
||||||
|
|
||||||
|
def get_attr(mixed, attr):
|
||||||
|
if isinstance(mixed, InstrumentedAttribute):
|
||||||
|
return getattr(
|
||||||
|
mixed.property.mapper.class_,
|
||||||
|
attr
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
return getattr(mixed, attr)
|
||||||
|
|
||||||
|
|
||||||
|
class AttrPath(object):
|
||||||
|
def __init__(self, class_, path):
|
||||||
|
self.class_ = class_
|
||||||
|
self.path = Path(path)
|
||||||
|
self.parts = []
|
||||||
|
last_attr = class_
|
||||||
|
for value in self.path:
|
||||||
|
last_attr = get_attr(last_attr, value)
|
||||||
|
self.parts.append(last_attr)
|
||||||
|
|
||||||
|
def __iter__(self):
|
||||||
|
for part in self.parts:
|
||||||
|
yield part
|
||||||
|
|
||||||
|
def __invert__(self):
|
||||||
|
def get_backref(part):
|
||||||
|
prop = part.property
|
||||||
|
backref = prop.backref
|
||||||
|
if backref is None:
|
||||||
|
raise Exception(
|
||||||
|
"Invert failed because property '%s' of class "
|
||||||
|
"%s has no backref." % (
|
||||||
|
prop.key,
|
||||||
|
prop.mapper.class_.__name__
|
||||||
|
)
|
||||||
|
)
|
||||||
|
return backref
|
||||||
|
|
||||||
|
return self.__class__(
|
||||||
|
self.parts[-1].mapper.class_,
|
||||||
|
'.'.join(map(get_backref, reversed(self.parts)))
|
||||||
|
)
|
||||||
|
|
||||||
|
def __getitem__(self, slice):
|
||||||
|
result = self.parts[slice]
|
||||||
|
if isinstance(result, list):
|
||||||
|
if result[0] is self.parts[0]:
|
||||||
|
class_ = self.class_
|
||||||
|
else:
|
||||||
|
class_ = result[0].mapper.class_
|
||||||
|
return self.__class__(
|
||||||
|
class_,
|
||||||
|
self.path[slice]
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
return result
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
return len(self.path)
|
||||||
|
|
||||||
|
def __repr__(self):
|
||||||
|
return "%s(%r, %r)" % (
|
||||||
|
self.__class__.__name__,
|
||||||
|
self.class_,
|
||||||
|
self.path.path
|
||||||
|
)
|
||||||
|
|
||||||
|
def __eq__(self, other):
|
||||||
|
return self.path == other.path and self.class_ == other.class_
|
||||||
|
|
||||||
|
def __ne__(self, other):
|
||||||
|
return not (self == other)
|
152
tests/test_path.py
Normal file
152
tests/test_path.py
Normal file
@@ -0,0 +1,152 @@
|
|||||||
|
from pytest import mark
|
||||||
|
import sqlalchemy as sa
|
||||||
|
from sqlalchemy_utils.path import Path, AttrPath
|
||||||
|
from tests import TestCase
|
||||||
|
|
||||||
|
|
||||||
|
class TestAttrPath(TestCase):
|
||||||
|
def create_models(self):
|
||||||
|
class Document(self.Base):
|
||||||
|
__tablename__ = 'document'
|
||||||
|
id = sa.Column(sa.Integer, primary_key=True)
|
||||||
|
name = sa.Column(sa.Unicode(255))
|
||||||
|
locale = sa.Column(sa.String(10))
|
||||||
|
|
||||||
|
class Section(self.Base):
|
||||||
|
__tablename__ = 'section'
|
||||||
|
id = sa.Column(sa.Integer, primary_key=True)
|
||||||
|
name = sa.Column(sa.Unicode(255))
|
||||||
|
locale = sa.Column(sa.String(10))
|
||||||
|
|
||||||
|
document_id = sa.Column(
|
||||||
|
sa.Integer, sa.ForeignKey(Document.id)
|
||||||
|
)
|
||||||
|
|
||||||
|
document = sa.orm.relationship(Document, backref='sections')
|
||||||
|
|
||||||
|
class SubSection(self.Base):
|
||||||
|
__tablename__ = 'subsection'
|
||||||
|
id = sa.Column(sa.Integer, primary_key=True)
|
||||||
|
name = sa.Column(sa.Unicode(255))
|
||||||
|
locale = sa.Column(sa.String(10))
|
||||||
|
|
||||||
|
section_id = sa.Column(
|
||||||
|
sa.Integer, sa.ForeignKey(Section.id)
|
||||||
|
)
|
||||||
|
|
||||||
|
section = sa.orm.relationship(Section, backref='subsections')
|
||||||
|
|
||||||
|
self.Document = Document
|
||||||
|
self.Section = Section
|
||||||
|
self.SubSection = SubSection
|
||||||
|
|
||||||
|
def test_invert(self):
|
||||||
|
path = ~ AttrPath(self.SubSection, 'section.document')
|
||||||
|
assert path.parts == [
|
||||||
|
self.Document.sections,
|
||||||
|
self.Section.subsections
|
||||||
|
]
|
||||||
|
|
||||||
|
def test_len(self):
|
||||||
|
len(AttrPath(self.SubSection, 'section.document')) == 2
|
||||||
|
|
||||||
|
def test_init(self):
|
||||||
|
path = AttrPath(self.SubSection, 'section.document')
|
||||||
|
assert path.class_ == self.SubSection
|
||||||
|
assert path.path == Path('section.document')
|
||||||
|
|
||||||
|
def test_iter(self):
|
||||||
|
path = AttrPath(self.SubSection, 'section.document')
|
||||||
|
assert list(path) == [
|
||||||
|
self.SubSection.section,
|
||||||
|
self.Section.document
|
||||||
|
]
|
||||||
|
|
||||||
|
def test_repr(self):
|
||||||
|
path = AttrPath(self.SubSection, 'section.document')
|
||||||
|
assert repr(path) == (
|
||||||
|
"AttrPath(<class 'tests.test_path.SubSection'>, "
|
||||||
|
"'section.document')"
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_getitem(self):
|
||||||
|
path = AttrPath(self.SubSection, 'section.document')
|
||||||
|
assert path[0] is self.SubSection.section
|
||||||
|
assert path[1] is self.Section.document
|
||||||
|
|
||||||
|
def test_getitem_with_slice(self):
|
||||||
|
path = AttrPath(self.SubSection, 'section.document')
|
||||||
|
assert path[:] == AttrPath(self.SubSection, 'section.document')
|
||||||
|
|
||||||
|
def test_eq(self):
|
||||||
|
assert (
|
||||||
|
AttrPath(self.SubSection, 'section.document') ==
|
||||||
|
AttrPath(self.SubSection, 'section.document')
|
||||||
|
)
|
||||||
|
assert not (
|
||||||
|
AttrPath(self.SubSection, 'section') ==
|
||||||
|
AttrPath(self.SubSection, 'section.document')
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_ne(self):
|
||||||
|
assert not (
|
||||||
|
AttrPath(self.SubSection, 'section.document') !=
|
||||||
|
AttrPath(self.SubSection, 'section.document')
|
||||||
|
)
|
||||||
|
assert (
|
||||||
|
AttrPath(self.SubSection, 'section') !=
|
||||||
|
AttrPath(self.SubSection, 'section.document')
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class TestPath(object):
|
||||||
|
def test_init(self):
|
||||||
|
path = Path('attr.attr2')
|
||||||
|
assert path.path == 'attr.attr2'
|
||||||
|
|
||||||
|
def test_init_with_path_object(self):
|
||||||
|
path = Path(Path('attr.attr2'))
|
||||||
|
assert path.path == 'attr.attr2'
|
||||||
|
|
||||||
|
def test_iter(self):
|
||||||
|
path = Path('s.s2.s3')
|
||||||
|
assert list(path) == ['s', 's2', 's3']
|
||||||
|
|
||||||
|
@mark.parametrize(('path', 'length'), (
|
||||||
|
(Path('s.s2.s3'), 3),
|
||||||
|
(Path('s.s2'), 2),
|
||||||
|
(Path(''), 0)
|
||||||
|
))
|
||||||
|
def test_len(self, path, length):
|
||||||
|
return len(path) == length
|
||||||
|
|
||||||
|
def test_reversed(self):
|
||||||
|
path = Path('s.s2.s3')
|
||||||
|
assert list(reversed(path)) == ['s3', 's2', 's']
|
||||||
|
|
||||||
|
def test_repr(self):
|
||||||
|
path = Path('s.s2')
|
||||||
|
assert repr(path) == "Path('s.s2')"
|
||||||
|
|
||||||
|
def test_getitem(self):
|
||||||
|
path = Path('s.s2')
|
||||||
|
assert path[0] == 's'
|
||||||
|
assert path[1] == 's2'
|
||||||
|
|
||||||
|
def test_getitem_with_slice(self):
|
||||||
|
path = Path('s.s2.s3')
|
||||||
|
assert path[1:] == Path('s2.s3')
|
||||||
|
|
||||||
|
@mark.parametrize(('test', 'result'), (
|
||||||
|
(Path('s.s2') == Path('s.s2'), True),
|
||||||
|
(Path('s.s2') == Path('s.s3'), False)
|
||||||
|
))
|
||||||
|
def test_eq(self, test, result):
|
||||||
|
assert test is result
|
||||||
|
|
||||||
|
@mark.parametrize(('test', 'result'), (
|
||||||
|
(Path('s.s2') != Path('s.s2'), False),
|
||||||
|
(Path('s.s2') != Path('s.s3'), True)
|
||||||
|
))
|
||||||
|
def test_ne(self, test, result):
|
||||||
|
assert test is result
|
Reference in New Issue
Block a user