From 5e42d32a8492a52e6bb9e77a03e8f8e0c343fbde Mon Sep 17 00:00:00 2001 From: Konsta Vesterinen Date: Tue, 22 Apr 2014 12:59:01 +0300 Subject: [PATCH] Add expression parser --- sqlalchemy_utils/__init__.py | 1 + sqlalchemy_utils/expression_parser.py | 144 ++++++++++++++++++++++++++ tests/test_expression_parser.py | 79 ++++++++++++++ 3 files changed, 224 insertions(+) create mode 100644 sqlalchemy_utils/expression_parser.py create mode 100644 tests/test_expression_parser.py diff --git a/sqlalchemy_utils/__init__.py b/sqlalchemy_utils/__init__.py index ea20e17..d302a87 100644 --- a/sqlalchemy_utils/__init__.py +++ b/sqlalchemy_utils/__init__.py @@ -2,6 +2,7 @@ from .aggregates import aggregated from .batch import batch_fetch, with_backrefs from .decorators import generates from .exceptions import ImproperlyConfigured +from .expression_parser import ExpressionParser from .functions import ( create_database, create_mock_engine, diff --git a/sqlalchemy_utils/expression_parser.py b/sqlalchemy_utils/expression_parser.py new file mode 100644 index 0000000..95cce70 --- /dev/null +++ b/sqlalchemy_utils/expression_parser.py @@ -0,0 +1,144 @@ +import six +import sqlalchemy as sa +from sqlalchemy.sql.annotation import AnnotatedColumn +from sqlalchemy.sql.expression import ( + BooleanClauseList, + BinaryExpression, + UnaryExpression, + BindParameter, + Cast, +) +from sqlalchemy.sql.elements import ( + False_, + True_, + Grouping, + ClauseList, + Label, + Case, + Tuple, + Null +) + + +class ExpressionParser(object): + def expression(self, expr): + if expr is None: + return + if isinstance(expr, BinaryExpression): + return self.binary_expression(expr) + elif isinstance(expr, BooleanClauseList): + return self.boolean_expression(expr) + elif isinstance(expr, UnaryExpression): + return self.unary_expression(expr) + elif isinstance(expr, sa.Column): + return self.column(expr) + elif isinstance(expr, AnnotatedColumn): + return self.column(expr) + elif isinstance(expr, BindParameter): + return self.bind_parameter(expr) + elif isinstance(expr, False_): + return self.false(expr) + elif isinstance(expr, True_): + return self.true(expr) + elif isinstance(expr, Grouping): + return self.grouping(expr) + elif isinstance(expr, ClauseList): + return self.clause_list(expr) + elif isinstance(expr, Label): + return self.label(expr) + elif isinstance(expr, Cast): + return self.cast(expr) + elif isinstance(expr, Case): + return self.case(expr) + elif isinstance(expr, Tuple): + return self.tuple(expr) + elif isinstance(expr, Null): + return self.null(expr) + raise Exception( + 'Unknown expression type %s' % expr.__class__.__name__ + ) + + def null(self, expr): + return expr + + def tuple(self, expr): + return expr.__class__( + *map(self.expression, expr.clauses), + type_=expr.type + ) + + def clause_list(self, expr): + return expr.__class__( + *map(self.expression, expr.clauses), + group=expr.group, + group_contents=expr.group_contents, + operator=expr.operator + ) + + def label(self, expr): + return expr.__class__( + name=expr.name, + element=self.expression(expr._element), + type_=expr.type + ) + + def cast(self, expr): + return expr.__class__( + expression=self.expression(expr.clause), + type_=expr.type + ) + + def case(self, expr): + return expr.__class__( + whens=[ + tuple(self.expression(x) for x in when) for when in expr.whens + ], + value=self.expression(expr.value), + else_=self.expression(expr.else_) + ) + + def grouping(self, expr): + return expr.__class__(self.expression(expr.element)) + + def true(self, expr): + return expr + + def false(self, expr): + return expr + + def process_table(self, table): + return table + + def column(self, column): + table = self.process_table(column.table) + return table.c[column.name] + + def unary_expression(self, expr): + return expr.operator(self.expression(expr.element)) + + def bind_parameter(self, expr): + # somehow bind parameters passed as unicode are converted to + # ascii strings along the way, force convert them back to avoid + # sqlalchemy unicode warnings + if isinstance(expr.type, sa.Unicode): + expr.value = six.text_type(expr.value) + return expr + + def binary_expression(self, expr): + return expr.__class__( + left=self.expression(expr.left), + right=self.expression(expr.right), + operator=expr.operator, + type_=expr.type, + negate=expr.negate, + modifiers=expr.modifiers.copy() + ) + + def boolean_expression(self, expr): + return expr.operator(*[ + self.expression(child_expr) + for child_expr in expr.get_children() + ]) + + def __call__(self, expr): + return self.expression(expr) diff --git a/tests/test_expression_parser.py b/tests/test_expression_parser.py new file mode 100644 index 0000000..dffe922 --- /dev/null +++ b/tests/test_expression_parser.py @@ -0,0 +1,79 @@ +from sqlalchemy_utils import ExpressionParser +import sqlalchemy as sa +from sqlalchemy.sql.elements import Cast, Null + +from . import TestCase + + +class MyExpressionParser(ExpressionParser): + def __init__(self, some_class): + self.parent = some_class + + def column(self, column): + return getattr(self.parent, column.key) + + +class TestExpressionParser(TestCase): + def setup_method(self, method): + TestCase.setup_method(self, method) + self.parser = MyExpressionParser(self.Category) + + def test_false_expression(self): + expr = self.parser(self.User.name.isnot(False)) + assert str(expr) == 'category.name IS NOT 0' + + def test_true_expression(self): + expr = self.parser(self.User.name.isnot(True)) + assert str(expr) == 'category.name IS NOT 1' + + def test_unary_expression(self): + expr = self.parser(~ self.User.name) + assert str(expr) == 'NOT category.name' + + def test_in_expression(self): + expr = self.parser(self.User.name.in_([2, 3])) + assert str(expr) == 'category.name IN (:name_1, :name_2)' + + def test_boolean_expression(self): + expr = self.parser(self.User.name == False) + assert str(expr) == 'category.name = 0' + + def test_label(self): + expr = self.parser(self.User.name.label('some_name')) + assert str(expr) == 'category.name' + + def test_like(self): + expr = self.parser(self.User.name.like(u'something')) + assert str(expr) == 'category.name LIKE :name_1' + + def test_cast(self): + expr = self.parser(Cast(self.User.name, sa.UnicodeText)) + assert str(expr) == 'CAST(category.name AS TEXT)' + + def test_case(self): + expr = self.parser( + sa.case( + [ + (self.User.name == 'wendy', 'W'), + (self.User.name == 'jack', 'J') + ], + else_='E' + ) + ) + assert str(expr) == ( + 'CASE WHEN (category.name = :name_1) ' + 'THEN :param_1 WHEN (category.name = :name_2) ' + 'THEN :param_2 ELSE :param_3 END' + ) + + def test_tuple(self): + expr = self.parser( + sa.tuple_(self.User.name, 3).in_([(u'someone', 3)]) + ) + assert str(expr) == ( + '(category.name, :param_1) IN ((:param_2, :param_3))' + ) + + def test_null(self): + expr = self.parser(self.User.name == Null()) + assert str(expr) == 'category.name IS NULL'