From a60c2424c70531eb51ca6c63f7dacf78ed7fac29 Mon Sep 17 00:00:00 2001 From: Konsta Vesterinen Date: Thu, 1 Aug 2013 10:25:37 +0300 Subject: [PATCH] Added new expressions tsvector_match and tsvector_concat --- sqlalchemy_utils/expressions.py | 39 ++++++++++++++++++ sqlalchemy_utils/types/__init__.py | 15 +------ sqlalchemy_utils/types/ts_vector.py | 13 ++++++ tests/test_expressions.py | 61 +++++++++++++++++++++++++++++ tests/test_utility_functions.py | 2 +- 5 files changed, 116 insertions(+), 14 deletions(-) create mode 100644 sqlalchemy_utils/expressions.py create mode 100644 sqlalchemy_utils/types/ts_vector.py create mode 100644 tests/test_expressions.py diff --git a/sqlalchemy_utils/expressions.py b/sqlalchemy_utils/expressions.py new file mode 100644 index 0000000..381f7e3 --- /dev/null +++ b/sqlalchemy_utils/expressions.py @@ -0,0 +1,39 @@ +import sqlalchemy as sa +from sqlalchemy.sql import expression +from sqlalchemy.ext.compiler import compiles +from sqlalchemy_utils.types import TSVectorType + + +class tsvector_match(expression.FunctionElement): + type = sa.types.Unicode() + name = 'tsvector_match' + + +@compiles(tsvector_match) +def compile_tsvector_match(element, compiler, **kw): + args = list(element.clauses) + if len(args) < 2: + raise Exception( + "Function 'match_tsvector' expects atleast two arguments." + ) + if len(args) == 2: + return '(%s) @@ to_tsquery(%s)' % ( + compiler.process(args[0]), + compiler.process(args[1]) + ) + elif len(args) == 3: + return '(%s) @@ to_tsquery(%s, %s)' % ( + compiler.process(args[0]), + compiler.process(args[2]), + compiler.process(args[1]) + ) + + +class tsvector_concat(expression.FunctionElement): + type = TSVectorType() + name = 'tsvector_concat' + + +@compiles(tsvector_concat) +def compile_tsvector_concat(element, compiler, **kw): + return ' || '.join(map(compiler.process, element.clauses)) diff --git a/sqlalchemy_utils/types/__init__.py b/sqlalchemy_utils/types/__init__.py index ac78219..6d5f35b 100644 --- a/sqlalchemy_utils/types/__init__.py +++ b/sqlalchemy_utils/types/__init__.py @@ -1,7 +1,5 @@ from functools import wraps from sqlalchemy.orm.collections import InstrumentedList as _InstrumentedList -from sqlalchemy import types -from sqlalchemy.dialects.postgresql.base import ischema_names from .arrow import ArrowType from .color import ColorType from .email import EmailType @@ -16,6 +14,7 @@ from .password import Password, PasswordType from .phone_number import PhoneNumber, PhoneNumberType from .scalar_list import ScalarListException, ScalarListType from .timezone import TimezoneType +from .ts_vector import TSVectorType from .uuid import UUIDType @@ -35,21 +34,11 @@ __all__ = ( ScalarListException, ScalarListType, TimezoneType, + TSVectorType, UUIDType, ) -class TSVectorType(types.UserDefinedType): - """ - Text search vector type for postgresql. - """ - def get_col_spec(self): - return 'tsvector' - - -ischema_names['tsvector'] = TSVectorType - - class InstrumentedList(_InstrumentedList): """Enhanced version of SQLAlchemy InstrumentedList. Provides some additional functionality.""" diff --git a/sqlalchemy_utils/types/ts_vector.py b/sqlalchemy_utils/types/ts_vector.py new file mode 100644 index 0000000..4d09aa8 --- /dev/null +++ b/sqlalchemy_utils/types/ts_vector.py @@ -0,0 +1,13 @@ +import sqlalchemy as sa +from sqlalchemy.dialects.postgresql.base import ischema_names + + +class TSVectorType(sa.types.UserDefinedType): + """ + Text search vector type for postgresql. + """ + def get_col_spec(self): + return 'tsvector' + + +ischema_names['tsvector'] = TSVectorType diff --git a/tests/test_expressions.py b/tests/test_expressions.py new file mode 100644 index 0000000..382b0c4 --- /dev/null +++ b/tests/test_expressions.py @@ -0,0 +1,61 @@ +from pytest import raises +import sqlalchemy as sa +from sqlalchemy_utils.types import TSVectorType +from sqlalchemy_utils.expressions import tsvector_match, tsvector_concat +from tests import TestCase + + +class TSVectorTestCase(TestCase): + dns = 'postgres://postgres@localhost/sqlalchemy_utils_test' + + def create_models(self): + class Article(self.Base): + __tablename__ = 'article' + id = sa.Column(sa.Integer, primary_key=True) + name = sa.Column(sa.Unicode(255)) + content = sa.Column(sa.UnicodeText) + search_vector = sa.Column(TSVectorType) + search_vector2 = sa.Column(TSVectorType) + + self.Article = Article + + +class TestMatchTSVector(TSVectorTestCase): + def test_raises_exception_if_less_than_2_parameters_given(self): + with raises(Exception): + str( + tsvector_match( + self.Article.search_vector, + ) + ) + + def test_supports_postgres(self): + assert str(tsvector_match( + self.Article.search_vector, + 'something', + )) == '(article.search_vector) @@ to_tsquery(:tsvector_match_1)' + + def test_supports_collation_as_3rd_parameter(self): + assert str(tsvector_match( + self.Article.search_vector, + 'something', + 'finnish' + )) == ( + '(article.search_vector) @@ ' + 'to_tsquery(:tsvector_match_1, :tsvector_match_2)' + ) + + +class TestConcatTSVector(TSVectorTestCase): + def test_concatenate_search_vectors(self): + assert str(tsvector_match( + tsvector_concat( + self.Article.search_vector, + self.Article.search_vector2 + ), + 'something', + 'finnish' + )) == ( + '(article.search_vector || article.search_vector2) ' + '@@ to_tsquery(:tsvector_match_1, :tsvector_match_2)' + ) diff --git a/tests/test_utility_functions.py b/tests/test_utility_functions.py index 6e845c0..1f1c130 100644 --- a/tests/test_utility_functions.py +++ b/tests/test_utility_functions.py @@ -3,7 +3,7 @@ from sqlalchemy_utils import escape_like, defer_except from tests import TestCase from sqlalchemy_utils.functions import ( non_indexed_foreign_keys, - render_statement + render_statement, )