From 823ad79d9b3c494bb65a93afbdf48b80e0506a1d Mon Sep 17 00:00:00 2001 From: Konsta Vesterinen Date: Wed, 7 May 2014 15:19:28 +0300 Subject: [PATCH] Make expression parser support instrumented attrs --- sqlalchemy_utils/expression_parser.py | 61 ++++++++++++++------------- tests/test_expression_parser.py | 5 ++- 2 files changed, 35 insertions(+), 31 deletions(-) diff --git a/sqlalchemy_utils/expression_parser.py b/sqlalchemy_utils/expression_parser.py index 95cce70..5bac9d3 100644 --- a/sqlalchemy_utils/expression_parser.py +++ b/sqlalchemy_utils/expression_parser.py @@ -1,5 +1,11 @@ +try: + from collections import OrderedDict +except ImportError: + from ordereddict import OrderedDict + import six import sqlalchemy as sa +from sqlalchemy.orm.attributes import InstrumentedAttribute from sqlalchemy.sql.annotation import AnnotatedColumn from sqlalchemy.sql.expression import ( BooleanClauseList, @@ -21,43 +27,38 @@ from sqlalchemy.sql.elements import ( class ExpressionParser(object): + parsers = OrderedDict(( + (BinaryExpression, 'binary_expression'), + (BooleanClauseList, 'boolean_expression'), + (UnaryExpression, 'unary_expression'), + (sa.Column, 'column'), + (AnnotatedColumn, 'column'), + (BindParameter, 'bind_parameter'), + (False_, 'false'), + (True_, 'true'), + (Grouping, 'grouping'), + (ClauseList, 'clause_list'), + (Label, 'label'), + (Cast, 'cast'), + (Case, 'case'), + (Tuple, 'tuple'), + (Null, 'null'), + (InstrumentedAttribute, 'instrumented_attribute') + )) + def expression(self, expr): if expr is None: return - if isinstance(expr, BinaryExpression): - return self.binary_expression(expr) - elif isinstance(expr, BooleanClauseList): - return self.boolean_expression(expr) - elif isinstance(expr, UnaryExpression): - return self.unary_expression(expr) - elif isinstance(expr, sa.Column): - return self.column(expr) - elif isinstance(expr, AnnotatedColumn): - return self.column(expr) - elif isinstance(expr, BindParameter): - return self.bind_parameter(expr) - elif isinstance(expr, False_): - return self.false(expr) - elif isinstance(expr, True_): - return self.true(expr) - elif isinstance(expr, Grouping): - return self.grouping(expr) - elif isinstance(expr, ClauseList): - return self.clause_list(expr) - elif isinstance(expr, Label): - return self.label(expr) - elif isinstance(expr, Cast): - return self.cast(expr) - elif isinstance(expr, Case): - return self.case(expr) - elif isinstance(expr, Tuple): - return self.tuple(expr) - elif isinstance(expr, Null): - return self.null(expr) + for class_, parser in self.parsers.items(): + if isinstance(expr, class_): + return getattr(self, parser)(expr) raise Exception( 'Unknown expression type %s' % expr.__class__.__name__ ) + def instrumented_attribute(self, expr): + return expr + def null(self, expr): return expr diff --git a/tests/test_expression_parser.py b/tests/test_expression_parser.py index 6510818..672de92 100644 --- a/tests/test_expression_parser.py +++ b/tests/test_expression_parser.py @@ -12,6 +12,9 @@ class MyExpressionParser(ExpressionParser): def column(self, column): return getattr(self.parent, column.key) + def instrumented_attribute(self, column): + return getattr(self.parent, column.key) + class TestExpressionParser(TestCase): create_tables = False @@ -82,4 +85,4 @@ class TestExpressionParser(TestCase): def test_instrumented_attribute(self): expr = self.parser(self.User.name) - assert str(expr) == 'category.name' + assert str(expr) == 'Category.name'