Add array_agg GenericFunction

This commit is contained in:
Konsta Vesterinen
2015-06-05 13:40:09 +03:00
parent 4852ffd80e
commit 89852542f2
4 changed files with 54 additions and 12 deletions

View File

@@ -8,7 +8,8 @@ Here you can see the full list of changes between each SQLAlchemy-Utils release.
^^^^^^^^^^^^^^^^^^^ ^^^^^^^^^^^^^^^^^^^
- Added Asterisk compiler - Added Asterisk compiler
- Added row_to_json FunctionElement - Added row_to_json GenericFunction
- Added array_agg GenericFunction
- Made quote function accept dialect object as the first paremeter - Made quote function accept dialect object as the first paremeter
- Made has_index work with tables without primary keys (#148) - Made has_index work with tables without primary keys (#148)

View File

@@ -8,7 +8,7 @@ from .asserts import ( # noqa
) )
from .exceptions import ImproperlyConfigured # noqa from .exceptions import ImproperlyConfigured # noqa
from .expression_parser import ExpressionParser # noqa from .expression_parser import ExpressionParser # noqa
from .expressions import Asterisk, row_to_json # noqa from .expressions import array_agg, Asterisk, row_to_json # noqa
from .functions import ( # noqa from .functions import ( # noqa
analyze, analyze,
create_database, create_database,

View File

@@ -1,5 +1,5 @@
import sqlalchemy as sa import sqlalchemy as sa
from sqlalchemy.dialects.postgresql import JSON from sqlalchemy.dialects.postgresql import ARRAY, JSON
from sqlalchemy.ext.compiler import compiles from sqlalchemy.ext.compiler import compiles
from sqlalchemy.sql.elements import ColumnClause from sqlalchemy.sql.elements import ColumnClause
from sqlalchemy.sql.expression import ( from sqlalchemy.sql.expression import (
@@ -9,6 +9,7 @@ from sqlalchemy.sql.expression import (
Executable, Executable,
FunctionElement FunctionElement
) )
from sqlalchemy.sql.functions import GenericFunction
from sqlalchemy_utils.functions.orm import quote from sqlalchemy_utils.functions.orm import quote
@@ -92,7 +93,7 @@ def compile_array_get(element, compiler, **kw):
) )
class row_to_json(FunctionElement): class row_to_json(GenericFunction):
name = 'row_to_json' name = 'row_to_json'
type = JSON type = JSON
@@ -102,6 +103,30 @@ def compile_row_to_json(element, compiler, **kw):
return "%s(%s)" % (element.name, compiler.process(element.clauses)) return "%s(%s)" % (element.name, compiler.process(element.clauses))
class json_array_length(GenericFunction):
name = 'json_array_length'
type = sa.Integer
@compiles(json_array_length, 'postgresql')
def compile_json_array_length(element, compiler, **kw):
return "%s(%s)" % (element.name, compiler.process(element.clauses))
class array_agg(GenericFunction):
name = 'array_agg'
type = ARRAY
def __init__(self, arg, **kw):
self.type = ARRAY(arg.type)
GenericFunction.__init__(self, arg, **kw)
@compiles(array_agg, 'postgresql')
def compile_json_array_length(element, compiler, **kw):
return "%s(%s)" % (element.name, compiler.process(element.clauses))
class Asterisk(ColumnElement): class Asterisk(ColumnElement):
def __init__(self, selectable): def __init__(self, selectable):
self.selectable = selectable self.selectable = selectable

View File

@@ -2,7 +2,7 @@ import sqlalchemy as sa
from pytest import raises from pytest import raises
from sqlalchemy.dialects import postgresql from sqlalchemy.dialects import postgresql
from sqlalchemy_utils import Asterisk, row_to_json from sqlalchemy_utils import array_agg, Asterisk, row_to_json
from sqlalchemy_utils.expressions import explain, explain_analyze from sqlalchemy_utils.expressions import explain, explain_analyze
from tests import TestCase from tests import TestCase
@@ -112,16 +112,32 @@ class TestAsterisk(object):
class TestRowToJson(object): class TestRowToJson(object):
def test_compiler_with_default_dialect(self): def test_compiler_with_default_dialect(self):
with raises(sa.exc.CompileError): with raises(sa.exc.CompileError):
assert str(row_to_json(sa.text('article.*'))) == ( str(row_to_json(sa.text('article.*')))
'row_to_json(article.*)'
)
def test_compiler_with_postgresql(self): def test_compiler_with_postgresql(self):
assert str(row_to_json(sa.text('article.*')).compile( assert str(row_to_json(sa.text('article.*')).compile(
dialect=postgresql.dialect() dialect=postgresql.dialect()
)) == ( )) == 'row_to_json(article.*)'
'row_to_json(article.*)'
)
def test_type(self): def test_type(self):
assert row_to_json(sa.text('article.*')).type == postgresql.JSON assert isinstance(
sa.func.row_to_json(sa.text('article.*')).type,
postgresql.JSON
)
class TestArrayAgg(object):
def test_compiler_with_default_dialect(self):
with raises(sa.exc.CompileError):
str(array_agg(sa.text('u.name')))
def test_compiler_with_postgresql(self):
assert str(array_agg(sa.text('u.name')).compile(
dialect=postgresql.dialect()
)) == "array_agg(u.name)"
def test_type(self):
assert isinstance(
sa.func.array_agg(sa.text('u.name')).type,
postgresql.ARRAY
)