Add expression parser

This commit is contained in:
Konsta Vesterinen
2014-04-22 12:59:01 +03:00
parent d541a77dc1
commit 5e42d32a84
3 changed files with 224 additions and 0 deletions

View File

@@ -2,6 +2,7 @@ from .aggregates import aggregated
from .batch import batch_fetch, with_backrefs
from .decorators import generates
from .exceptions import ImproperlyConfigured
from .expression_parser import ExpressionParser
from .functions import (
create_database,
create_mock_engine,

View File

@@ -0,0 +1,144 @@
import six
import sqlalchemy as sa
from sqlalchemy.sql.annotation import AnnotatedColumn
from sqlalchemy.sql.expression import (
BooleanClauseList,
BinaryExpression,
UnaryExpression,
BindParameter,
Cast,
)
from sqlalchemy.sql.elements import (
False_,
True_,
Grouping,
ClauseList,
Label,
Case,
Tuple,
Null
)
class ExpressionParser(object):
def expression(self, expr):
if expr is None:
return
if isinstance(expr, BinaryExpression):
return self.binary_expression(expr)
elif isinstance(expr, BooleanClauseList):
return self.boolean_expression(expr)
elif isinstance(expr, UnaryExpression):
return self.unary_expression(expr)
elif isinstance(expr, sa.Column):
return self.column(expr)
elif isinstance(expr, AnnotatedColumn):
return self.column(expr)
elif isinstance(expr, BindParameter):
return self.bind_parameter(expr)
elif isinstance(expr, False_):
return self.false(expr)
elif isinstance(expr, True_):
return self.true(expr)
elif isinstance(expr, Grouping):
return self.grouping(expr)
elif isinstance(expr, ClauseList):
return self.clause_list(expr)
elif isinstance(expr, Label):
return self.label(expr)
elif isinstance(expr, Cast):
return self.cast(expr)
elif isinstance(expr, Case):
return self.case(expr)
elif isinstance(expr, Tuple):
return self.tuple(expr)
elif isinstance(expr, Null):
return self.null(expr)
raise Exception(
'Unknown expression type %s' % expr.__class__.__name__
)
def null(self, expr):
return expr
def tuple(self, expr):
return expr.__class__(
*map(self.expression, expr.clauses),
type_=expr.type
)
def clause_list(self, expr):
return expr.__class__(
*map(self.expression, expr.clauses),
group=expr.group,
group_contents=expr.group_contents,
operator=expr.operator
)
def label(self, expr):
return expr.__class__(
name=expr.name,
element=self.expression(expr._element),
type_=expr.type
)
def cast(self, expr):
return expr.__class__(
expression=self.expression(expr.clause),
type_=expr.type
)
def case(self, expr):
return expr.__class__(
whens=[
tuple(self.expression(x) for x in when) for when in expr.whens
],
value=self.expression(expr.value),
else_=self.expression(expr.else_)
)
def grouping(self, expr):
return expr.__class__(self.expression(expr.element))
def true(self, expr):
return expr
def false(self, expr):
return expr
def process_table(self, table):
return table
def column(self, column):
table = self.process_table(column.table)
return table.c[column.name]
def unary_expression(self, expr):
return expr.operator(self.expression(expr.element))
def bind_parameter(self, expr):
# somehow bind parameters passed as unicode are converted to
# ascii strings along the way, force convert them back to avoid
# sqlalchemy unicode warnings
if isinstance(expr.type, sa.Unicode):
expr.value = six.text_type(expr.value)
return expr
def binary_expression(self, expr):
return expr.__class__(
left=self.expression(expr.left),
right=self.expression(expr.right),
operator=expr.operator,
type_=expr.type,
negate=expr.negate,
modifiers=expr.modifiers.copy()
)
def boolean_expression(self, expr):
return expr.operator(*[
self.expression(child_expr)
for child_expr in expr.get_children()
])
def __call__(self, expr):
return self.expression(expr)

View File

@@ -0,0 +1,79 @@
from sqlalchemy_utils import ExpressionParser
import sqlalchemy as sa
from sqlalchemy.sql.elements import Cast, Null
from . import TestCase
class MyExpressionParser(ExpressionParser):
def __init__(self, some_class):
self.parent = some_class
def column(self, column):
return getattr(self.parent, column.key)
class TestExpressionParser(TestCase):
def setup_method(self, method):
TestCase.setup_method(self, method)
self.parser = MyExpressionParser(self.Category)
def test_false_expression(self):
expr = self.parser(self.User.name.isnot(False))
assert str(expr) == 'category.name IS NOT 0'
def test_true_expression(self):
expr = self.parser(self.User.name.isnot(True))
assert str(expr) == 'category.name IS NOT 1'
def test_unary_expression(self):
expr = self.parser(~ self.User.name)
assert str(expr) == 'NOT category.name'
def test_in_expression(self):
expr = self.parser(self.User.name.in_([2, 3]))
assert str(expr) == 'category.name IN (:name_1, :name_2)'
def test_boolean_expression(self):
expr = self.parser(self.User.name == False)
assert str(expr) == 'category.name = 0'
def test_label(self):
expr = self.parser(self.User.name.label('some_name'))
assert str(expr) == 'category.name'
def test_like(self):
expr = self.parser(self.User.name.like(u'something'))
assert str(expr) == 'category.name LIKE :name_1'
def test_cast(self):
expr = self.parser(Cast(self.User.name, sa.UnicodeText))
assert str(expr) == 'CAST(category.name AS TEXT)'
def test_case(self):
expr = self.parser(
sa.case(
[
(self.User.name == 'wendy', 'W'),
(self.User.name == 'jack', 'J')
],
else_='E'
)
)
assert str(expr) == (
'CASE WHEN (category.name = :name_1) '
'THEN :param_1 WHEN (category.name = :name_2) '
'THEN :param_2 ELSE :param_3 END'
)
def test_tuple(self):
expr = self.parser(
sa.tuple_(self.User.name, 3).in_([(u'someone', 3)])
)
assert str(expr) == (
'(category.name, :param_1) IN ((:param_2, :param_3))'
)
def test_null(self):
expr = self.parser(self.User.name == Null())
assert str(expr) == 'category.name IS NULL'