65 lines
1.7 KiB
Python
65 lines
1.7 KiB
Python
import sqlalchemy as sa
|
|
from sqlalchemy.sql import expression
|
|
from sqlalchemy.ext.compiler import compiles
|
|
from sqlalchemy_utils.types import TSVectorType
|
|
|
|
|
|
class tsvector_match(expression.FunctionElement):
|
|
type = sa.types.Unicode()
|
|
name = 'tsvector_match'
|
|
|
|
|
|
@compiles(tsvector_match)
|
|
def compile_tsvector_match(element, compiler, **kw):
|
|
args = list(element.clauses)
|
|
if len(args) < 2:
|
|
raise Exception(
|
|
"Function 'tsvector_match' expects atleast two arguments."
|
|
)
|
|
return '(%s) @@ %s' % (
|
|
compiler.process(args[0]),
|
|
compiler.process(args[1])
|
|
)
|
|
|
|
|
|
class to_tsquery(expression.FunctionElement):
|
|
type = sa.types.Unicode()
|
|
name = 'to_tsquery'
|
|
|
|
|
|
@compiles(to_tsquery)
|
|
def compile_to_tsquery(element, compiler, **kw):
|
|
if len(element.clauses) < 1:
|
|
raise Exception(
|
|
"Function 'to_tsquery' expects atleast one argument."
|
|
)
|
|
return 'to_tsquery(%s)' % (
|
|
', '.join(map(compiler.process, element.clauses))
|
|
)
|
|
|
|
|
|
class plainto_tsquery(expression.FunctionElement):
|
|
type = sa.types.Unicode()
|
|
name = 'plainto_tsquery'
|
|
|
|
|
|
@compiles(plainto_tsquery)
|
|
def compile_plainto_tsquery(element, compiler, **kw):
|
|
if len(element.clauses) < 1:
|
|
raise Exception(
|
|
"Function 'plainto_tsquery' expects atleast one argument."
|
|
)
|
|
return 'plainto_tsquery(%s)' % (
|
|
', '.join(map(compiler.process, element.clauses))
|
|
)
|
|
|
|
|
|
class tsvector_concat(expression.FunctionElement):
|
|
type = TSVectorType()
|
|
name = 'tsvector_concat'
|
|
|
|
|
|
@compiles(tsvector_concat)
|
|
def compile_tsvector_concat(element, compiler, **kw):
|
|
return ' || '.join(map(compiler.process, element.clauses))
|