Added function expressions for to_tsquery and plainto_tsquery
This commit is contained in:
@@ -16,17 +16,42 @@ def compile_tsvector_match(element, compiler, **kw):
|
||||
raise Exception(
|
||||
"Function 'tsvector_match' expects atleast two arguments."
|
||||
)
|
||||
if len(args) == 2:
|
||||
return '(%s) @@ to_tsquery(%s)' % (
|
||||
compiler.process(args[0]),
|
||||
compiler.process(args[1])
|
||||
return '(%s) @@ %s' % (
|
||||
compiler.process(args[0]),
|
||||
compiler.process(args[1])
|
||||
)
|
||||
|
||||
|
||||
class to_tsquery(expression.FunctionElement):
|
||||
type = sa.types.Unicode()
|
||||
name = 'to_tsquery'
|
||||
|
||||
|
||||
@compiles(to_tsquery)
|
||||
def compile_to_tsquery(element, compiler, **kw):
|
||||
if len(element.clauses) < 1:
|
||||
raise Exception(
|
||||
"Function 'to_tsquery' expects atleast one argument."
|
||||
)
|
||||
elif len(args) == 3:
|
||||
return '(%s) @@ to_tsquery(%s, %s)' % (
|
||||
compiler.process(args[0]),
|
||||
compiler.process(args[2]),
|
||||
compiler.process(args[1])
|
||||
return 'to_tsquery(%s)' % (
|
||||
', '.join(map(compiler.process, element.clauses))
|
||||
)
|
||||
|
||||
|
||||
class plainto_tsquery(expression.FunctionElement):
|
||||
type = sa.types.Unicode()
|
||||
name = 'plainto_tsquery'
|
||||
|
||||
|
||||
@compiles(plainto_tsquery)
|
||||
def compile_plainto_tsquery(element, compiler, **kw):
|
||||
if len(element.clauses) < 1:
|
||||
raise Exception(
|
||||
"Function 'plainto_tsquery' expects atleast one argument."
|
||||
)
|
||||
return 'plainto_tsquery(%s)' % (
|
||||
', '.join(map(compiler.process, element.clauses))
|
||||
)
|
||||
|
||||
|
||||
class tsvector_concat(expression.FunctionElement):
|
||||
|
@@ -1,7 +1,12 @@
|
||||
from pytest import raises
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy_utils.types import TSVectorType
|
||||
from sqlalchemy_utils.expressions import tsvector_match, tsvector_concat
|
||||
from sqlalchemy_utils.expressions import (
|
||||
tsvector_match,
|
||||
tsvector_concat,
|
||||
to_tsquery,
|
||||
plainto_tsquery
|
||||
)
|
||||
from tests import TestCase
|
||||
|
||||
|
||||
@@ -32,17 +37,27 @@ class TestMatchTSVector(TSVectorTestCase):
|
||||
def test_supports_postgres(self):
|
||||
assert str(tsvector_match(
|
||||
self.Article.search_vector,
|
||||
'something',
|
||||
)) == '(article.search_vector) @@ to_tsquery(:tsvector_match_1)'
|
||||
to_tsquery('something'),
|
||||
)) == '(article.search_vector) @@ to_tsquery(:to_tsquery_1)'
|
||||
|
||||
def test_supports_collation_as_3rd_parameter(self):
|
||||
assert str(tsvector_match(
|
||||
self.Article.search_vector,
|
||||
'something',
|
||||
'finnish'
|
||||
)) == (
|
||||
'(article.search_vector) @@ '
|
||||
'to_tsquery(:tsvector_match_1, :tsvector_match_2)'
|
||||
|
||||
class TestToTSQuery(TSVectorTestCase):
|
||||
def test_requires_atleast_one_parameter(self):
|
||||
with raises(Exception):
|
||||
str(to_tsquery())
|
||||
|
||||
def test_supports_postgres(self):
|
||||
assert str(to_tsquery('something')) == 'to_tsquery(:to_tsquery_1)'
|
||||
|
||||
|
||||
class TestPlainToTSQuery(TSVectorTestCase):
|
||||
def test_requires_atleast_one_parameter(self):
|
||||
with raises(Exception):
|
||||
str(plainto_tsquery())
|
||||
|
||||
def test_supports_postgres(self):
|
||||
assert str(plainto_tsquery('s')) == (
|
||||
'plainto_tsquery(:plainto_tsquery_1)'
|
||||
)
|
||||
|
||||
|
||||
@@ -53,9 +68,8 @@ class TestConcatTSVector(TSVectorTestCase):
|
||||
self.Article.search_vector,
|
||||
self.Article.search_vector2
|
||||
),
|
||||
'something',
|
||||
'finnish'
|
||||
to_tsquery('finnish', 'something'),
|
||||
)) == (
|
||||
'(article.search_vector || article.search_vector2) '
|
||||
'@@ to_tsquery(:tsvector_match_1, :tsvector_match_2)'
|
||||
'@@ to_tsquery(:to_tsquery_1, :to_tsquery_2)'
|
||||
)
|
||||
|
Reference in New Issue
Block a user