Add default param for array_agg

This commit is contained in:
Konsta Vesterinen
2015-06-08 09:42:20 +03:00
parent c0bd2c2bca
commit 55c65beb4b
5 changed files with 30 additions and 9 deletions

View File

@@ -4,10 +4,11 @@ Changelog
Here you can see the full list of changes between each SQLAlchemy-Utils release.
0.30.9 (2015-06-05)
0.30.9 (2015-06-08)
^^^^^^^^^^^^^^^^^^^
- Added get_type utility function
- Added default parameter for array_agg function
0.30.8 (2015-06-05)

View File

@@ -8,7 +8,7 @@ from .asserts import ( # noqa
)
from .exceptions import ImproperlyConfigured # noqa
from .expression_parser import ExpressionParser # noqa
from .expressions import array_agg, Asterisk, row_to_json # noqa
from .expressions import Asterisk, row_to_json # noqa
from .functions import ( # noqa
analyze,
create_database,

View File

@@ -1,5 +1,5 @@
import sqlalchemy as sa
from sqlalchemy.dialects.postgresql import ARRAY, JSON
from sqlalchemy.dialects.postgresql import array, ARRAY, JSON
from sqlalchemy.ext.compiler import compiles
from sqlalchemy.sql.expression import (
_literal_as_text,
@@ -116,14 +116,21 @@ class array_agg(GenericFunction):
name = 'array_agg'
type = ARRAY
def __init__(self, arg, **kw):
def __init__(self, arg, default=None, **kw):
self.type = ARRAY(arg.type)
self.default = default
GenericFunction.__init__(self, arg, **kw)
@compiles(array_agg, 'postgresql')
def compile_array_agg(element, compiler, **kw):
return "%s(%s)" % (element.name, compiler.process(element.clauses))
compiled = "%s(%s)" % (element.name, compiler.process(element.clauses))
if element.default is None:
return compiled
return str(sa.func.coalesce(
sa.text(compiled),
sa.cast(array(element.default), element.type)
).compile(compiler))
class Asterisk(ColumnElement):

View File

@@ -2,7 +2,6 @@ import sqlalchemy as sa
from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy_utils import get_type
from tests import TestCase
class TestGetType(object):

View File

@@ -2,7 +2,7 @@ import sqlalchemy as sa
from pytest import raises
from sqlalchemy.dialects import postgresql
from sqlalchemy_utils import array_agg, Asterisk, row_to_json
from sqlalchemy_utils import Asterisk, row_to_json
from sqlalchemy_utils.expressions import explain, explain_analyze
from tests import TestCase
@@ -129,10 +129,10 @@ class TestRowToJson(object):
class TestArrayAgg(object):
def test_compiler_with_default_dialect(self):
with raises(sa.exc.CompileError):
str(array_agg(sa.text('u.name')))
str(sa.func.array_agg(sa.text('u.name')))
def test_compiler_with_postgresql(self):
assert str(array_agg(sa.text('u.name')).compile(
assert str(sa.func.array_agg(sa.text('u.name')).compile(
dialect=postgresql.dialect()
)) == "array_agg(u.name)"
@@ -141,3 +141,17 @@ class TestArrayAgg(object):
sa.func.array_agg(sa.text('u.name')).type,
postgresql.ARRAY
)
def test_array_agg_with_default(self):
Base = sa.ext.declarative.declarative_base()
class Article(Base):
__tablename__ = 'article'
id = sa.Column(sa.Integer, primary_key=True)
assert str(sa.func.array_agg(Article.id, [1]).compile(
dialect=postgresql.dialect()
)) == (
'coalesce(array_agg(article.id), CAST(ARRAY[%(param_1)s]'
' AS INTEGER[]))'
)