144 lines
3.6 KiB
Python
144 lines
3.6 KiB
Python
import sqlalchemy as sa
|
|
from sqlalchemy.dialects.postgresql import array, ARRAY, JSON
|
|
from sqlalchemy.ext.compiler import compiles
|
|
from sqlalchemy.sql.expression import (
|
|
_literal_as_text,
|
|
ClauseElement,
|
|
ColumnElement,
|
|
Executable,
|
|
FunctionElement
|
|
)
|
|
from sqlalchemy.sql.functions import GenericFunction
|
|
|
|
from sqlalchemy_utils.functions.orm import quote
|
|
|
|
|
|
class explain(Executable, ClauseElement):
|
|
"""
|
|
Define EXPLAIN element.
|
|
|
|
http://www.postgresql.org/docs/devel/static/sql-explain.html
|
|
"""
|
|
def __init__(
|
|
self,
|
|
stmt,
|
|
analyze=False,
|
|
verbose=False,
|
|
costs=True,
|
|
buffers=False,
|
|
timing=True,
|
|
format='text'
|
|
):
|
|
self.statement = _literal_as_text(stmt)
|
|
self.analyze = analyze
|
|
self.verbose = verbose
|
|
self.costs = costs
|
|
self.buffers = buffers
|
|
self.timing = timing
|
|
self.format = format
|
|
|
|
|
|
class explain_analyze(explain):
|
|
def __init__(self, stmt, **kwargs):
|
|
super(explain_analyze, self).__init__(
|
|
stmt,
|
|
analyze=True,
|
|
**kwargs
|
|
)
|
|
|
|
|
|
@compiles(explain, 'postgresql')
|
|
def pg_explain(element, compiler, **kw):
|
|
text = "EXPLAIN "
|
|
options = []
|
|
if element.analyze:
|
|
options.append('ANALYZE true')
|
|
if not element.timing:
|
|
options.append('TIMING false')
|
|
if element.buffers:
|
|
options.append('BUFFERS true')
|
|
if element.format != 'text':
|
|
options.append('FORMAT %s' % element.format)
|
|
if element.verbose:
|
|
options.append('VERBOSE true')
|
|
if not element.costs:
|
|
options.append('COSTS false')
|
|
if options:
|
|
text += '(%s) ' % ', '.join(options)
|
|
text += compiler.process(element.statement)
|
|
return text
|
|
|
|
|
|
class array_get(FunctionElement):
|
|
name = 'array_get'
|
|
|
|
|
|
@compiles(array_get)
|
|
def compile_array_get(element, compiler, **kw):
|
|
args = list(element.clauses)
|
|
if len(args) != 2:
|
|
raise Exception(
|
|
"Function 'array_get' expects two arguments (%d given)." %
|
|
len(args)
|
|
)
|
|
|
|
if not hasattr(args[1], 'value') or not isinstance(args[1].value, int):
|
|
raise Exception(
|
|
"Second argument should be an integer."
|
|
)
|
|
return '(%s)[%s]' % (
|
|
compiler.process(args[0]),
|
|
sa.text(str(args[1].value + 1))
|
|
)
|
|
|
|
|
|
class row_to_json(GenericFunction):
|
|
name = 'row_to_json'
|
|
type = JSON
|
|
|
|
|
|
@compiles(row_to_json, 'postgresql')
|
|
def compile_row_to_json(element, compiler, **kw):
|
|
return "%s(%s)" % (element.name, compiler.process(element.clauses))
|
|
|
|
|
|
class json_array_length(GenericFunction):
|
|
name = 'json_array_length'
|
|
type = sa.Integer
|
|
|
|
|
|
@compiles(json_array_length, 'postgresql')
|
|
def compile_json_array_length(element, compiler, **kw):
|
|
return "%s(%s)" % (element.name, compiler.process(element.clauses))
|
|
|
|
|
|
class array_agg(GenericFunction):
|
|
name = 'array_agg'
|
|
type = ARRAY
|
|
|
|
def __init__(self, arg, default=None, **kw):
|
|
self.type = ARRAY(arg.type)
|
|
self.default = default
|
|
GenericFunction.__init__(self, arg, **kw)
|
|
|
|
|
|
@compiles(array_agg, 'postgresql')
|
|
def compile_array_agg(element, compiler, **kw):
|
|
compiled = "%s(%s)" % (element.name, compiler.process(element.clauses))
|
|
if element.default is None:
|
|
return compiled
|
|
return str(sa.func.coalesce(
|
|
sa.text(compiled),
|
|
sa.cast(array(element.default), element.type)
|
|
).compile(compiler))
|
|
|
|
|
|
class Asterisk(ColumnElement):
|
|
def __init__(self, selectable):
|
|
self.selectable = selectable
|
|
|
|
|
|
@compiles(Asterisk)
|
|
def compile_asterisk(element, compiler, **kw):
|
|
return '%s.*' % quote(compiler.dialect, element.selectable.name)
|