Rewrite TSVectorType to cope with SA's own TSVECTOR

This commit is contained in:
Konsta Vesterinen
2015-01-02 18:35:45 +02:00
parent 0e40618ff9
commit 20ea950a4f
6 changed files with 139 additions and 147 deletions

View File

@@ -4,6 +4,15 @@ Changelog
Here you can see the full list of changes between each SQLAlchemy-Utils release.
0.29.0 (2015-01-02)
^^^^^^^^^^^^^^^^^^^
- Removed TSVectorType.match_tsquery (now replaced by TSVectorType.match to be compatible with SQLAlchemy)
- Removed undocumented function tsvector_concat
- Added support for TSVectorType concatenation through OR operator
- Added documentation for TSVectorType (#102)
0.28.3 (2014-12-17)
^^^^^^^^^^^^^^^^^^^

View File

@@ -103,6 +103,14 @@ TimezoneType
.. autoclass:: TimezoneType
TSVectorType
^^^^^^^^^^^^
.. module:: sqlalchemy_utils.types.ts_vector
.. autoclass:: TSVectorType
URLType
^^^^^^^

View File

@@ -6,7 +6,6 @@ from sqlalchemy.sql.expression import (
_literal_as_text
)
from sqlalchemy.ext.compiler import compiles
from sqlalchemy_utils.types import TSVectorType
class explain(Executable, ClauseElement):
@@ -65,66 +64,6 @@ def pg_explain(element, compiler, **kw):
return text
class tsvector_match(expression.FunctionElement):
type = sa.types.Unicode()
name = 'tsvector_match'
@compiles(tsvector_match)
def compile_tsvector_match(element, compiler, **kw):
args = list(element.clauses)
if len(args) < 2:
raise Exception(
"Function 'tsvector_match' expects atleast two arguments."
)
return '(%s) @@ %s' % (
compiler.process(args[0]),
compiler.process(args[1])
)
class to_tsquery(expression.FunctionElement):
type = sa.types.Unicode()
name = 'to_tsquery'
@compiles(to_tsquery)
def compile_to_tsquery(element, compiler, **kw):
if len(element.clauses) < 1:
raise Exception(
"Function 'to_tsquery' expects atleast one argument."
)
return 'to_tsquery(%s)' % (
', '.join(map(compiler.process, element.clauses))
)
class plainto_tsquery(expression.FunctionElement):
type = sa.types.Unicode()
name = 'plainto_tsquery'
@compiles(plainto_tsquery)
def compile_plainto_tsquery(element, compiler, **kw):
if len(element.clauses) < 1:
raise Exception(
"Function 'plainto_tsquery' expects atleast one argument."
)
return 'plainto_tsquery(%s)' % (
', '.join(map(compiler.process, element.clauses))
)
class tsvector_concat(expression.FunctionElement):
type = TSVectorType()
name = 'tsvector_concat'
@compiles(tsvector_concat)
def compile_tsvector_concat(element, compiler, **kw):
return ' || '.join(map(compiler.process, element.clauses))
class array_get(expression.FunctionElement):
name = 'array_get'

View File

@@ -1,24 +1,100 @@
import sqlalchemy as sa
from sqlalchemy.dialects.postgresql.base import ischema_names
from sqlalchemy.dialects.postgresql import TSVECTOR
class TSVectorType(sa.types.UserDefinedType):
class comparator_factory(sa.types.TypeEngine.Comparator):
def match_tsquery(self, other, catalog=None):
from sqlalchemy_utils.expressions import tsvector_match, to_tsquery
class TSVectorType(sa.types.TypeDecorator):
"""
.. note::
args = []
if catalog:
args.append(catalog)
elif self.type.options.get('catalog'):
args.append(self.type.options.get('catalog'))
args.append(other)
This type is PostgreSQL specific and is not supported by other
dialects.
return tsvector_match(
self.expr,
to_tsquery(*args)
Provides additional functionality for SQLAlchemy PostgreSQL dialect's
TSVECTOR_ type. This additional functionality includes:
* Vector concatenation
* regconfig constructor parameter which is applied to match function if no
postgresql_regconfig parameter is given
* Provides extensible base for extensions such as SQLAlchemy-Searchable_
.. _TSVECTOR:
http://docs.sqlalchemy.org/en/latest/dialects/postgresql.html#full-text-search
.. _SQLAlchemy-Searchable:
https://www.github.com/kvesteri/sqlalchemy-searchable
::
from sqlalchemy_utils import TSVectorType
class Article(self.Base):
__tablename__ = 'user'
id = sa.Column(sa.Integer, primary_key=True)
name = sa.Column(sa.String(100))
search_vector = sa.Column(TSVectorType)
# Find all articles whose name matches 'finland'
session.query(Article).filter(Article.search_vector.match('finland'))
TSVectorType also supports vector concatenation.
::
class Article(self.Base):
__tablename__ = 'user'
id = sa.Column(sa.Integer, primary_key=True)
name = sa.Column(sa.String(100))
name_vector = sa.Column(TSVectorType)
content = sa.Column(sa.String)
content_vector = sa.Column(TSVectorType)
# Find all articles whose name or content matches 'finland'
session.query(Article).filter(
(Article.name_vector | Article.content_vector).match('finland')
)
You can configure TSVectorType to use a specific regconfig.
::
class Article(self.Base):
__tablename__ = 'user'
id = sa.Column(sa.Integer, primary_key=True)
name = sa.Column(sa.String(100))
search_vector = sa.Column(
TSVectorType(regconfig='pg_catalog.simple')
)
Now expression such as::
Article.search_vector.match('finland')
Would be equivalent to SQL::
search_vector @@ to_tsquery('pg_catalog.simgle', 'finland')
"""
impl = TSVECTOR
class comparator_factory(TSVECTOR.Comparator):
def match(self, other, **kwargs):
if 'postgresql_regconfig' not in kwargs:
if 'regconfig' in self.type.options:
kwargs['postgresql_regconfig'] = (
self.type.options['regconfig']
)
return TSVECTOR.Comparator.match(self, other, **kwargs)
def __or__(self, other):
return self.op('||')(other)
def __init__(self, *args, **kwargs):
"""
Initializes new TSVectorType
@@ -28,12 +104,4 @@ class TSVectorType(sa.types.UserDefinedType):
"""
self.columns = args
self.options = kwargs
"""
Text search vector type for postgresql.
"""
def get_col_spec(self):
return 'tsvector'
ischema_names['tsvector'] = TSVectorType
super(TSVectorType, self).__init__()

View File

@@ -93,53 +93,3 @@ class TestExplainAnalyze(ExpressionTestCase):
dialect=postgresql.dialect()
)
).startswith('EXPLAIN (ANALYZE true) SELECT')
class TestMatchTSVector(ExpressionTestCase):
def test_raises_exception_if_less_than_2_parameters_given(self):
with raises(Exception):
str(
tsvector_match(
self.Article.search_vector,
)
)
def test_supports_postgres(self):
assert str(tsvector_match(
self.Article.search_vector,
to_tsquery('something'),
)) == '(article.search_vector) @@ to_tsquery(:to_tsquery_1)'
class TestToTSQuery(ExpressionTestCase):
def test_requires_atleast_one_parameter(self):
with raises(Exception):
str(to_tsquery())
def test_supports_postgres(self):
assert str(to_tsquery('something')) == 'to_tsquery(:to_tsquery_1)'
class TestPlainToTSQuery(ExpressionTestCase):
def test_requires_atleast_one_parameter(self):
with raises(Exception):
str(plainto_tsquery())
def test_supports_postgres(self):
assert str(plainto_tsquery('s')) == (
'plainto_tsquery(:plainto_tsquery_1)'
)
class TestConcatTSVector(ExpressionTestCase):
def test_concatenate_search_vectors(self):
assert str(tsvector_match(
tsvector_concat(
self.Article.search_vector,
self.Article.search_vector2
),
to_tsquery('finnish', 'something'),
)) == (
'(article.search_vector || article.search_vector2) '
'@@ to_tsquery(:to_tsquery_1, :to_tsquery_2)'
)

View File

@@ -1,5 +1,5 @@
import six
import sqlalchemy as sa
from sqlalchemy.dialects.postgresql import TSVECTOR
from sqlalchemy_utils import TSVectorType
from tests import TestCase
@@ -12,7 +12,9 @@ class TestTSVector(TestCase):
__tablename__ = 'user'
id = sa.Column(sa.Integer, primary_key=True)
name = sa.Column(sa.Unicode(255))
search_index = sa.Column(TSVectorType())
search_index = sa.Column(
TSVectorType(name, regconfig='pg_catalog.finnish')
)
def __repr__(self):
return 'User(%r)' % self.id
@@ -22,7 +24,7 @@ class TestTSVector(TestCase):
def test_generates_table(self):
assert 'search_index' in self.User.__table__.c
def test_type_autoloading(self):
def test_type_reflection(self):
reflected_metadata = sa.schema.MetaData()
table = sa.schema.Table(
'user',
@@ -30,23 +32,39 @@ class TestTSVector(TestCase):
autoload=True,
autoload_with=self.engine
)
assert isinstance(table.c['search_index'].type, TSVectorType)
assert isinstance(table.c['search_index'].type, TSVECTOR)
def test_catalog_and_columns_as_args(self):
type_ = TSVectorType('name', 'age', catalog='pg_catalog.simple')
type_ = TSVectorType('name', 'age', regconfig='pg_catalog.simple')
assert type_.columns == ('name', 'age')
assert type_.options['catalog'] == 'pg_catalog.simple'
assert type_.options['regconfig'] == 'pg_catalog.simple'
def test_match(self):
expr = self.User.search_index.match_tsquery(u'something')
assert six.text_type(expr) == (
'("user".search_index) @@ to_tsquery(:to_tsquery_1)'
expr = self.User.search_index.match(u'something')
assert str(expr.compile(self.connection)) == (
'''"user".search_index @@ to_tsquery('pg_catalog.finnish', '''
'''%(search_index_1)s)'''
)
def test_concat(self):
assert str(self.User.search_index | self.User.search_index) == (
'"user".search_index || "user".search_index'
)
def test_match_concatenation(self):
concat = self.User.search_index | self.User.search_index
bind = self.session.bind
assert str(concat.match('something').compile(bind)) == (
'("user".search_index || "user".search_index) @@ '
"to_tsquery('pg_catalog.finnish', %(param_1)s)"
)
def test_match_with_catalog(self):
expr = self.User.search_index.match_tsquery(
u'something', catalog='pg_catalog.simple'
expr = self.User.search_index.match(
u'something',
postgresql_regconfig='pg_catalog.simple'
)
assert six.text_type(expr) == (
'("user".search_index) @@ to_tsquery(:to_tsquery_1, :to_tsquery_2)'
assert str(expr.compile(self.connection)) == (
'''"user".search_index @@ to_tsquery('pg_catalog.simple', '''
'''%(search_index_1)s)'''
)