Added function expressions for to_tsquery and plainto_tsquery

This commit is contained in:
Konsta Vesterinen
2013-08-01 14:50:08 +03:00
parent a14aad51a9
commit a642269b3b
2 changed files with 62 additions and 23 deletions

View File

@@ -16,17 +16,42 @@ def compile_tsvector_match(element, compiler, **kw):
raise Exception(
"Function 'tsvector_match' expects atleast two arguments."
)
if len(args) == 2:
return '(%s) @@ to_tsquery(%s)' % (
compiler.process(args[0]),
compiler.process(args[1])
return '(%s) @@ %s' % (
compiler.process(args[0]),
compiler.process(args[1])
)
class to_tsquery(expression.FunctionElement):
type = sa.types.Unicode()
name = 'to_tsquery'
@compiles(to_tsquery)
def compile_to_tsquery(element, compiler, **kw):
if len(element.clauses) < 1:
raise Exception(
"Function 'to_tsquery' expects atleast one argument."
)
elif len(args) == 3:
return '(%s) @@ to_tsquery(%s, %s)' % (
compiler.process(args[0]),
compiler.process(args[2]),
compiler.process(args[1])
return 'to_tsquery(%s)' % (
', '.join(map(compiler.process, element.clauses))
)
class plainto_tsquery(expression.FunctionElement):
type = sa.types.Unicode()
name = 'plainto_tsquery'
@compiles(plainto_tsquery)
def compile_plainto_tsquery(element, compiler, **kw):
if len(element.clauses) < 1:
raise Exception(
"Function 'plainto_tsquery' expects atleast one argument."
)
return 'plainto_tsquery(%s)' % (
', '.join(map(compiler.process, element.clauses))
)
class tsvector_concat(expression.FunctionElement):

View File

@@ -1,7 +1,12 @@
from pytest import raises
import sqlalchemy as sa
from sqlalchemy_utils.types import TSVectorType
from sqlalchemy_utils.expressions import tsvector_match, tsvector_concat
from sqlalchemy_utils.expressions import (
tsvector_match,
tsvector_concat,
to_tsquery,
plainto_tsquery
)
from tests import TestCase
@@ -32,17 +37,27 @@ class TestMatchTSVector(TSVectorTestCase):
def test_supports_postgres(self):
assert str(tsvector_match(
self.Article.search_vector,
'something',
)) == '(article.search_vector) @@ to_tsquery(:tsvector_match_1)'
to_tsquery('something'),
)) == '(article.search_vector) @@ to_tsquery(:to_tsquery_1)'
def test_supports_collation_as_3rd_parameter(self):
assert str(tsvector_match(
self.Article.search_vector,
'something',
'finnish'
)) == (
'(article.search_vector) @@ '
'to_tsquery(:tsvector_match_1, :tsvector_match_2)'
class TestToTSQuery(TSVectorTestCase):
def test_requires_atleast_one_parameter(self):
with raises(Exception):
str(to_tsquery())
def test_supports_postgres(self):
assert str(to_tsquery('something')) == 'to_tsquery(:to_tsquery_1)'
class TestPlainToTSQuery(TSVectorTestCase):
def test_requires_atleast_one_parameter(self):
with raises(Exception):
str(plainto_tsquery())
def test_supports_postgres(self):
assert str(plainto_tsquery('s')) == (
'plainto_tsquery(:plainto_tsquery_1)'
)
@@ -53,9 +68,8 @@ class TestConcatTSVector(TSVectorTestCase):
self.Article.search_vector,
self.Article.search_vector2
),
'something',
'finnish'
to_tsquery('finnish', 'something'),
)) == (
'(article.search_vector || article.search_vector2) '
'@@ to_tsquery(:tsvector_match_1, :tsvector_match_2)'
'@@ to_tsquery(:to_tsquery_1, :to_tsquery_2)'
)