diff --git a/CHANGES.rst b/CHANGES.rst index c874177..a5ae7da 100644 --- a/CHANGES.rst +++ b/CHANGES.rst @@ -8,7 +8,8 @@ Here you can see the full list of changes between each SQLAlchemy-Utils release. ^^^^^^^^^^^^^^^^^^^ - Added Asterisk compiler -- Added row_to_json FunctionElement +- Added row_to_json GenericFunction +- Added array_agg GenericFunction - Made quote function accept dialect object as the first paremeter - Made has_index work with tables without primary keys (#148) diff --git a/sqlalchemy_utils/__init__.py b/sqlalchemy_utils/__init__.py index bdbfb46..8b85372 100644 --- a/sqlalchemy_utils/__init__.py +++ b/sqlalchemy_utils/__init__.py @@ -8,7 +8,7 @@ from .asserts import ( # noqa ) from .exceptions import ImproperlyConfigured # noqa from .expression_parser import ExpressionParser # noqa -from .expressions import Asterisk, row_to_json # noqa +from .expressions import array_agg, Asterisk, row_to_json # noqa from .functions import ( # noqa analyze, create_database, diff --git a/sqlalchemy_utils/expressions.py b/sqlalchemy_utils/expressions.py index c0f5fcf..12a6bcd 100644 --- a/sqlalchemy_utils/expressions.py +++ b/sqlalchemy_utils/expressions.py @@ -1,5 +1,5 @@ import sqlalchemy as sa -from sqlalchemy.dialects.postgresql import JSON +from sqlalchemy.dialects.postgresql import ARRAY, JSON from sqlalchemy.ext.compiler import compiles from sqlalchemy.sql.elements import ColumnClause from sqlalchemy.sql.expression import ( @@ -9,6 +9,7 @@ from sqlalchemy.sql.expression import ( Executable, FunctionElement ) +from sqlalchemy.sql.functions import GenericFunction from sqlalchemy_utils.functions.orm import quote @@ -92,7 +93,7 @@ def compile_array_get(element, compiler, **kw): ) -class row_to_json(FunctionElement): +class row_to_json(GenericFunction): name = 'row_to_json' type = JSON @@ -102,6 +103,30 @@ def compile_row_to_json(element, compiler, **kw): return "%s(%s)" % (element.name, compiler.process(element.clauses)) +class json_array_length(GenericFunction): + name = 'json_array_length' + type = sa.Integer + + +@compiles(json_array_length, 'postgresql') +def compile_json_array_length(element, compiler, **kw): + return "%s(%s)" % (element.name, compiler.process(element.clauses)) + + +class array_agg(GenericFunction): + name = 'array_agg' + type = ARRAY + + def __init__(self, arg, **kw): + self.type = ARRAY(arg.type) + GenericFunction.__init__(self, arg, **kw) + + +@compiles(array_agg, 'postgresql') +def compile_json_array_length(element, compiler, **kw): + return "%s(%s)" % (element.name, compiler.process(element.clauses)) + + class Asterisk(ColumnElement): def __init__(self, selectable): self.selectable = selectable diff --git a/tests/test_expressions.py b/tests/test_expressions.py index d73ad02..b6de6b6 100644 --- a/tests/test_expressions.py +++ b/tests/test_expressions.py @@ -2,7 +2,7 @@ import sqlalchemy as sa from pytest import raises from sqlalchemy.dialects import postgresql -from sqlalchemy_utils import Asterisk, row_to_json +from sqlalchemy_utils import array_agg, Asterisk, row_to_json from sqlalchemy_utils.expressions import explain, explain_analyze from tests import TestCase @@ -112,16 +112,32 @@ class TestAsterisk(object): class TestRowToJson(object): def test_compiler_with_default_dialect(self): with raises(sa.exc.CompileError): - assert str(row_to_json(sa.text('article.*'))) == ( - 'row_to_json(article.*)' - ) + str(row_to_json(sa.text('article.*'))) def test_compiler_with_postgresql(self): assert str(row_to_json(sa.text('article.*')).compile( dialect=postgresql.dialect() - )) == ( - 'row_to_json(article.*)' - ) + )) == 'row_to_json(article.*)' def test_type(self): - assert row_to_json(sa.text('article.*')).type == postgresql.JSON + assert isinstance( + sa.func.row_to_json(sa.text('article.*')).type, + postgresql.JSON + ) + + +class TestArrayAgg(object): + def test_compiler_with_default_dialect(self): + with raises(sa.exc.CompileError): + str(array_agg(sa.text('u.name'))) + + def test_compiler_with_postgresql(self): + assert str(array_agg(sa.text('u.name')).compile( + dialect=postgresql.dialect() + )) == "array_agg(u.name)" + + def test_type(self): + assert isinstance( + sa.func.array_agg(sa.text('u.name')).type, + postgresql.ARRAY + )