Add LtreeType for PostgreSQL
* Add Ltree primitive data type
This commit is contained in:
@@ -4,6 +4,13 @@ Changelog
|
||||
Here you can see the full list of changes between each SQLAlchemy-Utils release.
|
||||
|
||||
|
||||
0.32.4 (2016-04-20)
|
||||
^^^^^^^^^^^^^^^^^^^
|
||||
|
||||
- Added LtreeType for PostgreSQL ltree extension
|
||||
- Added Ltree primitive data type
|
||||
|
||||
|
||||
0.32.3 (2016-04-20)
|
||||
^^^^^^^^^^^^^^^^^^^
|
||||
|
||||
|
@@ -87,6 +87,19 @@ LocaleType
|
||||
.. autoclass:: LocaleType
|
||||
|
||||
|
||||
LtreeType
|
||||
---------
|
||||
|
||||
|
||||
.. module:: sqlalchemy_utils.types.ltree
|
||||
|
||||
.. autoclass:: LtreeType
|
||||
|
||||
.. module:: sqlalchemy_utils.primitives.ltree
|
||||
|
||||
.. autoclass:: Ltree
|
||||
|
||||
|
||||
IPAddressType
|
||||
-------------
|
||||
|
||||
|
@@ -55,7 +55,7 @@ from .listeners import ( # noqa
|
||||
)
|
||||
from .models import Timestamp # noqa
|
||||
from .observer import observes # noqa
|
||||
from .primitives import Country, Currency, WeekDay, WeekDays # noqa
|
||||
from .primitives import Country, Currency, Ltree, WeekDay, WeekDays # noqa
|
||||
from .proxy_dict import proxy_dict, ProxyDict # noqa
|
||||
from .query_chain import QueryChain # noqa
|
||||
from .types import ( # noqa
|
||||
@@ -77,6 +77,7 @@ from .types import ( # noqa
|
||||
IPAddressType,
|
||||
JSONType,
|
||||
LocaleType,
|
||||
LtreeType,
|
||||
NumericRangeType,
|
||||
Password,
|
||||
PasswordType,
|
||||
@@ -93,4 +94,4 @@ from .types import ( # noqa
|
||||
WeekDaysType
|
||||
)
|
||||
|
||||
__version__ = '0.32.3'
|
||||
__version__ = '0.32.4'
|
||||
|
@@ -1,4 +1,5 @@
|
||||
from .country import Country # noqa
|
||||
from .currency import Currency # noqa
|
||||
from .ltree import Ltree # noqa
|
||||
from .weekday import WeekDay # noqa
|
||||
from .weekdays import WeekDays # noqa
|
||||
|
180
sqlalchemy_utils/primitives/ltree.py
Normal file
180
sqlalchemy_utils/primitives/ltree.py
Normal file
@@ -0,0 +1,180 @@
|
||||
from __future__ import absolute_import
|
||||
|
||||
import re
|
||||
|
||||
import six
|
||||
|
||||
from ..utils import str_coercible
|
||||
|
||||
path_matcher = re.compile(r'^[A-Za-z0-9_]+(\.[A-Za-z0-9_]+)*$')
|
||||
|
||||
|
||||
@str_coercible
|
||||
class Ltree(object):
|
||||
"""
|
||||
Ltree class wraps a valid string label path. It provides various
|
||||
convenience properties and methods.
|
||||
|
||||
::
|
||||
|
||||
from sqlalchemy_utils import Ltree
|
||||
|
||||
Ltree('1.2.3').path # '1.2.3'
|
||||
|
||||
|
||||
Ltree always validates the given code.
|
||||
|
||||
::
|
||||
|
||||
Ltree(None) # raises TypeError
|
||||
|
||||
Ltree('..') # raises ValueError
|
||||
|
||||
|
||||
Ltree supports equality operators.
|
||||
|
||||
::
|
||||
|
||||
Ltree('Countries.Finland') == Ltree('Countries.Finland')
|
||||
Ltree('Countries.Germany') != Ltree('Countries.Finland')
|
||||
|
||||
|
||||
Ltree objects are hashable.
|
||||
|
||||
|
||||
::
|
||||
|
||||
assert hash(Ltree('Finland')) == hash('Finland')
|
||||
|
||||
|
||||
Ltree objects have length.
|
||||
|
||||
::
|
||||
|
||||
assert len(Ltree('1.2')) 2
|
||||
assert len(Ltree('some.one.some.where')) # 4
|
||||
|
||||
|
||||
You can easily find subpath indexes.
|
||||
|
||||
::
|
||||
|
||||
assert Ltree('1.2.3').index('2.3') == 1
|
||||
assert Ltree('1.2.3.4.5').index('3.4') == 2
|
||||
|
||||
|
||||
Ltree objects can be sliced.
|
||||
|
||||
|
||||
::
|
||||
|
||||
assert Ltree('1.2.3')[0:2] == Ltree('1.2')
|
||||
assert Ltree('1.2.3')[1:] == Ltree('2.3')
|
||||
|
||||
|
||||
Finding longest common ancestor.
|
||||
|
||||
|
||||
::
|
||||
|
||||
assert Ltree('1.2.3.4.5').lca('1.2.3', '1.2.3.4', '1.2.3') == '1.2'
|
||||
assert Ltree('1.2.3.4.5').lca('1.2', '1.2.3') == '1'
|
||||
|
||||
|
||||
Ltree objects can be concatenated.1
|
||||
|
||||
::
|
||||
|
||||
assert Ltree('1.2') + Ltree('1.2') == Ltree('1.2.1.2')
|
||||
"""
|
||||
|
||||
def __init__(self, path_or_ltree):
|
||||
if isinstance(path_or_ltree, Ltree):
|
||||
self.path = path_or_ltree.path
|
||||
elif isinstance(path_or_ltree, six.string_types):
|
||||
self.validate(path_or_ltree)
|
||||
self.path = path_or_ltree
|
||||
else:
|
||||
raise TypeError(
|
||||
"Ltree() argument must be a string or an Ltree, not '{0}'"
|
||||
.format(
|
||||
type(path_or_ltree).__name__
|
||||
)
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def validate(cls, path):
|
||||
if path_matcher.match(path) is None:
|
||||
raise ValueError(
|
||||
"'{0}' is not a valid ltree path.".format(path)
|
||||
)
|
||||
|
||||
def __len__(self):
|
||||
return len(self.path.split('.'))
|
||||
|
||||
def index(self, other):
|
||||
subpath = Ltree(other).path.split('.')
|
||||
parts = self.path.split('.')
|
||||
for index, _ in enumerate(parts):
|
||||
if parts[index:len(subpath) + index] == subpath:
|
||||
return index
|
||||
raise ValueError('subpath not found')
|
||||
|
||||
def __getitem__(self, key):
|
||||
if isinstance(key, int):
|
||||
return Ltree(self.path.split('.')[key])
|
||||
elif isinstance(key, slice):
|
||||
return Ltree('.'.join(self.path.split('.')[key]))
|
||||
raise TypeError(
|
||||
'Ltree indices must be integers, not {0}'.format(
|
||||
key.__class__.__name__
|
||||
)
|
||||
)
|
||||
|
||||
def lca(self, *others):
|
||||
"""
|
||||
Lowest common ancestor, i.e., longest common prefix of paths
|
||||
|
||||
::
|
||||
|
||||
assert Ltree('1.2.3.4.5').lca('1.2.3', '1.2.3.4', '1.2.3') == '1.2'
|
||||
"""
|
||||
other_parts = [Ltree(other).path.split('.') for other in others]
|
||||
parts = self.path.split('.')
|
||||
for index, element in enumerate(parts):
|
||||
if any((
|
||||
other[index] != element or len(other) <= index + 1
|
||||
for other in other_parts
|
||||
)):
|
||||
if index == 0:
|
||||
return None
|
||||
return Ltree('.'.join(parts[0:index]))
|
||||
|
||||
def __add__(self, other):
|
||||
return Ltree(self.path + '.' + Ltree(other).path)
|
||||
|
||||
def __radd__(self, other):
|
||||
return Ltree(other) + self
|
||||
|
||||
def __eq__(self, other):
|
||||
if isinstance(other, Ltree):
|
||||
return self.path == other.path
|
||||
elif isinstance(other, six.string_types):
|
||||
return self.path == other
|
||||
else:
|
||||
return NotImplemented
|
||||
|
||||
def __hash__(self):
|
||||
return hash(self.path)
|
||||
|
||||
def __ne__(self, other):
|
||||
return not (self == other)
|
||||
|
||||
def __repr__(self):
|
||||
return '%s(%r)' % (self.__class__.__name__, self.path)
|
||||
|
||||
def __unicode__(self):
|
||||
return self.path
|
||||
|
||||
def __contains__(self, label):
|
||||
return label in self.path.split('.')
|
@@ -12,6 +12,7 @@ from .encrypted import EncryptedType # noqa
|
||||
from .ip_address import IPAddressType # noqa
|
||||
from .json import JSONType # noqa
|
||||
from .locale import LocaleType # noqa
|
||||
from .ltree import LtreeType # noqa
|
||||
from .password import Password, PasswordType # noqa
|
||||
from .pg_composite import ( # noqa
|
||||
CompositeArray,
|
||||
|
107
sqlalchemy_utils/types/ltree.py
Normal file
107
sqlalchemy_utils/types/ltree.py
Normal file
@@ -0,0 +1,107 @@
|
||||
from __future__ import absolute_import
|
||||
|
||||
from sqlalchemy import types
|
||||
from sqlalchemy.dialects.postgresql.base import (
|
||||
ARRAY,
|
||||
ischema_names,
|
||||
PGTypeCompiler
|
||||
)
|
||||
from sqlalchemy.sql import expression
|
||||
|
||||
from ..primitives import Ltree
|
||||
from .scalar_coercible import ScalarCoercible
|
||||
|
||||
|
||||
class LtreeType(types.Concatenable, types.UserDefinedType, ScalarCoercible):
|
||||
"""Postgresql LtreeType type.
|
||||
|
||||
The LtreeType datatype can be used for representing labels of data stored
|
||||
in hierarchial tree-like structure. For more detailed information please
|
||||
refer to http://www.postgresql.org/docs/current/static/ltree.html
|
||||
|
||||
.. note::
|
||||
Using :class:`LtreeType`, :class:`LQUERY` and :class:`LTXTQUERY` types
|
||||
may require installation of Postgresql ltree extension on the server
|
||||
side. Please visit http://www.postgres.org for details.
|
||||
"""
|
||||
|
||||
class comparator_factory(types.Concatenable.Comparator):
|
||||
def ancestor_of(self, other):
|
||||
if isinstance(other, list):
|
||||
return self.op('@>')(expression.cast(other, ARRAY(LtreeType)))
|
||||
else:
|
||||
return self.op('@>')(other)
|
||||
|
||||
def descendant_of(self, other):
|
||||
if isinstance(other, list):
|
||||
return self.op('<@')(expression.cast(other, ARRAY(LtreeType)))
|
||||
else:
|
||||
return self.op('<@')(other)
|
||||
|
||||
def lquery(self, other):
|
||||
if isinstance(other, list):
|
||||
return self.op('?')(expression.cast(other, ARRAY(LQUERY)))
|
||||
else:
|
||||
return self.op('~')(other)
|
||||
|
||||
def ltxtquery(self, other):
|
||||
return self.op('@')(other)
|
||||
|
||||
def bind_processor(self, dialect):
|
||||
def process(value):
|
||||
if value:
|
||||
return value.path
|
||||
return process
|
||||
|
||||
def result_processor(self, dialect, coltype):
|
||||
def process(value):
|
||||
return self._coerce(value)
|
||||
return process
|
||||
|
||||
def literal_processor(self, dialect):
|
||||
def process(value):
|
||||
value = value.replace("'", "''")
|
||||
return "'%s'" % value
|
||||
return process
|
||||
|
||||
__visit_name__ = 'LTREE'
|
||||
|
||||
def _coerce(self, value):
|
||||
if value:
|
||||
return Ltree(value)
|
||||
|
||||
|
||||
class LQUERY(types.TypeEngine):
|
||||
"""Postresql LQUERY type.
|
||||
See :class:`LTREE` for details.
|
||||
"""
|
||||
__visit_name__ = 'LQUERY'
|
||||
|
||||
|
||||
class LTXTQUERY(types.TypeEngine):
|
||||
"""Postresql LTXTQUERY type.
|
||||
See :class:`LTREE` for details.
|
||||
"""
|
||||
__visit_name__ = 'LTXTQUERY'
|
||||
|
||||
|
||||
ischema_names['ltree'] = LtreeType
|
||||
ischema_names['lquery'] = LQUERY
|
||||
ischema_names['ltxtquery'] = LTXTQUERY
|
||||
|
||||
|
||||
def visit_LTREE(self, type_, **kw):
|
||||
return 'LTREE'
|
||||
|
||||
|
||||
def visit_LQUERY(self, type_, **kw):
|
||||
return 'LQUERY'
|
||||
|
||||
|
||||
def visit_LTXTQUERY(self, type_, **kw):
|
||||
return 'LTXTQUERY'
|
||||
|
||||
|
||||
PGTypeCompiler.visit_LTREE = visit_LTREE
|
||||
PGTypeCompiler.visit_LQUERY = visit_LQUERY
|
||||
PGTypeCompiler.visit_LTXTQUERY = visit_LTXTQUERY
|
173
tests/primitives/test_ltree.py
Normal file
173
tests/primitives/test_ltree.py
Normal file
@@ -0,0 +1,173 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
import pytest
|
||||
import six
|
||||
|
||||
from sqlalchemy_utils import Ltree
|
||||
|
||||
|
||||
class TestLtree(object):
|
||||
def test_init(self):
|
||||
assert Ltree('path.path') == Ltree(Ltree('path.path'))
|
||||
|
||||
def test_constructor_with_wrong_type(self):
|
||||
with pytest.raises(TypeError) as e:
|
||||
Ltree(None)
|
||||
assert str(e.value) == (
|
||||
"Ltree() argument must be a string or an Ltree, not 'NoneType'"
|
||||
)
|
||||
|
||||
def test_constructor_with_invalid_code(self):
|
||||
with pytest.raises(ValueError) as e:
|
||||
Ltree('..')
|
||||
assert str(e.value) == "'..' is not a valid ltree path."
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
'code',
|
||||
(
|
||||
'path',
|
||||
'path.path',
|
||||
'1_.2',
|
||||
'_._',
|
||||
)
|
||||
)
|
||||
def test_validate_with_valid_codes(self, code):
|
||||
Ltree.validate(code)
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
'path',
|
||||
(
|
||||
'',
|
||||
'.',
|
||||
'path.',
|
||||
'path..path',
|
||||
'path.path..path',
|
||||
'path.path..',
|
||||
'path.äö',
|
||||
)
|
||||
)
|
||||
def test_validate_with_invalid_path(self, path):
|
||||
with pytest.raises(ValueError) as e:
|
||||
Ltree.validate(path)
|
||||
assert str(e.value) == (
|
||||
"'{0}' is not a valid ltree path.".format(path)
|
||||
)
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
('path', 'length'),
|
||||
(
|
||||
('path', 1),
|
||||
('1.1', 2),
|
||||
('1.2.3', 3)
|
||||
)
|
||||
)
|
||||
def test_length(self, path, length):
|
||||
assert len(Ltree(path)) == length
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
('path', 'subpath', 'index'),
|
||||
(
|
||||
('path.path', 'path', 0),
|
||||
('1.2.3', '2.3', 1),
|
||||
('1.2.3.4', '2.3', 1),
|
||||
('1.2.3.4', '3.4', 2)
|
||||
)
|
||||
)
|
||||
def test_index(self, path, subpath, index):
|
||||
assert Ltree(path).index(subpath) == index
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
('path', 'item_slice', 'result'),
|
||||
(
|
||||
('path.path', 0, 'path'),
|
||||
('1.1.2.3', slice(1, 3), '1.2'),
|
||||
('1.1.2.3', slice(1, None), '1.2.3'),
|
||||
)
|
||||
)
|
||||
def test_getitem(self, path, item_slice, result):
|
||||
assert Ltree(path)[item_slice] == result
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
('path', 'others', 'result'),
|
||||
(
|
||||
('1.2.3', ['1.2.3', '1.2'], '1'),
|
||||
('1.2.3.4.5', ['1.2.3', '1.2.3.4'], '1.2'),
|
||||
('1.2.3.4.5', ['3.4', '1.2.3.4'], None),
|
||||
)
|
||||
)
|
||||
def test_lca(self, path, others, result):
|
||||
assert Ltree(path).lca(*others) == result
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
('path', 'other', 'result'),
|
||||
(
|
||||
('1.2.3', '4.5', '1.2.3.4.5'),
|
||||
('1', '1', '1.1'),
|
||||
)
|
||||
)
|
||||
def test_add(self, path, other, result):
|
||||
assert Ltree(path) + other == result
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
('path', 'other', 'result'),
|
||||
(
|
||||
('1.2.3', '4.5', '4.5.1.2.3'),
|
||||
('1', '1', '1.1'),
|
||||
)
|
||||
)
|
||||
def test_radd(self, path, other, result):
|
||||
assert other + Ltree(path) == result
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
('path', 'other', 'result'),
|
||||
(
|
||||
('1.2.3', '4.5', '1.2.3.4.5'),
|
||||
('1', '1', '1.1'),
|
||||
)
|
||||
)
|
||||
def test_iadd(self, path, other, result):
|
||||
ltree = Ltree(path)
|
||||
ltree += other
|
||||
assert ltree == result
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
('path', 'other', 'result'),
|
||||
(
|
||||
('1.2.3', '2', True),
|
||||
('1.2.3', '3', True),
|
||||
('1', '1', True),
|
||||
('1', '2', False),
|
||||
)
|
||||
)
|
||||
def test_contains(self, path, other, result):
|
||||
assert (other in Ltree(path)) == result
|
||||
|
||||
def test_getitem_with_other_than_slice_or_in(self):
|
||||
with pytest.raises(TypeError):
|
||||
Ltree('1.2')['something']
|
||||
|
||||
def test_index_raises_value_error_if_subpath_not_found(self):
|
||||
with pytest.raises(ValueError):
|
||||
Ltree('1.2').index('3')
|
||||
|
||||
def test_equality_operator(self):
|
||||
assert Ltree('path.path') == 'path.path'
|
||||
assert 'path.path' == Ltree('path.path')
|
||||
assert Ltree('path.path') == Ltree('path.path')
|
||||
|
||||
def test_non_equality_operator(self):
|
||||
assert Ltree('path.path') != u'path.'
|
||||
assert not (Ltree('path.path') != 'path.path')
|
||||
|
||||
def test_hash(self):
|
||||
return hash(Ltree('path')) == hash('path')
|
||||
|
||||
def test_repr(self):
|
||||
return repr(Ltree('path')) == "Ltree('path')"
|
||||
|
||||
def test_unicode(self):
|
||||
ltree = Ltree('path.path')
|
||||
assert six.text_type(ltree) == 'path.path'
|
||||
|
||||
def test_str(self):
|
||||
ltree = Ltree('path')
|
||||
assert str(ltree) == 'path'
|
@@ -11,7 +11,6 @@ def set_get_locale():
|
||||
|
||||
@pytest.fixture
|
||||
def User(Base):
|
||||
|
||||
class User(Base):
|
||||
__tablename__ = 'user'
|
||||
id = sa.Column(sa.Integer, primary_key=True)
|
||||
|
41
tests/types/test_ltree.py
Normal file
41
tests/types/test_ltree.py
Normal file
@@ -0,0 +1,41 @@
|
||||
import pytest
|
||||
import sqlalchemy as sa
|
||||
|
||||
from sqlalchemy_utils import Ltree, LtreeType
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def Section(Base):
|
||||
class Section(Base):
|
||||
__tablename__ = 'section'
|
||||
id = sa.Column(sa.Integer, primary_key=True)
|
||||
path = sa.Column(LtreeType)
|
||||
|
||||
return Section
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def init_models(Section, connection):
|
||||
connection.execute('CREATE EXTENSION IF NOT EXISTS ltree')
|
||||
pass
|
||||
|
||||
|
||||
@pytest.mark.usefixtures('postgresql_dsn')
|
||||
class TestLTREE(object):
|
||||
def test_saves_path(self, session, Section):
|
||||
section = Section(path='1.1.2')
|
||||
|
||||
session.add(section)
|
||||
session.commit()
|
||||
|
||||
user = session.query(Section).first()
|
||||
assert user.path == '1.1.2'
|
||||
|
||||
def test_scalar_attributes_get_coerced_to_objects(self, Section):
|
||||
section = Section(path='path.path')
|
||||
assert isinstance(section.path, Ltree)
|
||||
|
||||
def test_literal_param(self, session, Section):
|
||||
clause = Section.path == 'path'
|
||||
compiled = str(clause.compile(compile_kwargs={'literal_binds': True}))
|
||||
assert compiled == 'section.path = \'path\''
|
Reference in New Issue
Block a user