Added new expressions tsvector_match and tsvector_concat
This commit is contained in:
39
sqlalchemy_utils/expressions.py
Normal file
39
sqlalchemy_utils/expressions.py
Normal file
@@ -0,0 +1,39 @@
|
|||||||
|
import sqlalchemy as sa
|
||||||
|
from sqlalchemy.sql import expression
|
||||||
|
from sqlalchemy.ext.compiler import compiles
|
||||||
|
from sqlalchemy_utils.types import TSVectorType
|
||||||
|
|
||||||
|
|
||||||
|
class tsvector_match(expression.FunctionElement):
|
||||||
|
type = sa.types.Unicode()
|
||||||
|
name = 'tsvector_match'
|
||||||
|
|
||||||
|
|
||||||
|
@compiles(tsvector_match)
|
||||||
|
def compile_tsvector_match(element, compiler, **kw):
|
||||||
|
args = list(element.clauses)
|
||||||
|
if len(args) < 2:
|
||||||
|
raise Exception(
|
||||||
|
"Function 'match_tsvector' expects atleast two arguments."
|
||||||
|
)
|
||||||
|
if len(args) == 2:
|
||||||
|
return '(%s) @@ to_tsquery(%s)' % (
|
||||||
|
compiler.process(args[0]),
|
||||||
|
compiler.process(args[1])
|
||||||
|
)
|
||||||
|
elif len(args) == 3:
|
||||||
|
return '(%s) @@ to_tsquery(%s, %s)' % (
|
||||||
|
compiler.process(args[0]),
|
||||||
|
compiler.process(args[2]),
|
||||||
|
compiler.process(args[1])
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class tsvector_concat(expression.FunctionElement):
|
||||||
|
type = TSVectorType()
|
||||||
|
name = 'tsvector_concat'
|
||||||
|
|
||||||
|
|
||||||
|
@compiles(tsvector_concat)
|
||||||
|
def compile_tsvector_concat(element, compiler, **kw):
|
||||||
|
return ' || '.join(map(compiler.process, element.clauses))
|
@@ -1,7 +1,5 @@
|
|||||||
from functools import wraps
|
from functools import wraps
|
||||||
from sqlalchemy.orm.collections import InstrumentedList as _InstrumentedList
|
from sqlalchemy.orm.collections import InstrumentedList as _InstrumentedList
|
||||||
from sqlalchemy import types
|
|
||||||
from sqlalchemy.dialects.postgresql.base import ischema_names
|
|
||||||
from .arrow import ArrowType
|
from .arrow import ArrowType
|
||||||
from .color import ColorType
|
from .color import ColorType
|
||||||
from .email import EmailType
|
from .email import EmailType
|
||||||
@@ -16,6 +14,7 @@ from .password import Password, PasswordType
|
|||||||
from .phone_number import PhoneNumber, PhoneNumberType
|
from .phone_number import PhoneNumber, PhoneNumberType
|
||||||
from .scalar_list import ScalarListException, ScalarListType
|
from .scalar_list import ScalarListException, ScalarListType
|
||||||
from .timezone import TimezoneType
|
from .timezone import TimezoneType
|
||||||
|
from .ts_vector import TSVectorType
|
||||||
from .uuid import UUIDType
|
from .uuid import UUIDType
|
||||||
|
|
||||||
|
|
||||||
@@ -35,21 +34,11 @@ __all__ = (
|
|||||||
ScalarListException,
|
ScalarListException,
|
||||||
ScalarListType,
|
ScalarListType,
|
||||||
TimezoneType,
|
TimezoneType,
|
||||||
|
TSVectorType,
|
||||||
UUIDType,
|
UUIDType,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class TSVectorType(types.UserDefinedType):
|
|
||||||
"""
|
|
||||||
Text search vector type for postgresql.
|
|
||||||
"""
|
|
||||||
def get_col_spec(self):
|
|
||||||
return 'tsvector'
|
|
||||||
|
|
||||||
|
|
||||||
ischema_names['tsvector'] = TSVectorType
|
|
||||||
|
|
||||||
|
|
||||||
class InstrumentedList(_InstrumentedList):
|
class InstrumentedList(_InstrumentedList):
|
||||||
"""Enhanced version of SQLAlchemy InstrumentedList. Provides some
|
"""Enhanced version of SQLAlchemy InstrumentedList. Provides some
|
||||||
additional functionality."""
|
additional functionality."""
|
||||||
|
13
sqlalchemy_utils/types/ts_vector.py
Normal file
13
sqlalchemy_utils/types/ts_vector.py
Normal file
@@ -0,0 +1,13 @@
|
|||||||
|
import sqlalchemy as sa
|
||||||
|
from sqlalchemy.dialects.postgresql.base import ischema_names
|
||||||
|
|
||||||
|
|
||||||
|
class TSVectorType(sa.types.UserDefinedType):
|
||||||
|
"""
|
||||||
|
Text search vector type for postgresql.
|
||||||
|
"""
|
||||||
|
def get_col_spec(self):
|
||||||
|
return 'tsvector'
|
||||||
|
|
||||||
|
|
||||||
|
ischema_names['tsvector'] = TSVectorType
|
61
tests/test_expressions.py
Normal file
61
tests/test_expressions.py
Normal file
@@ -0,0 +1,61 @@
|
|||||||
|
from pytest import raises
|
||||||
|
import sqlalchemy as sa
|
||||||
|
from sqlalchemy_utils.types import TSVectorType
|
||||||
|
from sqlalchemy_utils.expressions import tsvector_match, tsvector_concat
|
||||||
|
from tests import TestCase
|
||||||
|
|
||||||
|
|
||||||
|
class TSVectorTestCase(TestCase):
|
||||||
|
dns = 'postgres://postgres@localhost/sqlalchemy_utils_test'
|
||||||
|
|
||||||
|
def create_models(self):
|
||||||
|
class Article(self.Base):
|
||||||
|
__tablename__ = 'article'
|
||||||
|
id = sa.Column(sa.Integer, primary_key=True)
|
||||||
|
name = sa.Column(sa.Unicode(255))
|
||||||
|
content = sa.Column(sa.UnicodeText)
|
||||||
|
search_vector = sa.Column(TSVectorType)
|
||||||
|
search_vector2 = sa.Column(TSVectorType)
|
||||||
|
|
||||||
|
self.Article = Article
|
||||||
|
|
||||||
|
|
||||||
|
class TestMatchTSVector(TSVectorTestCase):
|
||||||
|
def test_raises_exception_if_less_than_2_parameters_given(self):
|
||||||
|
with raises(Exception):
|
||||||
|
str(
|
||||||
|
tsvector_match(
|
||||||
|
self.Article.search_vector,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_supports_postgres(self):
|
||||||
|
assert str(tsvector_match(
|
||||||
|
self.Article.search_vector,
|
||||||
|
'something',
|
||||||
|
)) == '(article.search_vector) @@ to_tsquery(:tsvector_match_1)'
|
||||||
|
|
||||||
|
def test_supports_collation_as_3rd_parameter(self):
|
||||||
|
assert str(tsvector_match(
|
||||||
|
self.Article.search_vector,
|
||||||
|
'something',
|
||||||
|
'finnish'
|
||||||
|
)) == (
|
||||||
|
'(article.search_vector) @@ '
|
||||||
|
'to_tsquery(:tsvector_match_1, :tsvector_match_2)'
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class TestConcatTSVector(TSVectorTestCase):
|
||||||
|
def test_concatenate_search_vectors(self):
|
||||||
|
assert str(tsvector_match(
|
||||||
|
tsvector_concat(
|
||||||
|
self.Article.search_vector,
|
||||||
|
self.Article.search_vector2
|
||||||
|
),
|
||||||
|
'something',
|
||||||
|
'finnish'
|
||||||
|
)) == (
|
||||||
|
'(article.search_vector || article.search_vector2) '
|
||||||
|
'@@ to_tsquery(:tsvector_match_1, :tsvector_match_2)'
|
||||||
|
)
|
@@ -3,7 +3,7 @@ from sqlalchemy_utils import escape_like, defer_except
|
|||||||
from tests import TestCase
|
from tests import TestCase
|
||||||
from sqlalchemy_utils.functions import (
|
from sqlalchemy_utils.functions import (
|
||||||
non_indexed_foreign_keys,
|
non_indexed_foreign_keys,
|
||||||
render_statement
|
render_statement,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
Reference in New Issue
Block a user