Added function expressions for to_tsquery and plainto_tsquery

This commit is contained in:
Konsta Vesterinen
2013-08-01 14:50:08 +03:00
parent a14aad51a9
commit a642269b3b
2 changed files with 62 additions and 23 deletions

View File

@@ -16,16 +16,41 @@ def compile_tsvector_match(element, compiler, **kw):
raise Exception( raise Exception(
"Function 'tsvector_match' expects atleast two arguments." "Function 'tsvector_match' expects atleast two arguments."
) )
if len(args) == 2: return '(%s) @@ %s' % (
return '(%s) @@ to_tsquery(%s)' % (
compiler.process(args[0]), compiler.process(args[0]),
compiler.process(args[1]) compiler.process(args[1])
) )
elif len(args) == 3:
return '(%s) @@ to_tsquery(%s, %s)' % (
compiler.process(args[0]), class to_tsquery(expression.FunctionElement):
compiler.process(args[2]), type = sa.types.Unicode()
compiler.process(args[1]) name = 'to_tsquery'
@compiles(to_tsquery)
def compile_to_tsquery(element, compiler, **kw):
if len(element.clauses) < 1:
raise Exception(
"Function 'to_tsquery' expects atleast one argument."
)
return 'to_tsquery(%s)' % (
', '.join(map(compiler.process, element.clauses))
)
class plainto_tsquery(expression.FunctionElement):
type = sa.types.Unicode()
name = 'plainto_tsquery'
@compiles(plainto_tsquery)
def compile_plainto_tsquery(element, compiler, **kw):
if len(element.clauses) < 1:
raise Exception(
"Function 'plainto_tsquery' expects atleast one argument."
)
return 'plainto_tsquery(%s)' % (
', '.join(map(compiler.process, element.clauses))
) )

View File

@@ -1,7 +1,12 @@
from pytest import raises from pytest import raises
import sqlalchemy as sa import sqlalchemy as sa
from sqlalchemy_utils.types import TSVectorType from sqlalchemy_utils.types import TSVectorType
from sqlalchemy_utils.expressions import tsvector_match, tsvector_concat from sqlalchemy_utils.expressions import (
tsvector_match,
tsvector_concat,
to_tsquery,
plainto_tsquery
)
from tests import TestCase from tests import TestCase
@@ -32,17 +37,27 @@ class TestMatchTSVector(TSVectorTestCase):
def test_supports_postgres(self): def test_supports_postgres(self):
assert str(tsvector_match( assert str(tsvector_match(
self.Article.search_vector, self.Article.search_vector,
'something', to_tsquery('something'),
)) == '(article.search_vector) @@ to_tsquery(:tsvector_match_1)' )) == '(article.search_vector) @@ to_tsquery(:to_tsquery_1)'
def test_supports_collation_as_3rd_parameter(self):
assert str(tsvector_match( class TestToTSQuery(TSVectorTestCase):
self.Article.search_vector, def test_requires_atleast_one_parameter(self):
'something', with raises(Exception):
'finnish' str(to_tsquery())
)) == (
'(article.search_vector) @@ ' def test_supports_postgres(self):
'to_tsquery(:tsvector_match_1, :tsvector_match_2)' assert str(to_tsquery('something')) == 'to_tsquery(:to_tsquery_1)'
class TestPlainToTSQuery(TSVectorTestCase):
def test_requires_atleast_one_parameter(self):
with raises(Exception):
str(plainto_tsquery())
def test_supports_postgres(self):
assert str(plainto_tsquery('s')) == (
'plainto_tsquery(:plainto_tsquery_1)'
) )
@@ -53,9 +68,8 @@ class TestConcatTSVector(TSVectorTestCase):
self.Article.search_vector, self.Article.search_vector,
self.Article.search_vector2 self.Article.search_vector2
), ),
'something', to_tsquery('finnish', 'something'),
'finnish'
)) == ( )) == (
'(article.search_vector || article.search_vector2) ' '(article.search_vector || article.search_vector2) '
'@@ to_tsquery(:tsvector_match_1, :tsvector_match_2)' '@@ to_tsquery(:to_tsquery_1, :to_tsquery_2)'
) )