Added new expressions tsvector_match and tsvector_concat
This commit is contained in:
39
sqlalchemy_utils/expressions.py
Normal file
39
sqlalchemy_utils/expressions.py
Normal file
@@ -0,0 +1,39 @@
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy.sql import expression
|
||||
from sqlalchemy.ext.compiler import compiles
|
||||
from sqlalchemy_utils.types import TSVectorType
|
||||
|
||||
|
||||
class tsvector_match(expression.FunctionElement):
|
||||
type = sa.types.Unicode()
|
||||
name = 'tsvector_match'
|
||||
|
||||
|
||||
@compiles(tsvector_match)
|
||||
def compile_tsvector_match(element, compiler, **kw):
|
||||
args = list(element.clauses)
|
||||
if len(args) < 2:
|
||||
raise Exception(
|
||||
"Function 'match_tsvector' expects atleast two arguments."
|
||||
)
|
||||
if len(args) == 2:
|
||||
return '(%s) @@ to_tsquery(%s)' % (
|
||||
compiler.process(args[0]),
|
||||
compiler.process(args[1])
|
||||
)
|
||||
elif len(args) == 3:
|
||||
return '(%s) @@ to_tsquery(%s, %s)' % (
|
||||
compiler.process(args[0]),
|
||||
compiler.process(args[2]),
|
||||
compiler.process(args[1])
|
||||
)
|
||||
|
||||
|
||||
class tsvector_concat(expression.FunctionElement):
|
||||
type = TSVectorType()
|
||||
name = 'tsvector_concat'
|
||||
|
||||
|
||||
@compiles(tsvector_concat)
|
||||
def compile_tsvector_concat(element, compiler, **kw):
|
||||
return ' || '.join(map(compiler.process, element.clauses))
|
@@ -1,7 +1,5 @@
|
||||
from functools import wraps
|
||||
from sqlalchemy.orm.collections import InstrumentedList as _InstrumentedList
|
||||
from sqlalchemy import types
|
||||
from sqlalchemy.dialects.postgresql.base import ischema_names
|
||||
from .arrow import ArrowType
|
||||
from .color import ColorType
|
||||
from .email import EmailType
|
||||
@@ -16,6 +14,7 @@ from .password import Password, PasswordType
|
||||
from .phone_number import PhoneNumber, PhoneNumberType
|
||||
from .scalar_list import ScalarListException, ScalarListType
|
||||
from .timezone import TimezoneType
|
||||
from .ts_vector import TSVectorType
|
||||
from .uuid import UUIDType
|
||||
|
||||
|
||||
@@ -35,21 +34,11 @@ __all__ = (
|
||||
ScalarListException,
|
||||
ScalarListType,
|
||||
TimezoneType,
|
||||
TSVectorType,
|
||||
UUIDType,
|
||||
)
|
||||
|
||||
|
||||
class TSVectorType(types.UserDefinedType):
|
||||
"""
|
||||
Text search vector type for postgresql.
|
||||
"""
|
||||
def get_col_spec(self):
|
||||
return 'tsvector'
|
||||
|
||||
|
||||
ischema_names['tsvector'] = TSVectorType
|
||||
|
||||
|
||||
class InstrumentedList(_InstrumentedList):
|
||||
"""Enhanced version of SQLAlchemy InstrumentedList. Provides some
|
||||
additional functionality."""
|
||||
|
13
sqlalchemy_utils/types/ts_vector.py
Normal file
13
sqlalchemy_utils/types/ts_vector.py
Normal file
@@ -0,0 +1,13 @@
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy.dialects.postgresql.base import ischema_names
|
||||
|
||||
|
||||
class TSVectorType(sa.types.UserDefinedType):
|
||||
"""
|
||||
Text search vector type for postgresql.
|
||||
"""
|
||||
def get_col_spec(self):
|
||||
return 'tsvector'
|
||||
|
||||
|
||||
ischema_names['tsvector'] = TSVectorType
|
61
tests/test_expressions.py
Normal file
61
tests/test_expressions.py
Normal file
@@ -0,0 +1,61 @@
|
||||
from pytest import raises
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy_utils.types import TSVectorType
|
||||
from sqlalchemy_utils.expressions import tsvector_match, tsvector_concat
|
||||
from tests import TestCase
|
||||
|
||||
|
||||
class TSVectorTestCase(TestCase):
|
||||
dns = 'postgres://postgres@localhost/sqlalchemy_utils_test'
|
||||
|
||||
def create_models(self):
|
||||
class Article(self.Base):
|
||||
__tablename__ = 'article'
|
||||
id = sa.Column(sa.Integer, primary_key=True)
|
||||
name = sa.Column(sa.Unicode(255))
|
||||
content = sa.Column(sa.UnicodeText)
|
||||
search_vector = sa.Column(TSVectorType)
|
||||
search_vector2 = sa.Column(TSVectorType)
|
||||
|
||||
self.Article = Article
|
||||
|
||||
|
||||
class TestMatchTSVector(TSVectorTestCase):
|
||||
def test_raises_exception_if_less_than_2_parameters_given(self):
|
||||
with raises(Exception):
|
||||
str(
|
||||
tsvector_match(
|
||||
self.Article.search_vector,
|
||||
)
|
||||
)
|
||||
|
||||
def test_supports_postgres(self):
|
||||
assert str(tsvector_match(
|
||||
self.Article.search_vector,
|
||||
'something',
|
||||
)) == '(article.search_vector) @@ to_tsquery(:tsvector_match_1)'
|
||||
|
||||
def test_supports_collation_as_3rd_parameter(self):
|
||||
assert str(tsvector_match(
|
||||
self.Article.search_vector,
|
||||
'something',
|
||||
'finnish'
|
||||
)) == (
|
||||
'(article.search_vector) @@ '
|
||||
'to_tsquery(:tsvector_match_1, :tsvector_match_2)'
|
||||
)
|
||||
|
||||
|
||||
class TestConcatTSVector(TSVectorTestCase):
|
||||
def test_concatenate_search_vectors(self):
|
||||
assert str(tsvector_match(
|
||||
tsvector_concat(
|
||||
self.Article.search_vector,
|
||||
self.Article.search_vector2
|
||||
),
|
||||
'something',
|
||||
'finnish'
|
||||
)) == (
|
||||
'(article.search_vector || article.search_vector2) '
|
||||
'@@ to_tsquery(:tsvector_match_1, :tsvector_match_2)'
|
||||
)
|
@@ -3,7 +3,7 @@ from sqlalchemy_utils import escape_like, defer_except
|
||||
from tests import TestCase
|
||||
from sqlalchemy_utils.functions import (
|
||||
non_indexed_foreign_keys,
|
||||
render_statement
|
||||
render_statement,
|
||||
)
|
||||
|
||||
|
||||
|
Reference in New Issue
Block a user