Rewrite TSVectorType to cope with SA's own TSVECTOR
This commit is contained in:
@@ -4,6 +4,15 @@ Changelog
|
||||
Here you can see the full list of changes between each SQLAlchemy-Utils release.
|
||||
|
||||
|
||||
0.29.0 (2015-01-02)
|
||||
^^^^^^^^^^^^^^^^^^^
|
||||
|
||||
- Removed TSVectorType.match_tsquery (now replaced by TSVectorType.match to be compatible with SQLAlchemy)
|
||||
- Removed undocumented function tsvector_concat
|
||||
- Added support for TSVectorType concatenation through OR operator
|
||||
- Added documentation for TSVectorType (#102)
|
||||
|
||||
|
||||
0.28.3 (2014-12-17)
|
||||
^^^^^^^^^^^^^^^^^^^
|
||||
|
||||
|
@@ -103,6 +103,14 @@ TimezoneType
|
||||
.. autoclass:: TimezoneType
|
||||
|
||||
|
||||
TSVectorType
|
||||
^^^^^^^^^^^^
|
||||
|
||||
.. module:: sqlalchemy_utils.types.ts_vector
|
||||
|
||||
.. autoclass:: TSVectorType
|
||||
|
||||
|
||||
URLType
|
||||
^^^^^^^
|
||||
|
||||
|
@@ -6,7 +6,6 @@ from sqlalchemy.sql.expression import (
|
||||
_literal_as_text
|
||||
)
|
||||
from sqlalchemy.ext.compiler import compiles
|
||||
from sqlalchemy_utils.types import TSVectorType
|
||||
|
||||
|
||||
class explain(Executable, ClauseElement):
|
||||
@@ -65,66 +64,6 @@ def pg_explain(element, compiler, **kw):
|
||||
return text
|
||||
|
||||
|
||||
class tsvector_match(expression.FunctionElement):
|
||||
type = sa.types.Unicode()
|
||||
name = 'tsvector_match'
|
||||
|
||||
|
||||
@compiles(tsvector_match)
|
||||
def compile_tsvector_match(element, compiler, **kw):
|
||||
args = list(element.clauses)
|
||||
if len(args) < 2:
|
||||
raise Exception(
|
||||
"Function 'tsvector_match' expects atleast two arguments."
|
||||
)
|
||||
return '(%s) @@ %s' % (
|
||||
compiler.process(args[0]),
|
||||
compiler.process(args[1])
|
||||
)
|
||||
|
||||
|
||||
class to_tsquery(expression.FunctionElement):
|
||||
type = sa.types.Unicode()
|
||||
name = 'to_tsquery'
|
||||
|
||||
|
||||
@compiles(to_tsquery)
|
||||
def compile_to_tsquery(element, compiler, **kw):
|
||||
if len(element.clauses) < 1:
|
||||
raise Exception(
|
||||
"Function 'to_tsquery' expects atleast one argument."
|
||||
)
|
||||
return 'to_tsquery(%s)' % (
|
||||
', '.join(map(compiler.process, element.clauses))
|
||||
)
|
||||
|
||||
|
||||
class plainto_tsquery(expression.FunctionElement):
|
||||
type = sa.types.Unicode()
|
||||
name = 'plainto_tsquery'
|
||||
|
||||
|
||||
@compiles(plainto_tsquery)
|
||||
def compile_plainto_tsquery(element, compiler, **kw):
|
||||
if len(element.clauses) < 1:
|
||||
raise Exception(
|
||||
"Function 'plainto_tsquery' expects atleast one argument."
|
||||
)
|
||||
return 'plainto_tsquery(%s)' % (
|
||||
', '.join(map(compiler.process, element.clauses))
|
||||
)
|
||||
|
||||
|
||||
class tsvector_concat(expression.FunctionElement):
|
||||
type = TSVectorType()
|
||||
name = 'tsvector_concat'
|
||||
|
||||
|
||||
@compiles(tsvector_concat)
|
||||
def compile_tsvector_concat(element, compiler, **kw):
|
||||
return ' || '.join(map(compiler.process, element.clauses))
|
||||
|
||||
|
||||
class array_get(expression.FunctionElement):
|
||||
name = 'array_get'
|
||||
|
||||
|
@@ -1,24 +1,100 @@
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy.dialects.postgresql.base import ischema_names
|
||||
from sqlalchemy.dialects.postgresql import TSVECTOR
|
||||
|
||||
|
||||
class TSVectorType(sa.types.UserDefinedType):
|
||||
class comparator_factory(sa.types.TypeEngine.Comparator):
|
||||
def match_tsquery(self, other, catalog=None):
|
||||
from sqlalchemy_utils.expressions import tsvector_match, to_tsquery
|
||||
class TSVectorType(sa.types.TypeDecorator):
|
||||
"""
|
||||
.. note::
|
||||
|
||||
args = []
|
||||
if catalog:
|
||||
args.append(catalog)
|
||||
elif self.type.options.get('catalog'):
|
||||
args.append(self.type.options.get('catalog'))
|
||||
args.append(other)
|
||||
This type is PostgreSQL specific and is not supported by other
|
||||
dialects.
|
||||
|
||||
return tsvector_match(
|
||||
self.expr,
|
||||
to_tsquery(*args)
|
||||
Provides additional functionality for SQLAlchemy PostgreSQL dialect's
|
||||
TSVECTOR_ type. This additional functionality includes:
|
||||
|
||||
* Vector concatenation
|
||||
* regconfig constructor parameter which is applied to match function if no
|
||||
postgresql_regconfig parameter is given
|
||||
* Provides extensible base for extensions such as SQLAlchemy-Searchable_
|
||||
|
||||
.. _TSVECTOR:
|
||||
http://docs.sqlalchemy.org/en/latest/dialects/postgresql.html#full-text-search
|
||||
|
||||
.. _SQLAlchemy-Searchable:
|
||||
https://www.github.com/kvesteri/sqlalchemy-searchable
|
||||
|
||||
::
|
||||
|
||||
from sqlalchemy_utils import TSVectorType
|
||||
|
||||
|
||||
class Article(self.Base):
|
||||
__tablename__ = 'user'
|
||||
id = sa.Column(sa.Integer, primary_key=True)
|
||||
name = sa.Column(sa.String(100))
|
||||
search_vector = sa.Column(TSVectorType)
|
||||
|
||||
|
||||
# Find all articles whose name matches 'finland'
|
||||
session.query(Article).filter(Article.search_vector.match('finland'))
|
||||
|
||||
|
||||
TSVectorType also supports vector concatenation.
|
||||
|
||||
::
|
||||
|
||||
|
||||
class Article(self.Base):
|
||||
__tablename__ = 'user'
|
||||
id = sa.Column(sa.Integer, primary_key=True)
|
||||
name = sa.Column(sa.String(100))
|
||||
name_vector = sa.Column(TSVectorType)
|
||||
content = sa.Column(sa.String)
|
||||
content_vector = sa.Column(TSVectorType)
|
||||
|
||||
# Find all articles whose name or content matches 'finland'
|
||||
session.query(Article).filter(
|
||||
(Article.name_vector | Article.content_vector).match('finland')
|
||||
)
|
||||
|
||||
You can configure TSVectorType to use a specific regconfig.
|
||||
::
|
||||
|
||||
class Article(self.Base):
|
||||
__tablename__ = 'user'
|
||||
id = sa.Column(sa.Integer, primary_key=True)
|
||||
name = sa.Column(sa.String(100))
|
||||
search_vector = sa.Column(
|
||||
TSVectorType(regconfig='pg_catalog.simple')
|
||||
)
|
||||
|
||||
|
||||
Now expression such as::
|
||||
|
||||
|
||||
Article.search_vector.match('finland')
|
||||
|
||||
|
||||
Would be equivalent to SQL::
|
||||
|
||||
|
||||
search_vector @@ to_tsquery('pg_catalog.simgle', 'finland')
|
||||
|
||||
"""
|
||||
impl = TSVECTOR
|
||||
|
||||
class comparator_factory(TSVECTOR.Comparator):
|
||||
def match(self, other, **kwargs):
|
||||
if 'postgresql_regconfig' not in kwargs:
|
||||
if 'regconfig' in self.type.options:
|
||||
kwargs['postgresql_regconfig'] = (
|
||||
self.type.options['regconfig']
|
||||
)
|
||||
return TSVECTOR.Comparator.match(self, other, **kwargs)
|
||||
|
||||
def __or__(self, other):
|
||||
return self.op('||')(other)
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
"""
|
||||
Initializes new TSVectorType
|
||||
@@ -28,12 +104,4 @@ class TSVectorType(sa.types.UserDefinedType):
|
||||
"""
|
||||
self.columns = args
|
||||
self.options = kwargs
|
||||
|
||||
"""
|
||||
Text search vector type for postgresql.
|
||||
"""
|
||||
def get_col_spec(self):
|
||||
return 'tsvector'
|
||||
|
||||
|
||||
ischema_names['tsvector'] = TSVectorType
|
||||
super(TSVectorType, self).__init__()
|
||||
|
@@ -93,53 +93,3 @@ class TestExplainAnalyze(ExpressionTestCase):
|
||||
dialect=postgresql.dialect()
|
||||
)
|
||||
).startswith('EXPLAIN (ANALYZE true) SELECT')
|
||||
|
||||
|
||||
class TestMatchTSVector(ExpressionTestCase):
|
||||
def test_raises_exception_if_less_than_2_parameters_given(self):
|
||||
with raises(Exception):
|
||||
str(
|
||||
tsvector_match(
|
||||
self.Article.search_vector,
|
||||
)
|
||||
)
|
||||
|
||||
def test_supports_postgres(self):
|
||||
assert str(tsvector_match(
|
||||
self.Article.search_vector,
|
||||
to_tsquery('something'),
|
||||
)) == '(article.search_vector) @@ to_tsquery(:to_tsquery_1)'
|
||||
|
||||
|
||||
class TestToTSQuery(ExpressionTestCase):
|
||||
def test_requires_atleast_one_parameter(self):
|
||||
with raises(Exception):
|
||||
str(to_tsquery())
|
||||
|
||||
def test_supports_postgres(self):
|
||||
assert str(to_tsquery('something')) == 'to_tsquery(:to_tsquery_1)'
|
||||
|
||||
|
||||
class TestPlainToTSQuery(ExpressionTestCase):
|
||||
def test_requires_atleast_one_parameter(self):
|
||||
with raises(Exception):
|
||||
str(plainto_tsquery())
|
||||
|
||||
def test_supports_postgres(self):
|
||||
assert str(plainto_tsquery('s')) == (
|
||||
'plainto_tsquery(:plainto_tsquery_1)'
|
||||
)
|
||||
|
||||
|
||||
class TestConcatTSVector(ExpressionTestCase):
|
||||
def test_concatenate_search_vectors(self):
|
||||
assert str(tsvector_match(
|
||||
tsvector_concat(
|
||||
self.Article.search_vector,
|
||||
self.Article.search_vector2
|
||||
),
|
||||
to_tsquery('finnish', 'something'),
|
||||
)) == (
|
||||
'(article.search_vector || article.search_vector2) '
|
||||
'@@ to_tsquery(:to_tsquery_1, :to_tsquery_2)'
|
||||
)
|
||||
|
@@ -1,5 +1,5 @@
|
||||
import six
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy.dialects.postgresql import TSVECTOR
|
||||
from sqlalchemy_utils import TSVectorType
|
||||
from tests import TestCase
|
||||
|
||||
@@ -12,7 +12,9 @@ class TestTSVector(TestCase):
|
||||
__tablename__ = 'user'
|
||||
id = sa.Column(sa.Integer, primary_key=True)
|
||||
name = sa.Column(sa.Unicode(255))
|
||||
search_index = sa.Column(TSVectorType())
|
||||
search_index = sa.Column(
|
||||
TSVectorType(name, regconfig='pg_catalog.finnish')
|
||||
)
|
||||
|
||||
def __repr__(self):
|
||||
return 'User(%r)' % self.id
|
||||
@@ -22,7 +24,7 @@ class TestTSVector(TestCase):
|
||||
def test_generates_table(self):
|
||||
assert 'search_index' in self.User.__table__.c
|
||||
|
||||
def test_type_autoloading(self):
|
||||
def test_type_reflection(self):
|
||||
reflected_metadata = sa.schema.MetaData()
|
||||
table = sa.schema.Table(
|
||||
'user',
|
||||
@@ -30,23 +32,39 @@ class TestTSVector(TestCase):
|
||||
autoload=True,
|
||||
autoload_with=self.engine
|
||||
)
|
||||
assert isinstance(table.c['search_index'].type, TSVectorType)
|
||||
assert isinstance(table.c['search_index'].type, TSVECTOR)
|
||||
|
||||
def test_catalog_and_columns_as_args(self):
|
||||
type_ = TSVectorType('name', 'age', catalog='pg_catalog.simple')
|
||||
type_ = TSVectorType('name', 'age', regconfig='pg_catalog.simple')
|
||||
assert type_.columns == ('name', 'age')
|
||||
assert type_.options['catalog'] == 'pg_catalog.simple'
|
||||
assert type_.options['regconfig'] == 'pg_catalog.simple'
|
||||
|
||||
def test_match(self):
|
||||
expr = self.User.search_index.match_tsquery(u'something')
|
||||
assert six.text_type(expr) == (
|
||||
'("user".search_index) @@ to_tsquery(:to_tsquery_1)'
|
||||
expr = self.User.search_index.match(u'something')
|
||||
assert str(expr.compile(self.connection)) == (
|
||||
'''"user".search_index @@ to_tsquery('pg_catalog.finnish', '''
|
||||
'''%(search_index_1)s)'''
|
||||
)
|
||||
|
||||
def test_concat(self):
|
||||
assert str(self.User.search_index | self.User.search_index) == (
|
||||
'"user".search_index || "user".search_index'
|
||||
)
|
||||
|
||||
def test_match_concatenation(self):
|
||||
concat = self.User.search_index | self.User.search_index
|
||||
bind = self.session.bind
|
||||
assert str(concat.match('something').compile(bind)) == (
|
||||
'("user".search_index || "user".search_index) @@ '
|
||||
"to_tsquery('pg_catalog.finnish', %(param_1)s)"
|
||||
)
|
||||
|
||||
def test_match_with_catalog(self):
|
||||
expr = self.User.search_index.match_tsquery(
|
||||
u'something', catalog='pg_catalog.simple'
|
||||
expr = self.User.search_index.match(
|
||||
u'something',
|
||||
postgresql_regconfig='pg_catalog.simple'
|
||||
)
|
||||
assert six.text_type(expr) == (
|
||||
'("user".search_index) @@ to_tsquery(:to_tsquery_1, :to_tsquery_2)'
|
||||
assert str(expr.compile(self.connection)) == (
|
||||
'''"user".search_index @@ to_tsquery('pg_catalog.simple', '''
|
||||
'''%(search_index_1)s)'''
|
||||
)
|
||||
|
Reference in New Issue
Block a user