Added new expressions tsvector_match and tsvector_concat

This commit is contained in:
Konsta Vesterinen
2013-08-01 10:25:37 +03:00
parent 22f7be96f5
commit a60c2424c7
5 changed files with 116 additions and 14 deletions

View File

@@ -0,0 +1,39 @@
import sqlalchemy as sa
from sqlalchemy.sql import expression
from sqlalchemy.ext.compiler import compiles
from sqlalchemy_utils.types import TSVectorType
class tsvector_match(expression.FunctionElement):
type = sa.types.Unicode()
name = 'tsvector_match'
@compiles(tsvector_match)
def compile_tsvector_match(element, compiler, **kw):
args = list(element.clauses)
if len(args) < 2:
raise Exception(
"Function 'match_tsvector' expects atleast two arguments."
)
if len(args) == 2:
return '(%s) @@ to_tsquery(%s)' % (
compiler.process(args[0]),
compiler.process(args[1])
)
elif len(args) == 3:
return '(%s) @@ to_tsquery(%s, %s)' % (
compiler.process(args[0]),
compiler.process(args[2]),
compiler.process(args[1])
)
class tsvector_concat(expression.FunctionElement):
type = TSVectorType()
name = 'tsvector_concat'
@compiles(tsvector_concat)
def compile_tsvector_concat(element, compiler, **kw):
return ' || '.join(map(compiler.process, element.clauses))

View File

@@ -1,7 +1,5 @@
from functools import wraps
from sqlalchemy.orm.collections import InstrumentedList as _InstrumentedList
from sqlalchemy import types
from sqlalchemy.dialects.postgresql.base import ischema_names
from .arrow import ArrowType
from .color import ColorType
from .email import EmailType
@@ -16,6 +14,7 @@ from .password import Password, PasswordType
from .phone_number import PhoneNumber, PhoneNumberType
from .scalar_list import ScalarListException, ScalarListType
from .timezone import TimezoneType
from .ts_vector import TSVectorType
from .uuid import UUIDType
@@ -35,21 +34,11 @@ __all__ = (
ScalarListException,
ScalarListType,
TimezoneType,
TSVectorType,
UUIDType,
)
class TSVectorType(types.UserDefinedType):
"""
Text search vector type for postgresql.
"""
def get_col_spec(self):
return 'tsvector'
ischema_names['tsvector'] = TSVectorType
class InstrumentedList(_InstrumentedList):
"""Enhanced version of SQLAlchemy InstrumentedList. Provides some
additional functionality."""

View File

@@ -0,0 +1,13 @@
import sqlalchemy as sa
from sqlalchemy.dialects.postgresql.base import ischema_names
class TSVectorType(sa.types.UserDefinedType):
"""
Text search vector type for postgresql.
"""
def get_col_spec(self):
return 'tsvector'
ischema_names['tsvector'] = TSVectorType

61
tests/test_expressions.py Normal file
View File

@@ -0,0 +1,61 @@
from pytest import raises
import sqlalchemy as sa
from sqlalchemy_utils.types import TSVectorType
from sqlalchemy_utils.expressions import tsvector_match, tsvector_concat
from tests import TestCase
class TSVectorTestCase(TestCase):
dns = 'postgres://postgres@localhost/sqlalchemy_utils_test'
def create_models(self):
class Article(self.Base):
__tablename__ = 'article'
id = sa.Column(sa.Integer, primary_key=True)
name = sa.Column(sa.Unicode(255))
content = sa.Column(sa.UnicodeText)
search_vector = sa.Column(TSVectorType)
search_vector2 = sa.Column(TSVectorType)
self.Article = Article
class TestMatchTSVector(TSVectorTestCase):
def test_raises_exception_if_less_than_2_parameters_given(self):
with raises(Exception):
str(
tsvector_match(
self.Article.search_vector,
)
)
def test_supports_postgres(self):
assert str(tsvector_match(
self.Article.search_vector,
'something',
)) == '(article.search_vector) @@ to_tsquery(:tsvector_match_1)'
def test_supports_collation_as_3rd_parameter(self):
assert str(tsvector_match(
self.Article.search_vector,
'something',
'finnish'
)) == (
'(article.search_vector) @@ '
'to_tsquery(:tsvector_match_1, :tsvector_match_2)'
)
class TestConcatTSVector(TSVectorTestCase):
def test_concatenate_search_vectors(self):
assert str(tsvector_match(
tsvector_concat(
self.Article.search_vector,
self.Article.search_vector2
),
'something',
'finnish'
)) == (
'(article.search_vector || article.search_vector2) '
'@@ to_tsquery(:tsvector_match_1, :tsvector_match_2)'
)

View File

@@ -3,7 +3,7 @@ from sqlalchemy_utils import escape_like, defer_except
from tests import TestCase
from sqlalchemy_utils.functions import (
non_indexed_foreign_keys,
render_statement
render_statement,
)