Add LtreeType for PostgreSQL

* Add Ltree primitive data type
This commit is contained in:
Konsta Vesterinen
2016-04-25 13:43:15 +03:00
parent c3a931b549
commit 384a6bbd9b
10 changed files with 526 additions and 3 deletions

View File

@@ -4,6 +4,13 @@ Changelog
Here you can see the full list of changes between each SQLAlchemy-Utils release.
0.32.4 (2016-04-20)
^^^^^^^^^^^^^^^^^^^
- Added LtreeType for PostgreSQL ltree extension
- Added Ltree primitive data type
0.32.3 (2016-04-20)
^^^^^^^^^^^^^^^^^^^

View File

@@ -87,6 +87,19 @@ LocaleType
.. autoclass:: LocaleType
LtreeType
---------
.. module:: sqlalchemy_utils.types.ltree
.. autoclass:: LtreeType
.. module:: sqlalchemy_utils.primitives.ltree
.. autoclass:: Ltree
IPAddressType
-------------

View File

@@ -55,7 +55,7 @@ from .listeners import ( # noqa
)
from .models import Timestamp # noqa
from .observer import observes # noqa
from .primitives import Country, Currency, WeekDay, WeekDays # noqa
from .primitives import Country, Currency, Ltree, WeekDay, WeekDays # noqa
from .proxy_dict import proxy_dict, ProxyDict # noqa
from .query_chain import QueryChain # noqa
from .types import ( # noqa
@@ -77,6 +77,7 @@ from .types import ( # noqa
IPAddressType,
JSONType,
LocaleType,
LtreeType,
NumericRangeType,
Password,
PasswordType,
@@ -93,4 +94,4 @@ from .types import ( # noqa
WeekDaysType
)
__version__ = '0.32.3'
__version__ = '0.32.4'

View File

@@ -1,4 +1,5 @@
from .country import Country # noqa
from .currency import Currency # noqa
from .ltree import Ltree # noqa
from .weekday import WeekDay # noqa
from .weekdays import WeekDays # noqa

View File

@@ -0,0 +1,180 @@
from __future__ import absolute_import
import re
import six
from ..utils import str_coercible
path_matcher = re.compile(r'^[A-Za-z0-9_]+(\.[A-Za-z0-9_]+)*$')
@str_coercible
class Ltree(object):
"""
Ltree class wraps a valid string label path. It provides various
convenience properties and methods.
::
from sqlalchemy_utils import Ltree
Ltree('1.2.3').path # '1.2.3'
Ltree always validates the given code.
::
Ltree(None) # raises TypeError
Ltree('..') # raises ValueError
Ltree supports equality operators.
::
Ltree('Countries.Finland') == Ltree('Countries.Finland')
Ltree('Countries.Germany') != Ltree('Countries.Finland')
Ltree objects are hashable.
::
assert hash(Ltree('Finland')) == hash('Finland')
Ltree objects have length.
::
assert len(Ltree('1.2')) 2
assert len(Ltree('some.one.some.where')) # 4
You can easily find subpath indexes.
::
assert Ltree('1.2.3').index('2.3') == 1
assert Ltree('1.2.3.4.5').index('3.4') == 2
Ltree objects can be sliced.
::
assert Ltree('1.2.3')[0:2] == Ltree('1.2')
assert Ltree('1.2.3')[1:] == Ltree('2.3')
Finding longest common ancestor.
::
assert Ltree('1.2.3.4.5').lca('1.2.3', '1.2.3.4', '1.2.3') == '1.2'
assert Ltree('1.2.3.4.5').lca('1.2', '1.2.3') == '1'
Ltree objects can be concatenated.1
::
assert Ltree('1.2') + Ltree('1.2') == Ltree('1.2.1.2')
"""
def __init__(self, path_or_ltree):
if isinstance(path_or_ltree, Ltree):
self.path = path_or_ltree.path
elif isinstance(path_or_ltree, six.string_types):
self.validate(path_or_ltree)
self.path = path_or_ltree
else:
raise TypeError(
"Ltree() argument must be a string or an Ltree, not '{0}'"
.format(
type(path_or_ltree).__name__
)
)
@classmethod
def validate(cls, path):
if path_matcher.match(path) is None:
raise ValueError(
"'{0}' is not a valid ltree path.".format(path)
)
def __len__(self):
return len(self.path.split('.'))
def index(self, other):
subpath = Ltree(other).path.split('.')
parts = self.path.split('.')
for index, _ in enumerate(parts):
if parts[index:len(subpath) + index] == subpath:
return index
raise ValueError('subpath not found')
def __getitem__(self, key):
if isinstance(key, int):
return Ltree(self.path.split('.')[key])
elif isinstance(key, slice):
return Ltree('.'.join(self.path.split('.')[key]))
raise TypeError(
'Ltree indices must be integers, not {0}'.format(
key.__class__.__name__
)
)
def lca(self, *others):
"""
Lowest common ancestor, i.e., longest common prefix of paths
::
assert Ltree('1.2.3.4.5').lca('1.2.3', '1.2.3.4', '1.2.3') == '1.2'
"""
other_parts = [Ltree(other).path.split('.') for other in others]
parts = self.path.split('.')
for index, element in enumerate(parts):
if any((
other[index] != element or len(other) <= index + 1
for other in other_parts
)):
if index == 0:
return None
return Ltree('.'.join(parts[0:index]))
def __add__(self, other):
return Ltree(self.path + '.' + Ltree(other).path)
def __radd__(self, other):
return Ltree(other) + self
def __eq__(self, other):
if isinstance(other, Ltree):
return self.path == other.path
elif isinstance(other, six.string_types):
return self.path == other
else:
return NotImplemented
def __hash__(self):
return hash(self.path)
def __ne__(self, other):
return not (self == other)
def __repr__(self):
return '%s(%r)' % (self.__class__.__name__, self.path)
def __unicode__(self):
return self.path
def __contains__(self, label):
return label in self.path.split('.')

View File

@@ -12,6 +12,7 @@ from .encrypted import EncryptedType # noqa
from .ip_address import IPAddressType # noqa
from .json import JSONType # noqa
from .locale import LocaleType # noqa
from .ltree import LtreeType # noqa
from .password import Password, PasswordType # noqa
from .pg_composite import ( # noqa
CompositeArray,

View File

@@ -0,0 +1,107 @@
from __future__ import absolute_import
from sqlalchemy import types
from sqlalchemy.dialects.postgresql.base import (
ARRAY,
ischema_names,
PGTypeCompiler
)
from sqlalchemy.sql import expression
from ..primitives import Ltree
from .scalar_coercible import ScalarCoercible
class LtreeType(types.Concatenable, types.UserDefinedType, ScalarCoercible):
"""Postgresql LtreeType type.
The LtreeType datatype can be used for representing labels of data stored
in hierarchial tree-like structure. For more detailed information please
refer to http://www.postgresql.org/docs/current/static/ltree.html
.. note::
Using :class:`LtreeType`, :class:`LQUERY` and :class:`LTXTQUERY` types
may require installation of Postgresql ltree extension on the server
side. Please visit http://www.postgres.org for details.
"""
class comparator_factory(types.Concatenable.Comparator):
def ancestor_of(self, other):
if isinstance(other, list):
return self.op('@>')(expression.cast(other, ARRAY(LtreeType)))
else:
return self.op('@>')(other)
def descendant_of(self, other):
if isinstance(other, list):
return self.op('<@')(expression.cast(other, ARRAY(LtreeType)))
else:
return self.op('<@')(other)
def lquery(self, other):
if isinstance(other, list):
return self.op('?')(expression.cast(other, ARRAY(LQUERY)))
else:
return self.op('~')(other)
def ltxtquery(self, other):
return self.op('@')(other)
def bind_processor(self, dialect):
def process(value):
if value:
return value.path
return process
def result_processor(self, dialect, coltype):
def process(value):
return self._coerce(value)
return process
def literal_processor(self, dialect):
def process(value):
value = value.replace("'", "''")
return "'%s'" % value
return process
__visit_name__ = 'LTREE'
def _coerce(self, value):
if value:
return Ltree(value)
class LQUERY(types.TypeEngine):
"""Postresql LQUERY type.
See :class:`LTREE` for details.
"""
__visit_name__ = 'LQUERY'
class LTXTQUERY(types.TypeEngine):
"""Postresql LTXTQUERY type.
See :class:`LTREE` for details.
"""
__visit_name__ = 'LTXTQUERY'
ischema_names['ltree'] = LtreeType
ischema_names['lquery'] = LQUERY
ischema_names['ltxtquery'] = LTXTQUERY
def visit_LTREE(self, type_, **kw):
return 'LTREE'
def visit_LQUERY(self, type_, **kw):
return 'LQUERY'
def visit_LTXTQUERY(self, type_, **kw):
return 'LTXTQUERY'
PGTypeCompiler.visit_LTREE = visit_LTREE
PGTypeCompiler.visit_LQUERY = visit_LQUERY
PGTypeCompiler.visit_LTXTQUERY = visit_LTXTQUERY

View File

@@ -0,0 +1,173 @@
# -*- coding: utf-8 -*-
import pytest
import six
from sqlalchemy_utils import Ltree
class TestLtree(object):
def test_init(self):
assert Ltree('path.path') == Ltree(Ltree('path.path'))
def test_constructor_with_wrong_type(self):
with pytest.raises(TypeError) as e:
Ltree(None)
assert str(e.value) == (
"Ltree() argument must be a string or an Ltree, not 'NoneType'"
)
def test_constructor_with_invalid_code(self):
with pytest.raises(ValueError) as e:
Ltree('..')
assert str(e.value) == "'..' is not a valid ltree path."
@pytest.mark.parametrize(
'code',
(
'path',
'path.path',
'1_.2',
'_._',
)
)
def test_validate_with_valid_codes(self, code):
Ltree.validate(code)
@pytest.mark.parametrize(
'path',
(
'',
'.',
'path.',
'path..path',
'path.path..path',
'path.path..',
'path.äö',
)
)
def test_validate_with_invalid_path(self, path):
with pytest.raises(ValueError) as e:
Ltree.validate(path)
assert str(e.value) == (
"'{0}' is not a valid ltree path.".format(path)
)
@pytest.mark.parametrize(
('path', 'length'),
(
('path', 1),
('1.1', 2),
('1.2.3', 3)
)
)
def test_length(self, path, length):
assert len(Ltree(path)) == length
@pytest.mark.parametrize(
('path', 'subpath', 'index'),
(
('path.path', 'path', 0),
('1.2.3', '2.3', 1),
('1.2.3.4', '2.3', 1),
('1.2.3.4', '3.4', 2)
)
)
def test_index(self, path, subpath, index):
assert Ltree(path).index(subpath) == index
@pytest.mark.parametrize(
('path', 'item_slice', 'result'),
(
('path.path', 0, 'path'),
('1.1.2.3', slice(1, 3), '1.2'),
('1.1.2.3', slice(1, None), '1.2.3'),
)
)
def test_getitem(self, path, item_slice, result):
assert Ltree(path)[item_slice] == result
@pytest.mark.parametrize(
('path', 'others', 'result'),
(
('1.2.3', ['1.2.3', '1.2'], '1'),
('1.2.3.4.5', ['1.2.3', '1.2.3.4'], '1.2'),
('1.2.3.4.5', ['3.4', '1.2.3.4'], None),
)
)
def test_lca(self, path, others, result):
assert Ltree(path).lca(*others) == result
@pytest.mark.parametrize(
('path', 'other', 'result'),
(
('1.2.3', '4.5', '1.2.3.4.5'),
('1', '1', '1.1'),
)
)
def test_add(self, path, other, result):
assert Ltree(path) + other == result
@pytest.mark.parametrize(
('path', 'other', 'result'),
(
('1.2.3', '4.5', '4.5.1.2.3'),
('1', '1', '1.1'),
)
)
def test_radd(self, path, other, result):
assert other + Ltree(path) == result
@pytest.mark.parametrize(
('path', 'other', 'result'),
(
('1.2.3', '4.5', '1.2.3.4.5'),
('1', '1', '1.1'),
)
)
def test_iadd(self, path, other, result):
ltree = Ltree(path)
ltree += other
assert ltree == result
@pytest.mark.parametrize(
('path', 'other', 'result'),
(
('1.2.3', '2', True),
('1.2.3', '3', True),
('1', '1', True),
('1', '2', False),
)
)
def test_contains(self, path, other, result):
assert (other in Ltree(path)) == result
def test_getitem_with_other_than_slice_or_in(self):
with pytest.raises(TypeError):
Ltree('1.2')['something']
def test_index_raises_value_error_if_subpath_not_found(self):
with pytest.raises(ValueError):
Ltree('1.2').index('3')
def test_equality_operator(self):
assert Ltree('path.path') == 'path.path'
assert 'path.path' == Ltree('path.path')
assert Ltree('path.path') == Ltree('path.path')
def test_non_equality_operator(self):
assert Ltree('path.path') != u'path.'
assert not (Ltree('path.path') != 'path.path')
def test_hash(self):
return hash(Ltree('path')) == hash('path')
def test_repr(self):
return repr(Ltree('path')) == "Ltree('path')"
def test_unicode(self):
ltree = Ltree('path.path')
assert six.text_type(ltree) == 'path.path'
def test_str(self):
ltree = Ltree('path')
assert str(ltree) == 'path'

View File

@@ -11,7 +11,6 @@ def set_get_locale():
@pytest.fixture
def User(Base):
class User(Base):
__tablename__ = 'user'
id = sa.Column(sa.Integer, primary_key=True)

41
tests/types/test_ltree.py Normal file
View File

@@ -0,0 +1,41 @@
import pytest
import sqlalchemy as sa
from sqlalchemy_utils import Ltree, LtreeType
@pytest.fixture
def Section(Base):
class Section(Base):
__tablename__ = 'section'
id = sa.Column(sa.Integer, primary_key=True)
path = sa.Column(LtreeType)
return Section
@pytest.fixture
def init_models(Section, connection):
connection.execute('CREATE EXTENSION IF NOT EXISTS ltree')
pass
@pytest.mark.usefixtures('postgresql_dsn')
class TestLTREE(object):
def test_saves_path(self, session, Section):
section = Section(path='1.1.2')
session.add(section)
session.commit()
user = session.query(Section).first()
assert user.path == '1.1.2'
def test_scalar_attributes_get_coerced_to_objects(self, Section):
section = Section(path='path.path')
assert isinstance(section.path, Ltree)
def test_literal_param(self, session, Section):
clause = Section.path == 'path'
compiled = str(clause.compile(compile_kwargs={'literal_binds': True}))
assert compiled == 'section.path = \'path\''