From 384a6bbd9b4df4b617787af13f763db212c17cf0 Mon Sep 17 00:00:00 2001 From: Konsta Vesterinen Date: Mon, 25 Apr 2016 13:43:15 +0300 Subject: [PATCH] Add LtreeType for PostgreSQL * Add Ltree primitive data type --- CHANGES.rst | 7 + docs/data_types.rst | 13 ++ sqlalchemy_utils/__init__.py | 5 +- sqlalchemy_utils/primitives/__init__.py | 1 + sqlalchemy_utils/primitives/ltree.py | 180 ++++++++++++++++++++++++ sqlalchemy_utils/types/__init__.py | 1 + sqlalchemy_utils/types/ltree.py | 107 ++++++++++++++ tests/primitives/test_ltree.py | 173 +++++++++++++++++++++++ tests/types/test_currency.py | 1 - tests/types/test_ltree.py | 41 ++++++ 10 files changed, 526 insertions(+), 3 deletions(-) create mode 100644 sqlalchemy_utils/primitives/ltree.py create mode 100644 sqlalchemy_utils/types/ltree.py create mode 100644 tests/primitives/test_ltree.py create mode 100644 tests/types/test_ltree.py diff --git a/CHANGES.rst b/CHANGES.rst index 2ce07d9..e3a5d88 100644 --- a/CHANGES.rst +++ b/CHANGES.rst @@ -4,6 +4,13 @@ Changelog Here you can see the full list of changes between each SQLAlchemy-Utils release. +0.32.4 (2016-04-20) +^^^^^^^^^^^^^^^^^^^ + +- Added LtreeType for PostgreSQL ltree extension +- Added Ltree primitive data type + + 0.32.3 (2016-04-20) ^^^^^^^^^^^^^^^^^^^ diff --git a/docs/data_types.rst b/docs/data_types.rst index 887f1d2..eac2b4f 100644 --- a/docs/data_types.rst +++ b/docs/data_types.rst @@ -87,6 +87,19 @@ LocaleType .. autoclass:: LocaleType +LtreeType +--------- + + +.. module:: sqlalchemy_utils.types.ltree + +.. autoclass:: LtreeType + +.. module:: sqlalchemy_utils.primitives.ltree + +.. autoclass:: Ltree + + IPAddressType ------------- diff --git a/sqlalchemy_utils/__init__.py b/sqlalchemy_utils/__init__.py index 126bd4c..7ee2a6b 100644 --- a/sqlalchemy_utils/__init__.py +++ b/sqlalchemy_utils/__init__.py @@ -55,7 +55,7 @@ from .listeners import ( # noqa ) from .models import Timestamp # noqa from .observer import observes # noqa -from .primitives import Country, Currency, WeekDay, WeekDays # noqa +from .primitives import Country, Currency, Ltree, WeekDay, WeekDays # noqa from .proxy_dict import proxy_dict, ProxyDict # noqa from .query_chain import QueryChain # noqa from .types import ( # noqa @@ -77,6 +77,7 @@ from .types import ( # noqa IPAddressType, JSONType, LocaleType, + LtreeType, NumericRangeType, Password, PasswordType, @@ -93,4 +94,4 @@ from .types import ( # noqa WeekDaysType ) -__version__ = '0.32.3' +__version__ = '0.32.4' diff --git a/sqlalchemy_utils/primitives/__init__.py b/sqlalchemy_utils/primitives/__init__.py index 71a5829..76d768c 100644 --- a/sqlalchemy_utils/primitives/__init__.py +++ b/sqlalchemy_utils/primitives/__init__.py @@ -1,4 +1,5 @@ from .country import Country # noqa from .currency import Currency # noqa +from .ltree import Ltree # noqa from .weekday import WeekDay # noqa from .weekdays import WeekDays # noqa diff --git a/sqlalchemy_utils/primitives/ltree.py b/sqlalchemy_utils/primitives/ltree.py new file mode 100644 index 0000000..5ca1448 --- /dev/null +++ b/sqlalchemy_utils/primitives/ltree.py @@ -0,0 +1,180 @@ +from __future__ import absolute_import + +import re + +import six + +from ..utils import str_coercible + +path_matcher = re.compile(r'^[A-Za-z0-9_]+(\.[A-Za-z0-9_]+)*$') + + +@str_coercible +class Ltree(object): + """ + Ltree class wraps a valid string label path. It provides various + convenience properties and methods. + + :: + + from sqlalchemy_utils import Ltree + + Ltree('1.2.3').path # '1.2.3' + + + Ltree always validates the given code. + + :: + + Ltree(None) # raises TypeError + + Ltree('..') # raises ValueError + + + Ltree supports equality operators. + + :: + + Ltree('Countries.Finland') == Ltree('Countries.Finland') + Ltree('Countries.Germany') != Ltree('Countries.Finland') + + + Ltree objects are hashable. + + + :: + + assert hash(Ltree('Finland')) == hash('Finland') + + + Ltree objects have length. + + :: + + assert len(Ltree('1.2')) 2 + assert len(Ltree('some.one.some.where')) # 4 + + + You can easily find subpath indexes. + + :: + + assert Ltree('1.2.3').index('2.3') == 1 + assert Ltree('1.2.3.4.5').index('3.4') == 2 + + + Ltree objects can be sliced. + + + :: + + assert Ltree('1.2.3')[0:2] == Ltree('1.2') + assert Ltree('1.2.3')[1:] == Ltree('2.3') + + + Finding longest common ancestor. + + + :: + + assert Ltree('1.2.3.4.5').lca('1.2.3', '1.2.3.4', '1.2.3') == '1.2' + assert Ltree('1.2.3.4.5').lca('1.2', '1.2.3') == '1' + + + Ltree objects can be concatenated.1 + + :: + + assert Ltree('1.2') + Ltree('1.2') == Ltree('1.2.1.2') + """ + + def __init__(self, path_or_ltree): + if isinstance(path_or_ltree, Ltree): + self.path = path_or_ltree.path + elif isinstance(path_or_ltree, six.string_types): + self.validate(path_or_ltree) + self.path = path_or_ltree + else: + raise TypeError( + "Ltree() argument must be a string or an Ltree, not '{0}'" + .format( + type(path_or_ltree).__name__ + ) + ) + + @classmethod + def validate(cls, path): + if path_matcher.match(path) is None: + raise ValueError( + "'{0}' is not a valid ltree path.".format(path) + ) + + def __len__(self): + return len(self.path.split('.')) + + def index(self, other): + subpath = Ltree(other).path.split('.') + parts = self.path.split('.') + for index, _ in enumerate(parts): + if parts[index:len(subpath) + index] == subpath: + return index + raise ValueError('subpath not found') + + def __getitem__(self, key): + if isinstance(key, int): + return Ltree(self.path.split('.')[key]) + elif isinstance(key, slice): + return Ltree('.'.join(self.path.split('.')[key])) + raise TypeError( + 'Ltree indices must be integers, not {0}'.format( + key.__class__.__name__ + ) + ) + + def lca(self, *others): + """ + Lowest common ancestor, i.e., longest common prefix of paths + + :: + + assert Ltree('1.2.3.4.5').lca('1.2.3', '1.2.3.4', '1.2.3') == '1.2' + """ + other_parts = [Ltree(other).path.split('.') for other in others] + parts = self.path.split('.') + for index, element in enumerate(parts): + if any(( + other[index] != element or len(other) <= index + 1 + for other in other_parts + )): + if index == 0: + return None + return Ltree('.'.join(parts[0:index])) + + def __add__(self, other): + return Ltree(self.path + '.' + Ltree(other).path) + + def __radd__(self, other): + return Ltree(other) + self + + def __eq__(self, other): + if isinstance(other, Ltree): + return self.path == other.path + elif isinstance(other, six.string_types): + return self.path == other + else: + return NotImplemented + + def __hash__(self): + return hash(self.path) + + def __ne__(self, other): + return not (self == other) + + def __repr__(self): + return '%s(%r)' % (self.__class__.__name__, self.path) + + def __unicode__(self): + return self.path + + def __contains__(self, label): + return label in self.path.split('.') diff --git a/sqlalchemy_utils/types/__init__.py b/sqlalchemy_utils/types/__init__.py index 272b001..71dc84f 100644 --- a/sqlalchemy_utils/types/__init__.py +++ b/sqlalchemy_utils/types/__init__.py @@ -12,6 +12,7 @@ from .encrypted import EncryptedType # noqa from .ip_address import IPAddressType # noqa from .json import JSONType # noqa from .locale import LocaleType # noqa +from .ltree import LtreeType # noqa from .password import Password, PasswordType # noqa from .pg_composite import ( # noqa CompositeArray, diff --git a/sqlalchemy_utils/types/ltree.py b/sqlalchemy_utils/types/ltree.py new file mode 100644 index 0000000..95ebfcb --- /dev/null +++ b/sqlalchemy_utils/types/ltree.py @@ -0,0 +1,107 @@ +from __future__ import absolute_import + +from sqlalchemy import types +from sqlalchemy.dialects.postgresql.base import ( + ARRAY, + ischema_names, + PGTypeCompiler +) +from sqlalchemy.sql import expression + +from ..primitives import Ltree +from .scalar_coercible import ScalarCoercible + + +class LtreeType(types.Concatenable, types.UserDefinedType, ScalarCoercible): + """Postgresql LtreeType type. + + The LtreeType datatype can be used for representing labels of data stored + in hierarchial tree-like structure. For more detailed information please + refer to http://www.postgresql.org/docs/current/static/ltree.html + + .. note:: + Using :class:`LtreeType`, :class:`LQUERY` and :class:`LTXTQUERY` types + may require installation of Postgresql ltree extension on the server + side. Please visit http://www.postgres.org for details. + """ + + class comparator_factory(types.Concatenable.Comparator): + def ancestor_of(self, other): + if isinstance(other, list): + return self.op('@>')(expression.cast(other, ARRAY(LtreeType))) + else: + return self.op('@>')(other) + + def descendant_of(self, other): + if isinstance(other, list): + return self.op('<@')(expression.cast(other, ARRAY(LtreeType))) + else: + return self.op('<@')(other) + + def lquery(self, other): + if isinstance(other, list): + return self.op('?')(expression.cast(other, ARRAY(LQUERY))) + else: + return self.op('~')(other) + + def ltxtquery(self, other): + return self.op('@')(other) + + def bind_processor(self, dialect): + def process(value): + if value: + return value.path + return process + + def result_processor(self, dialect, coltype): + def process(value): + return self._coerce(value) + return process + + def literal_processor(self, dialect): + def process(value): + value = value.replace("'", "''") + return "'%s'" % value + return process + + __visit_name__ = 'LTREE' + + def _coerce(self, value): + if value: + return Ltree(value) + + +class LQUERY(types.TypeEngine): + """Postresql LQUERY type. + See :class:`LTREE` for details. + """ + __visit_name__ = 'LQUERY' + + +class LTXTQUERY(types.TypeEngine): + """Postresql LTXTQUERY type. + See :class:`LTREE` for details. + """ + __visit_name__ = 'LTXTQUERY' + + +ischema_names['ltree'] = LtreeType +ischema_names['lquery'] = LQUERY +ischema_names['ltxtquery'] = LTXTQUERY + + +def visit_LTREE(self, type_, **kw): + return 'LTREE' + + +def visit_LQUERY(self, type_, **kw): + return 'LQUERY' + + +def visit_LTXTQUERY(self, type_, **kw): + return 'LTXTQUERY' + + +PGTypeCompiler.visit_LTREE = visit_LTREE +PGTypeCompiler.visit_LQUERY = visit_LQUERY +PGTypeCompiler.visit_LTXTQUERY = visit_LTXTQUERY diff --git a/tests/primitives/test_ltree.py b/tests/primitives/test_ltree.py new file mode 100644 index 0000000..6f498b3 --- /dev/null +++ b/tests/primitives/test_ltree.py @@ -0,0 +1,173 @@ +# -*- coding: utf-8 -*- +import pytest +import six + +from sqlalchemy_utils import Ltree + + +class TestLtree(object): + def test_init(self): + assert Ltree('path.path') == Ltree(Ltree('path.path')) + + def test_constructor_with_wrong_type(self): + with pytest.raises(TypeError) as e: + Ltree(None) + assert str(e.value) == ( + "Ltree() argument must be a string or an Ltree, not 'NoneType'" + ) + + def test_constructor_with_invalid_code(self): + with pytest.raises(ValueError) as e: + Ltree('..') + assert str(e.value) == "'..' is not a valid ltree path." + + @pytest.mark.parametrize( + 'code', + ( + 'path', + 'path.path', + '1_.2', + '_._', + ) + ) + def test_validate_with_valid_codes(self, code): + Ltree.validate(code) + + @pytest.mark.parametrize( + 'path', + ( + '', + '.', + 'path.', + 'path..path', + 'path.path..path', + 'path.path..', + 'path.รครถ', + ) + ) + def test_validate_with_invalid_path(self, path): + with pytest.raises(ValueError) as e: + Ltree.validate(path) + assert str(e.value) == ( + "'{0}' is not a valid ltree path.".format(path) + ) + + @pytest.mark.parametrize( + ('path', 'length'), + ( + ('path', 1), + ('1.1', 2), + ('1.2.3', 3) + ) + ) + def test_length(self, path, length): + assert len(Ltree(path)) == length + + @pytest.mark.parametrize( + ('path', 'subpath', 'index'), + ( + ('path.path', 'path', 0), + ('1.2.3', '2.3', 1), + ('1.2.3.4', '2.3', 1), + ('1.2.3.4', '3.4', 2) + ) + ) + def test_index(self, path, subpath, index): + assert Ltree(path).index(subpath) == index + + @pytest.mark.parametrize( + ('path', 'item_slice', 'result'), + ( + ('path.path', 0, 'path'), + ('1.1.2.3', slice(1, 3), '1.2'), + ('1.1.2.3', slice(1, None), '1.2.3'), + ) + ) + def test_getitem(self, path, item_slice, result): + assert Ltree(path)[item_slice] == result + + @pytest.mark.parametrize( + ('path', 'others', 'result'), + ( + ('1.2.3', ['1.2.3', '1.2'], '1'), + ('1.2.3.4.5', ['1.2.3', '1.2.3.4'], '1.2'), + ('1.2.3.4.5', ['3.4', '1.2.3.4'], None), + ) + ) + def test_lca(self, path, others, result): + assert Ltree(path).lca(*others) == result + + @pytest.mark.parametrize( + ('path', 'other', 'result'), + ( + ('1.2.3', '4.5', '1.2.3.4.5'), + ('1', '1', '1.1'), + ) + ) + def test_add(self, path, other, result): + assert Ltree(path) + other == result + + @pytest.mark.parametrize( + ('path', 'other', 'result'), + ( + ('1.2.3', '4.5', '4.5.1.2.3'), + ('1', '1', '1.1'), + ) + ) + def test_radd(self, path, other, result): + assert other + Ltree(path) == result + + @pytest.mark.parametrize( + ('path', 'other', 'result'), + ( + ('1.2.3', '4.5', '1.2.3.4.5'), + ('1', '1', '1.1'), + ) + ) + def test_iadd(self, path, other, result): + ltree = Ltree(path) + ltree += other + assert ltree == result + + @pytest.mark.parametrize( + ('path', 'other', 'result'), + ( + ('1.2.3', '2', True), + ('1.2.3', '3', True), + ('1', '1', True), + ('1', '2', False), + ) + ) + def test_contains(self, path, other, result): + assert (other in Ltree(path)) == result + + def test_getitem_with_other_than_slice_or_in(self): + with pytest.raises(TypeError): + Ltree('1.2')['something'] + + def test_index_raises_value_error_if_subpath_not_found(self): + with pytest.raises(ValueError): + Ltree('1.2').index('3') + + def test_equality_operator(self): + assert Ltree('path.path') == 'path.path' + assert 'path.path' == Ltree('path.path') + assert Ltree('path.path') == Ltree('path.path') + + def test_non_equality_operator(self): + assert Ltree('path.path') != u'path.' + assert not (Ltree('path.path') != 'path.path') + + def test_hash(self): + return hash(Ltree('path')) == hash('path') + + def test_repr(self): + return repr(Ltree('path')) == "Ltree('path')" + + def test_unicode(self): + ltree = Ltree('path.path') + assert six.text_type(ltree) == 'path.path' + + def test_str(self): + ltree = Ltree('path') + assert str(ltree) == 'path' diff --git a/tests/types/test_currency.py b/tests/types/test_currency.py index 0e10631..23eebbc 100644 --- a/tests/types/test_currency.py +++ b/tests/types/test_currency.py @@ -11,7 +11,6 @@ def set_get_locale(): @pytest.fixture def User(Base): - class User(Base): __tablename__ = 'user' id = sa.Column(sa.Integer, primary_key=True) diff --git a/tests/types/test_ltree.py b/tests/types/test_ltree.py new file mode 100644 index 0000000..916fa83 --- /dev/null +++ b/tests/types/test_ltree.py @@ -0,0 +1,41 @@ +import pytest +import sqlalchemy as sa + +from sqlalchemy_utils import Ltree, LtreeType + + +@pytest.fixture +def Section(Base): + class Section(Base): + __tablename__ = 'section' + id = sa.Column(sa.Integer, primary_key=True) + path = sa.Column(LtreeType) + + return Section + + +@pytest.fixture +def init_models(Section, connection): + connection.execute('CREATE EXTENSION IF NOT EXISTS ltree') + pass + + +@pytest.mark.usefixtures('postgresql_dsn') +class TestLTREE(object): + def test_saves_path(self, session, Section): + section = Section(path='1.1.2') + + session.add(section) + session.commit() + + user = session.query(Section).first() + assert user.path == '1.1.2' + + def test_scalar_attributes_get_coerced_to_objects(self, Section): + section = Section(path='path.path') + assert isinstance(section.path, Ltree) + + def test_literal_param(self, session, Section): + clause = Section.path == 'path' + compiled = str(clause.compile(compile_kwargs={'literal_binds': True})) + assert compiled == 'section.path = \'path\''