From 1a76f9868d2dbc7fbd4d15e02d3a958f96feb512 Mon Sep 17 00:00:00 2001 From: Ryan Leckey Date: Wed, 23 Oct 2013 00:25:40 -0700 Subject: [PATCH] Make render_expression use the stack like mock_engine. --- sqlalchemy_utils/functions/__init__.py | 31 +++++++++++++++++++++----- tests/test_utility_functions.py | 7 +++--- 2 files changed, 30 insertions(+), 8 deletions(-) diff --git a/sqlalchemy_utils/functions/__init__.py b/sqlalchemy_utils/functions/__init__.py index adfb50d..4370e7d 100644 --- a/sqlalchemy_utils/functions/__init__.py +++ b/sqlalchemy_utils/functions/__init__.py @@ -340,7 +340,7 @@ def mock_engine(engine, stream=None): six.exec_('del __mock', frame.f_globals, frame.f_locals) -def render_expression(expression, bind, context=None): +def render_expression(expression, bind, stream=None): """Generate a SQL expression from the passed python expression. Only the global variable, `engine`, is available for use in the @@ -351,15 +351,36 @@ def render_expression(expression, bind, context=None): blindly pass user input to this function as it uses exec. :param bind: A SQLAlchemy engine or bind URL. - :param context: Dictionary of local variables for the expression. + :param stream: Render all DDL operations to the stream. """ - stream = cStringIO() + # Create a stream if not present. + + if stream is None: + stream = cStringIO() + engine = create_mock_engine(bind, stream) - six.exec_(expression, {'engine': engine}, context) + # Navigate the stack and find the calling frame that allows the + # expression to execuate. - return stream.getvalue() + for frame in inspect.stack()[1:]: + + try: + frame = frame[0] + local = dict(frame.f_locals) + local['engine'] = engine + six.exec_(expression, frame.f_globals, local) + break + + except: + pass + + else: + + raise ValueError('Not a valid python expression', engine) + + return stream def render_statement(statement, bind=None): diff --git a/tests/test_utility_functions.py b/tests/test_utility_functions.py index b0242b1..b28b046 100644 --- a/tests/test_utility_functions.py +++ b/tests/test_utility_functions.py @@ -102,9 +102,10 @@ class TestRender(TestCase): assert 'WHERE user.id = 3' in text def test_render_ddl(self): - context = {'table': self.User.__table__} - expression = 'table.create(engine)' - text = render_expression(expression, self.engine, context) + expression = 'self.User.__table__.create(engine)' + stream = render_expression(expression, self.engine) + + text = stream.getvalue() assert 'CREATE TABLE user' in text assert 'PRIMARY KEY' in text