From 20ea950a4fa2150387725987d377bb0f59d92ce6 Mon Sep 17 00:00:00 2001 From: Konsta Vesterinen Date: Fri, 2 Jan 2015 18:35:45 +0200 Subject: [PATCH] Rewrite TSVectorType to cope with SA's own TSVECTOR --- CHANGES.rst | 9 +++ docs/data_types.rst | 8 ++ sqlalchemy_utils/expressions.py | 61 --------------- sqlalchemy_utils/types/ts_vector.py | 114 ++++++++++++++++++++++------ tests/test_expressions.py | 50 ------------ tests/types/test_tsvector.py | 44 +++++++---- 6 files changed, 139 insertions(+), 147 deletions(-) diff --git a/CHANGES.rst b/CHANGES.rst index bf4853e..5dda050 100644 --- a/CHANGES.rst +++ b/CHANGES.rst @@ -4,6 +4,15 @@ Changelog Here you can see the full list of changes between each SQLAlchemy-Utils release. +0.29.0 (2015-01-02) +^^^^^^^^^^^^^^^^^^^ + +- Removed TSVectorType.match_tsquery (now replaced by TSVectorType.match to be compatible with SQLAlchemy) +- Removed undocumented function tsvector_concat +- Added support for TSVectorType concatenation through OR operator +- Added documentation for TSVectorType (#102) + + 0.28.3 (2014-12-17) ^^^^^^^^^^^^^^^^^^^ diff --git a/docs/data_types.rst b/docs/data_types.rst index 1281a65..9501570 100644 --- a/docs/data_types.rst +++ b/docs/data_types.rst @@ -103,6 +103,14 @@ TimezoneType .. autoclass:: TimezoneType +TSVectorType +^^^^^^^^^^^^ + +.. module:: sqlalchemy_utils.types.ts_vector + +.. autoclass:: TSVectorType + + URLType ^^^^^^^ diff --git a/sqlalchemy_utils/expressions.py b/sqlalchemy_utils/expressions.py index e69226f..9eadac4 100644 --- a/sqlalchemy_utils/expressions.py +++ b/sqlalchemy_utils/expressions.py @@ -6,7 +6,6 @@ from sqlalchemy.sql.expression import ( _literal_as_text ) from sqlalchemy.ext.compiler import compiles -from sqlalchemy_utils.types import TSVectorType class explain(Executable, ClauseElement): @@ -65,66 +64,6 @@ def pg_explain(element, compiler, **kw): return text -class tsvector_match(expression.FunctionElement): - type = sa.types.Unicode() - name = 'tsvector_match' - - -@compiles(tsvector_match) -def compile_tsvector_match(element, compiler, **kw): - args = list(element.clauses) - if len(args) < 2: - raise Exception( - "Function 'tsvector_match' expects atleast two arguments." - ) - return '(%s) @@ %s' % ( - compiler.process(args[0]), - compiler.process(args[1]) - ) - - -class to_tsquery(expression.FunctionElement): - type = sa.types.Unicode() - name = 'to_tsquery' - - -@compiles(to_tsquery) -def compile_to_tsquery(element, compiler, **kw): - if len(element.clauses) < 1: - raise Exception( - "Function 'to_tsquery' expects atleast one argument." - ) - return 'to_tsquery(%s)' % ( - ', '.join(map(compiler.process, element.clauses)) - ) - - -class plainto_tsquery(expression.FunctionElement): - type = sa.types.Unicode() - name = 'plainto_tsquery' - - -@compiles(plainto_tsquery) -def compile_plainto_tsquery(element, compiler, **kw): - if len(element.clauses) < 1: - raise Exception( - "Function 'plainto_tsquery' expects atleast one argument." - ) - return 'plainto_tsquery(%s)' % ( - ', '.join(map(compiler.process, element.clauses)) - ) - - -class tsvector_concat(expression.FunctionElement): - type = TSVectorType() - name = 'tsvector_concat' - - -@compiles(tsvector_concat) -def compile_tsvector_concat(element, compiler, **kw): - return ' || '.join(map(compiler.process, element.clauses)) - - class array_get(expression.FunctionElement): name = 'array_get' diff --git a/sqlalchemy_utils/types/ts_vector.py b/sqlalchemy_utils/types/ts_vector.py index 8734910..6c4a9b2 100644 --- a/sqlalchemy_utils/types/ts_vector.py +++ b/sqlalchemy_utils/types/ts_vector.py @@ -1,24 +1,100 @@ import sqlalchemy as sa -from sqlalchemy.dialects.postgresql.base import ischema_names +from sqlalchemy.dialects.postgresql import TSVECTOR -class TSVectorType(sa.types.UserDefinedType): - class comparator_factory(sa.types.TypeEngine.Comparator): - def match_tsquery(self, other, catalog=None): - from sqlalchemy_utils.expressions import tsvector_match, to_tsquery +class TSVectorType(sa.types.TypeDecorator): + """ + .. note:: - args = [] - if catalog: - args.append(catalog) - elif self.type.options.get('catalog'): - args.append(self.type.options.get('catalog')) - args.append(other) + This type is PostgreSQL specific and is not supported by other + dialects. - return tsvector_match( - self.expr, - to_tsquery(*args) + Provides additional functionality for SQLAlchemy PostgreSQL dialect's + TSVECTOR_ type. This additional functionality includes: + + * Vector concatenation + * regconfig constructor parameter which is applied to match function if no + postgresql_regconfig parameter is given + * Provides extensible base for extensions such as SQLAlchemy-Searchable_ + + .. _TSVECTOR: + http://docs.sqlalchemy.org/en/latest/dialects/postgresql.html#full-text-search + + .. _SQLAlchemy-Searchable: + https://www.github.com/kvesteri/sqlalchemy-searchable + + :: + + from sqlalchemy_utils import TSVectorType + + + class Article(self.Base): + __tablename__ = 'user' + id = sa.Column(sa.Integer, primary_key=True) + name = sa.Column(sa.String(100)) + search_vector = sa.Column(TSVectorType) + + + # Find all articles whose name matches 'finland' + session.query(Article).filter(Article.search_vector.match('finland')) + + + TSVectorType also supports vector concatenation. + + :: + + + class Article(self.Base): + __tablename__ = 'user' + id = sa.Column(sa.Integer, primary_key=True) + name = sa.Column(sa.String(100)) + name_vector = sa.Column(TSVectorType) + content = sa.Column(sa.String) + content_vector = sa.Column(TSVectorType) + + # Find all articles whose name or content matches 'finland' + session.query(Article).filter( + (Article.name_vector | Article.content_vector).match('finland') + ) + + You can configure TSVectorType to use a specific regconfig. + :: + + class Article(self.Base): + __tablename__ = 'user' + id = sa.Column(sa.Integer, primary_key=True) + name = sa.Column(sa.String(100)) + search_vector = sa.Column( + TSVectorType(regconfig='pg_catalog.simple') ) + + Now expression such as:: + + + Article.search_vector.match('finland') + + + Would be equivalent to SQL:: + + + search_vector @@ to_tsquery('pg_catalog.simgle', 'finland') + + """ + impl = TSVECTOR + + class comparator_factory(TSVECTOR.Comparator): + def match(self, other, **kwargs): + if 'postgresql_regconfig' not in kwargs: + if 'regconfig' in self.type.options: + kwargs['postgresql_regconfig'] = ( + self.type.options['regconfig'] + ) + return TSVECTOR.Comparator.match(self, other, **kwargs) + + def __or__(self, other): + return self.op('||')(other) + def __init__(self, *args, **kwargs): """ Initializes new TSVectorType @@ -28,12 +104,4 @@ class TSVectorType(sa.types.UserDefinedType): """ self.columns = args self.options = kwargs - - """ - Text search vector type for postgresql. - """ - def get_col_spec(self): - return 'tsvector' - - -ischema_names['tsvector'] = TSVectorType + super(TSVectorType, self).__init__() diff --git a/tests/test_expressions.py b/tests/test_expressions.py index 27c11fc..5bf4884 100644 --- a/tests/test_expressions.py +++ b/tests/test_expressions.py @@ -93,53 +93,3 @@ class TestExplainAnalyze(ExpressionTestCase): dialect=postgresql.dialect() ) ).startswith('EXPLAIN (ANALYZE true) SELECT') - - -class TestMatchTSVector(ExpressionTestCase): - def test_raises_exception_if_less_than_2_parameters_given(self): - with raises(Exception): - str( - tsvector_match( - self.Article.search_vector, - ) - ) - - def test_supports_postgres(self): - assert str(tsvector_match( - self.Article.search_vector, - to_tsquery('something'), - )) == '(article.search_vector) @@ to_tsquery(:to_tsquery_1)' - - -class TestToTSQuery(ExpressionTestCase): - def test_requires_atleast_one_parameter(self): - with raises(Exception): - str(to_tsquery()) - - def test_supports_postgres(self): - assert str(to_tsquery('something')) == 'to_tsquery(:to_tsquery_1)' - - -class TestPlainToTSQuery(ExpressionTestCase): - def test_requires_atleast_one_parameter(self): - with raises(Exception): - str(plainto_tsquery()) - - def test_supports_postgres(self): - assert str(plainto_tsquery('s')) == ( - 'plainto_tsquery(:plainto_tsquery_1)' - ) - - -class TestConcatTSVector(ExpressionTestCase): - def test_concatenate_search_vectors(self): - assert str(tsvector_match( - tsvector_concat( - self.Article.search_vector, - self.Article.search_vector2 - ), - to_tsquery('finnish', 'something'), - )) == ( - '(article.search_vector || article.search_vector2) ' - '@@ to_tsquery(:to_tsquery_1, :to_tsquery_2)' - ) diff --git a/tests/types/test_tsvector.py b/tests/types/test_tsvector.py index 58b22e6..bb5fda9 100644 --- a/tests/types/test_tsvector.py +++ b/tests/types/test_tsvector.py @@ -1,5 +1,5 @@ -import six import sqlalchemy as sa +from sqlalchemy.dialects.postgresql import TSVECTOR from sqlalchemy_utils import TSVectorType from tests import TestCase @@ -12,7 +12,9 @@ class TestTSVector(TestCase): __tablename__ = 'user' id = sa.Column(sa.Integer, primary_key=True) name = sa.Column(sa.Unicode(255)) - search_index = sa.Column(TSVectorType()) + search_index = sa.Column( + TSVectorType(name, regconfig='pg_catalog.finnish') + ) def __repr__(self): return 'User(%r)' % self.id @@ -22,7 +24,7 @@ class TestTSVector(TestCase): def test_generates_table(self): assert 'search_index' in self.User.__table__.c - def test_type_autoloading(self): + def test_type_reflection(self): reflected_metadata = sa.schema.MetaData() table = sa.schema.Table( 'user', @@ -30,23 +32,39 @@ class TestTSVector(TestCase): autoload=True, autoload_with=self.engine ) - assert isinstance(table.c['search_index'].type, TSVectorType) + assert isinstance(table.c['search_index'].type, TSVECTOR) def test_catalog_and_columns_as_args(self): - type_ = TSVectorType('name', 'age', catalog='pg_catalog.simple') + type_ = TSVectorType('name', 'age', regconfig='pg_catalog.simple') assert type_.columns == ('name', 'age') - assert type_.options['catalog'] == 'pg_catalog.simple' + assert type_.options['regconfig'] == 'pg_catalog.simple' def test_match(self): - expr = self.User.search_index.match_tsquery(u'something') - assert six.text_type(expr) == ( - '("user".search_index) @@ to_tsquery(:to_tsquery_1)' + expr = self.User.search_index.match(u'something') + assert str(expr.compile(self.connection)) == ( + '''"user".search_index @@ to_tsquery('pg_catalog.finnish', ''' + '''%(search_index_1)s)''' + ) + + def test_concat(self): + assert str(self.User.search_index | self.User.search_index) == ( + '"user".search_index || "user".search_index' + ) + + def test_match_concatenation(self): + concat = self.User.search_index | self.User.search_index + bind = self.session.bind + assert str(concat.match('something').compile(bind)) == ( + '("user".search_index || "user".search_index) @@ ' + "to_tsquery('pg_catalog.finnish', %(param_1)s)" ) def test_match_with_catalog(self): - expr = self.User.search_index.match_tsquery( - u'something', catalog='pg_catalog.simple' + expr = self.User.search_index.match( + u'something', + postgresql_regconfig='pg_catalog.simple' ) - assert six.text_type(expr) == ( - '("user".search_index) @@ to_tsquery(:to_tsquery_1, :to_tsquery_2)' + assert str(expr.compile(self.connection)) == ( + '''"user".search_index @@ to_tsquery('pg_catalog.simple', ''' + '''%(search_index_1)s)''' )