Add default param for array_agg
This commit is contained in:
@@ -4,10 +4,11 @@ Changelog
|
||||
Here you can see the full list of changes between each SQLAlchemy-Utils release.
|
||||
|
||||
|
||||
0.30.9 (2015-06-05)
|
||||
0.30.9 (2015-06-08)
|
||||
^^^^^^^^^^^^^^^^^^^
|
||||
|
||||
- Added get_type utility function
|
||||
- Added default parameter for array_agg function
|
||||
|
||||
|
||||
0.30.8 (2015-06-05)
|
||||
|
@@ -8,7 +8,7 @@ from .asserts import ( # noqa
|
||||
)
|
||||
from .exceptions import ImproperlyConfigured # noqa
|
||||
from .expression_parser import ExpressionParser # noqa
|
||||
from .expressions import array_agg, Asterisk, row_to_json # noqa
|
||||
from .expressions import Asterisk, row_to_json # noqa
|
||||
from .functions import ( # noqa
|
||||
analyze,
|
||||
create_database,
|
||||
|
@@ -1,5 +1,5 @@
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy.dialects.postgresql import ARRAY, JSON
|
||||
from sqlalchemy.dialects.postgresql import array, ARRAY, JSON
|
||||
from sqlalchemy.ext.compiler import compiles
|
||||
from sqlalchemy.sql.expression import (
|
||||
_literal_as_text,
|
||||
@@ -116,14 +116,21 @@ class array_agg(GenericFunction):
|
||||
name = 'array_agg'
|
||||
type = ARRAY
|
||||
|
||||
def __init__(self, arg, **kw):
|
||||
def __init__(self, arg, default=None, **kw):
|
||||
self.type = ARRAY(arg.type)
|
||||
self.default = default
|
||||
GenericFunction.__init__(self, arg, **kw)
|
||||
|
||||
|
||||
@compiles(array_agg, 'postgresql')
|
||||
def compile_array_agg(element, compiler, **kw):
|
||||
return "%s(%s)" % (element.name, compiler.process(element.clauses))
|
||||
compiled = "%s(%s)" % (element.name, compiler.process(element.clauses))
|
||||
if element.default is None:
|
||||
return compiled
|
||||
return str(sa.func.coalesce(
|
||||
sa.text(compiled),
|
||||
sa.cast(array(element.default), element.type)
|
||||
).compile(compiler))
|
||||
|
||||
|
||||
class Asterisk(ColumnElement):
|
||||
|
@@ -2,7 +2,6 @@ import sqlalchemy as sa
|
||||
from sqlalchemy.ext.declarative import declarative_base
|
||||
|
||||
from sqlalchemy_utils import get_type
|
||||
from tests import TestCase
|
||||
|
||||
|
||||
class TestGetType(object):
|
||||
|
@@ -2,7 +2,7 @@ import sqlalchemy as sa
|
||||
from pytest import raises
|
||||
from sqlalchemy.dialects import postgresql
|
||||
|
||||
from sqlalchemy_utils import array_agg, Asterisk, row_to_json
|
||||
from sqlalchemy_utils import Asterisk, row_to_json
|
||||
from sqlalchemy_utils.expressions import explain, explain_analyze
|
||||
from tests import TestCase
|
||||
|
||||
@@ -129,10 +129,10 @@ class TestRowToJson(object):
|
||||
class TestArrayAgg(object):
|
||||
def test_compiler_with_default_dialect(self):
|
||||
with raises(sa.exc.CompileError):
|
||||
str(array_agg(sa.text('u.name')))
|
||||
str(sa.func.array_agg(sa.text('u.name')))
|
||||
|
||||
def test_compiler_with_postgresql(self):
|
||||
assert str(array_agg(sa.text('u.name')).compile(
|
||||
assert str(sa.func.array_agg(sa.text('u.name')).compile(
|
||||
dialect=postgresql.dialect()
|
||||
)) == "array_agg(u.name)"
|
||||
|
||||
@@ -141,3 +141,17 @@ class TestArrayAgg(object):
|
||||
sa.func.array_agg(sa.text('u.name')).type,
|
||||
postgresql.ARRAY
|
||||
)
|
||||
|
||||
def test_array_agg_with_default(self):
|
||||
Base = sa.ext.declarative.declarative_base()
|
||||
|
||||
class Article(Base):
|
||||
__tablename__ = 'article'
|
||||
id = sa.Column(sa.Integer, primary_key=True)
|
||||
|
||||
assert str(sa.func.array_agg(Article.id, [1]).compile(
|
||||
dialect=postgresql.dialect()
|
||||
)) == (
|
||||
'coalesce(array_agg(article.id), CAST(ARRAY[%(param_1)s]'
|
||||
' AS INTEGER[]))'
|
||||
)
|
||||
|
Reference in New Issue
Block a user