Add array_agg GenericFunction

This commit is contained in:
Konsta Vesterinen
2015-06-05 13:40:09 +03:00
parent 4852ffd80e
commit 89852542f2
4 changed files with 54 additions and 12 deletions

View File

@@ -8,7 +8,8 @@ Here you can see the full list of changes between each SQLAlchemy-Utils release.
^^^^^^^^^^^^^^^^^^^
- Added Asterisk compiler
- Added row_to_json FunctionElement
- Added row_to_json GenericFunction
- Added array_agg GenericFunction
- Made quote function accept dialect object as the first paremeter
- Made has_index work with tables without primary keys (#148)

View File

@@ -8,7 +8,7 @@ from .asserts import ( # noqa
)
from .exceptions import ImproperlyConfigured # noqa
from .expression_parser import ExpressionParser # noqa
from .expressions import Asterisk, row_to_json # noqa
from .expressions import array_agg, Asterisk, row_to_json # noqa
from .functions import ( # noqa
analyze,
create_database,

View File

@@ -1,5 +1,5 @@
import sqlalchemy as sa
from sqlalchemy.dialects.postgresql import JSON
from sqlalchemy.dialects.postgresql import ARRAY, JSON
from sqlalchemy.ext.compiler import compiles
from sqlalchemy.sql.elements import ColumnClause
from sqlalchemy.sql.expression import (
@@ -9,6 +9,7 @@ from sqlalchemy.sql.expression import (
Executable,
FunctionElement
)
from sqlalchemy.sql.functions import GenericFunction
from sqlalchemy_utils.functions.orm import quote
@@ -92,7 +93,7 @@ def compile_array_get(element, compiler, **kw):
)
class row_to_json(FunctionElement):
class row_to_json(GenericFunction):
name = 'row_to_json'
type = JSON
@@ -102,6 +103,30 @@ def compile_row_to_json(element, compiler, **kw):
return "%s(%s)" % (element.name, compiler.process(element.clauses))
class json_array_length(GenericFunction):
name = 'json_array_length'
type = sa.Integer
@compiles(json_array_length, 'postgresql')
def compile_json_array_length(element, compiler, **kw):
return "%s(%s)" % (element.name, compiler.process(element.clauses))
class array_agg(GenericFunction):
name = 'array_agg'
type = ARRAY
def __init__(self, arg, **kw):
self.type = ARRAY(arg.type)
GenericFunction.__init__(self, arg, **kw)
@compiles(array_agg, 'postgresql')
def compile_json_array_length(element, compiler, **kw):
return "%s(%s)" % (element.name, compiler.process(element.clauses))
class Asterisk(ColumnElement):
def __init__(self, selectable):
self.selectable = selectable

View File

@@ -2,7 +2,7 @@ import sqlalchemy as sa
from pytest import raises
from sqlalchemy.dialects import postgresql
from sqlalchemy_utils import Asterisk, row_to_json
from sqlalchemy_utils import array_agg, Asterisk, row_to_json
from sqlalchemy_utils.expressions import explain, explain_analyze
from tests import TestCase
@@ -112,16 +112,32 @@ class TestAsterisk(object):
class TestRowToJson(object):
def test_compiler_with_default_dialect(self):
with raises(sa.exc.CompileError):
assert str(row_to_json(sa.text('article.*'))) == (
'row_to_json(article.*)'
)
str(row_to_json(sa.text('article.*')))
def test_compiler_with_postgresql(self):
assert str(row_to_json(sa.text('article.*')).compile(
dialect=postgresql.dialect()
)) == (
'row_to_json(article.*)'
)
)) == 'row_to_json(article.*)'
def test_type(self):
assert row_to_json(sa.text('article.*')).type == postgresql.JSON
assert isinstance(
sa.func.row_to_json(sa.text('article.*')).type,
postgresql.JSON
)
class TestArrayAgg(object):
def test_compiler_with_default_dialect(self):
with raises(sa.exc.CompileError):
str(array_agg(sa.text('u.name')))
def test_compiler_with_postgresql(self):
assert str(array_agg(sa.text('u.name')).compile(
dialect=postgresql.dialect()
)) == "array_agg(u.name)"
def test_type(self):
assert isinstance(
sa.func.array_agg(sa.text('u.name')).type,
postgresql.ARRAY
)