diff --git a/sqlalchemy_utils/expressions.py b/sqlalchemy_utils/expressions.py index f348f22..53a1070 100644 --- a/sqlalchemy_utils/expressions.py +++ b/sqlalchemy_utils/expressions.py @@ -16,17 +16,42 @@ def compile_tsvector_match(element, compiler, **kw): raise Exception( "Function 'tsvector_match' expects atleast two arguments." ) - if len(args) == 2: - return '(%s) @@ to_tsquery(%s)' % ( - compiler.process(args[0]), - compiler.process(args[1]) + return '(%s) @@ %s' % ( + compiler.process(args[0]), + compiler.process(args[1]) + ) + + +class to_tsquery(expression.FunctionElement): + type = sa.types.Unicode() + name = 'to_tsquery' + + +@compiles(to_tsquery) +def compile_to_tsquery(element, compiler, **kw): + if len(element.clauses) < 1: + raise Exception( + "Function 'to_tsquery' expects atleast one argument." ) - elif len(args) == 3: - return '(%s) @@ to_tsquery(%s, %s)' % ( - compiler.process(args[0]), - compiler.process(args[2]), - compiler.process(args[1]) + return 'to_tsquery(%s)' % ( + ', '.join(map(compiler.process, element.clauses)) + ) + + +class plainto_tsquery(expression.FunctionElement): + type = sa.types.Unicode() + name = 'plainto_tsquery' + + +@compiles(plainto_tsquery) +def compile_plainto_tsquery(element, compiler, **kw): + if len(element.clauses) < 1: + raise Exception( + "Function 'plainto_tsquery' expects atleast one argument." ) + return 'plainto_tsquery(%s)' % ( + ', '.join(map(compiler.process, element.clauses)) + ) class tsvector_concat(expression.FunctionElement): diff --git a/tests/test_expressions.py b/tests/test_expressions.py index 382b0c4..7718b53 100644 --- a/tests/test_expressions.py +++ b/tests/test_expressions.py @@ -1,7 +1,12 @@ from pytest import raises import sqlalchemy as sa from sqlalchemy_utils.types import TSVectorType -from sqlalchemy_utils.expressions import tsvector_match, tsvector_concat +from sqlalchemy_utils.expressions import ( + tsvector_match, + tsvector_concat, + to_tsquery, + plainto_tsquery +) from tests import TestCase @@ -32,17 +37,27 @@ class TestMatchTSVector(TSVectorTestCase): def test_supports_postgres(self): assert str(tsvector_match( self.Article.search_vector, - 'something', - )) == '(article.search_vector) @@ to_tsquery(:tsvector_match_1)' + to_tsquery('something'), + )) == '(article.search_vector) @@ to_tsquery(:to_tsquery_1)' - def test_supports_collation_as_3rd_parameter(self): - assert str(tsvector_match( - self.Article.search_vector, - 'something', - 'finnish' - )) == ( - '(article.search_vector) @@ ' - 'to_tsquery(:tsvector_match_1, :tsvector_match_2)' + +class TestToTSQuery(TSVectorTestCase): + def test_requires_atleast_one_parameter(self): + with raises(Exception): + str(to_tsquery()) + + def test_supports_postgres(self): + assert str(to_tsquery('something')) == 'to_tsquery(:to_tsquery_1)' + + +class TestPlainToTSQuery(TSVectorTestCase): + def test_requires_atleast_one_parameter(self): + with raises(Exception): + str(plainto_tsquery()) + + def test_supports_postgres(self): + assert str(plainto_tsquery('s')) == ( + 'plainto_tsquery(:plainto_tsquery_1)' ) @@ -53,9 +68,8 @@ class TestConcatTSVector(TSVectorTestCase): self.Article.search_vector, self.Article.search_vector2 ), - 'something', - 'finnish' + to_tsquery('finnish', 'something'), )) == ( '(article.search_vector || article.search_vector2) ' - '@@ to_tsquery(:tsvector_match_1, :tsvector_match_2)' + '@@ to_tsquery(:to_tsquery_1, :to_tsquery_2)' )