From 55c65beb4bd4ca08b4ff3776c574be7317b4845d Mon Sep 17 00:00:00 2001 From: Konsta Vesterinen Date: Mon, 8 Jun 2015 09:42:20 +0300 Subject: [PATCH] Add default param for array_agg --- CHANGES.rst | 3 ++- sqlalchemy_utils/__init__.py | 2 +- sqlalchemy_utils/expressions.py | 13 ++++++++++--- tests/functions/test_get_type.py | 1 - tests/test_expressions.py | 20 +++++++++++++++++--- 5 files changed, 30 insertions(+), 9 deletions(-) diff --git a/CHANGES.rst b/CHANGES.rst index f3a4d15..b8e27f0 100644 --- a/CHANGES.rst +++ b/CHANGES.rst @@ -4,10 +4,11 @@ Changelog Here you can see the full list of changes between each SQLAlchemy-Utils release. -0.30.9 (2015-06-05) +0.30.9 (2015-06-08) ^^^^^^^^^^^^^^^^^^^ - Added get_type utility function +- Added default parameter for array_agg function 0.30.8 (2015-06-05) diff --git a/sqlalchemy_utils/__init__.py b/sqlalchemy_utils/__init__.py index 8acf281..aaa425a 100644 --- a/sqlalchemy_utils/__init__.py +++ b/sqlalchemy_utils/__init__.py @@ -8,7 +8,7 @@ from .asserts import ( # noqa ) from .exceptions import ImproperlyConfigured # noqa from .expression_parser import ExpressionParser # noqa -from .expressions import array_agg, Asterisk, row_to_json # noqa +from .expressions import Asterisk, row_to_json # noqa from .functions import ( # noqa analyze, create_database, diff --git a/sqlalchemy_utils/expressions.py b/sqlalchemy_utils/expressions.py index d92f081..4d2c112 100644 --- a/sqlalchemy_utils/expressions.py +++ b/sqlalchemy_utils/expressions.py @@ -1,5 +1,5 @@ import sqlalchemy as sa -from sqlalchemy.dialects.postgresql import ARRAY, JSON +from sqlalchemy.dialects.postgresql import array, ARRAY, JSON from sqlalchemy.ext.compiler import compiles from sqlalchemy.sql.expression import ( _literal_as_text, @@ -116,14 +116,21 @@ class array_agg(GenericFunction): name = 'array_agg' type = ARRAY - def __init__(self, arg, **kw): + def __init__(self, arg, default=None, **kw): self.type = ARRAY(arg.type) + self.default = default GenericFunction.__init__(self, arg, **kw) @compiles(array_agg, 'postgresql') def compile_array_agg(element, compiler, **kw): - return "%s(%s)" % (element.name, compiler.process(element.clauses)) + compiled = "%s(%s)" % (element.name, compiler.process(element.clauses)) + if element.default is None: + return compiled + return str(sa.func.coalesce( + sa.text(compiled), + sa.cast(array(element.default), element.type) + ).compile(compiler)) class Asterisk(ColumnElement): diff --git a/tests/functions/test_get_type.py b/tests/functions/test_get_type.py index 0287325..8990be4 100644 --- a/tests/functions/test_get_type.py +++ b/tests/functions/test_get_type.py @@ -2,7 +2,6 @@ import sqlalchemy as sa from sqlalchemy.ext.declarative import declarative_base from sqlalchemy_utils import get_type -from tests import TestCase class TestGetType(object): diff --git a/tests/test_expressions.py b/tests/test_expressions.py index 8945820..1bcd223 100644 --- a/tests/test_expressions.py +++ b/tests/test_expressions.py @@ -2,7 +2,7 @@ import sqlalchemy as sa from pytest import raises from sqlalchemy.dialects import postgresql -from sqlalchemy_utils import array_agg, Asterisk, row_to_json +from sqlalchemy_utils import Asterisk, row_to_json from sqlalchemy_utils.expressions import explain, explain_analyze from tests import TestCase @@ -129,10 +129,10 @@ class TestRowToJson(object): class TestArrayAgg(object): def test_compiler_with_default_dialect(self): with raises(sa.exc.CompileError): - str(array_agg(sa.text('u.name'))) + str(sa.func.array_agg(sa.text('u.name'))) def test_compiler_with_postgresql(self): - assert str(array_agg(sa.text('u.name')).compile( + assert str(sa.func.array_agg(sa.text('u.name')).compile( dialect=postgresql.dialect() )) == "array_agg(u.name)" @@ -141,3 +141,17 @@ class TestArrayAgg(object): sa.func.array_agg(sa.text('u.name')).type, postgresql.ARRAY ) + + def test_array_agg_with_default(self): + Base = sa.ext.declarative.declarative_base() + + class Article(Base): + __tablename__ = 'article' + id = sa.Column(sa.Integer, primary_key=True) + + assert str(sa.func.array_agg(Article.id, [1]).compile( + dialect=postgresql.dialect() + )) == ( + 'coalesce(array_agg(article.id), CAST(ARRAY[%(param_1)s]' + ' AS INTEGER[]))' + )