feat(CompiledRouter): Add an intermediate AST step to the compiler (#1040)

Rather than compiling the routing tree directly to Python code,
first generate an AST and then use it to produce the code. This
provides several benefits, including:

    * It makes the compilation process easier to reason about.
    * It makes it easier to keep track of indentation and whitespace.
    * It sets us up for being able to make transformations that we
      will need to do to support URI template filters, etc. in the
      future.
This commit is contained in:
Kurt Griffiths
2017-05-17 09:47:10 -06:00
committed by John Vrbanac
parent e0abfcfc64
commit f1597d77e6
3 changed files with 229 additions and 56 deletions

View File

@@ -18,7 +18,7 @@ import keyword
import re import re
_FIELD_REGEX = re.compile('{([^}]*)}') _FIELD_PATTERN = re.compile('{([^}]*)}')
_TAB_STR = ' ' * 4 _TAB_STR = ' ' * 4
@@ -35,14 +35,26 @@ class CompiledRouter(object):
processing quite fast. processing quite fast.
""" """
__slots__ = (
'_ast',
'_find',
'_finder_src',
'_patterns',
'_return_values',
'_roots',
)
def __init__(self): def __init__(self):
self._roots = [] self._roots = []
self._find = self._compile() self._find = self._compile()
self._code_lines = None self._finder_src = None
self._src = None self._patterns = None
self._expressions = None
self._return_values = None self._return_values = None
@property
def finder_src(self):
return self._finder_src
def add_route(self, uri_template, method_map, resource): def add_route(self, uri_template, method_map, resource):
"""Adds a route between a URI path template and a resource. """Adds a route between a URI path template and a resource.
@@ -66,7 +78,7 @@ class CompiledRouter(object):
# values from more shallow nodes. # values from more shallow nodes.
# 2. For complex nodes, re.compile() raises a nasty error # 2. For complex nodes, re.compile() raises a nasty error
# #
fields = _FIELD_REGEX.findall(uri_template) fields = _FIELD_PATTERN.findall(uri_template)
used_names = set() used_names = set()
for name in fields: for name in fields:
is_identifier = re.match('[A-Za-z_][A-Za-z0-9_]*$', name) is_identifier = re.match('[A-Za-z_][A-Za-z0-9_]*$', name)
@@ -145,28 +157,28 @@ class CompiledRouter(object):
path = uri.lstrip('/').split('/') path = uri.lstrip('/').split('/')
params = {} params = {}
node = self._find(path, self._return_values, self._expressions, params) node = self._find(path, self._return_values, self._patterns, params)
if node is not None: if node is not None:
return node.resource, node.method_map, params, node.uri_template return node.resource, node.method_map, params, node.uri_template
else: else:
return None return None
def _compile_tree(self, nodes, indent=1, level=0, fast_return=True): # -----------------------------------------------------------------
"""Generates Python code for a routing tree or subtree.""" # Private
# -----------------------------------------------------------------
def line(text, indent_offset=0): def _generate_ast(self, nodes, parent, return_values, patterns, level=0, fast_return=True):
pad = _TAB_STR * (indent + indent_offset) """Generates a coarse AST for the router."""
self._code_lines.append(pad + text)
# NOTE(kgriffs): Base case # NOTE(kgriffs): Base case
if not nodes: if not nodes:
return return
line('if path_len > %d:' % level) outer_parent = _CxIfPathLength('>', level)
indent += 1 parent.append(outer_parent)
parent = outer_parent
level_indent = indent
found_simple = False found_simple = False
# NOTE(kgriffs & philiptzou): Sort nodes in this sequence: # NOTE(kgriffs & philiptzou): Sort nodes in this sequence:
@@ -195,20 +207,18 @@ class CompiledRouter(object):
# contain anything more than a single literal or variable, # contain anything more than a single literal or variable,
# and they need to be checked using a pre-compiled regular # and they need to be checked using a pre-compiled regular
# expression. # expression.
expression_idx = len(self._expressions) pattern_idx = len(patterns)
self._expressions.append(node.var_regex) patterns.append(node.var_pattern)
line('match = expressions[%d].match(path[%d]) # %s' % ( construct = _CxIfPathSegmentPattern(level, pattern_idx,
expression_idx, level, node.var_regex.pattern)) node.var_pattern.pattern)
parent.append(construct)
line('if match is not None:') parent = construct
indent += 1
line('params.update(match.groupdict())')
else: else:
# NOTE(kgriffs): Simple nodes just capture the entire path # NOTE(kgriffs): Simple nodes just capture the entire path
# segment as the value for the param. # segment as the value for the param.
line('params["%s"] = path[%d]' % (node.var_name, level)) parent.append(_CxSetParam(node.var_name, level))
# NOTE(kgriffs): We don't allow multiple simple var nodes # NOTE(kgriffs): We don't allow multiple simple var nodes
# to exist at the same level, e.g.: # to exist at the same level, e.g.:
@@ -222,59 +232,78 @@ class CompiledRouter(object):
else: else:
# NOTE(kgriffs): Not a param, so must match exactly # NOTE(kgriffs): Not a param, so must match exactly
line('if path[%d] == "%s":' % (level, node.raw_segment)) construct = _CxIfPathSegmentLiteral(level, node.raw_segment)
indent += 1 parent.append(construct)
parent = construct
if node.resource is not None: if node.resource is not None:
# NOTE(kgriffs): This is a valid route, so we will want to # NOTE(kgriffs): This is a valid route, so we will want to
# return the relevant information. # return the relevant information.
resource_idx = len(self._return_values) resource_idx = len(return_values)
self._return_values.append(node) return_values.append(node)
self._compile_tree(node.children, indent, level + 1, fast_return) self._generate_ast(
node.children,
parent,
return_values,
patterns,
level + 1,
fast_return
)
if node.resource is None: if node.resource is None:
if fast_return: if fast_return:
line('return None') parent.append(_CxReturnNone())
else: else:
# NOTE(kgriffs): Make sure that we have consumed all of # NOTE(kgriffs): Make sure that we have consumed all of
# the segments for the requested route; otherwise we could # the segments for the requested route; otherwise we could
# mistakenly match "/foo/23/bar" against "/foo/{id}". # mistakenly match "/foo/23/bar" against "/foo/{id}".
line('if path_len == %d:' % (level + 1)) construct = _CxIfPathLength('==', level + 1)
line('return return_values[%d]' % resource_idx, 1) construct.append(_CxReturnValue(resource_idx))
parent.append(construct)
if fast_return: if fast_return:
line('return None') parent.append(_CxReturnNone())
indent = level_indent parent = outer_parent
if not found_simple and fast_return: if not found_simple and fast_return:
line('return None') parent.append(_CxReturnNone())
def _compile(self): def _compile(self):
"""Generates Python code for entire routing tree. """Generates Python code for the entire routing tree.
The generated code is compiled and the resulting Python method is The generated code is compiled and the resulting Python method
returned. is returned.
""" """
self._return_values = []
self._expressions = [] src_lines = [
self._code_lines = [ 'def find(path, return_values, patterns, params):',
'def find(path, return_values, expressions, params):',
_TAB_STR + 'path_len = len(path)', _TAB_STR + 'path_len = len(path)',
] ]
self._compile_tree(self._roots) self._return_values = []
self._patterns = []
self._code_lines.append( self._ast = _CxParent()
self._generate_ast(
self._roots,
self._ast,
self._return_values,
self._patterns
)
src_lines.append(self._ast.src(0))
src_lines.append(
# PERF(kgriffs): Explicit return of None is faster than implicit # PERF(kgriffs): Explicit return of None is faster than implicit
_TAB_STR + 'return None' _TAB_STR + 'return None'
) )
self._src = '\n'.join(self._code_lines) self._finder_src = '\n'.join(src_lines)
scope = {} scope = {}
exec(compile(self._src, '<string>', 'exec'), scope) exec(compile(self._finder_src, '<string>', 'exec'), scope)
return scope['find'] return scope['find']
@@ -294,11 +323,12 @@ class CompiledRouterNode(object):
self.is_var = False self.is_var = False
self.is_complex = False self.is_complex = False
self.var_name = None self.var_name = None
self.var_pattern = None
# NOTE(kgriffs): CompiledRouter.add_route validates field names, # NOTE(kgriffs): CompiledRouter.add_route validates field names,
# so here we can just assume they are OK and use the simple # so here we can just assume they are OK and use the simple
# _FIELD_REGEX to match them. # _FIELD_PATTERN to match them.
matches = list(_FIELD_REGEX.finditer(raw_segment)) matches = list(_FIELD_PATTERN.finditer(raw_segment))
if not matches: if not matches:
self.is_var = False self.is_var = False
@@ -334,11 +364,11 @@ class CompiledRouterNode(object):
# trick the parser into doing the right thing. # trick the parser into doing the right thing.
escaped_segment = re.sub(r'[\.\(\)\[\]\?\$\*\+\^\|]', r'\\\g<0>', raw_segment) escaped_segment = re.sub(r'[\.\(\)\[\]\?\$\*\+\^\|]', r'\\\g<0>', raw_segment)
seg_pattern = _FIELD_REGEX.sub(r'(?P<\1>.+)', escaped_segment) pattern_text = _FIELD_PATTERN.sub(r'(?P<\1>.+)', escaped_segment)
seg_pattern = '^' + seg_pattern + '$' pattern_text = '^' + pattern_text + '$'
self.is_complex = True self.is_complex = True
self.var_regex = re.compile(seg_pattern) self.var_pattern = re.compile(pattern_text)
def matches(self, segment): def matches(self, segment):
"""Returns True if this node matches the supplied template segment.""" """Returns True if this node matches the supplied template segment."""
@@ -360,7 +390,7 @@ class CompiledRouterNode(object):
# simple, complex ==> False # simple, complex ==> False
# simple, string ==> False # simple, string ==> False
# complex, simple ==> False # complex, simple ==> False
# complex, complex ==> (Depend) # complex, complex ==> (Maybe)
# complex, string ==> False # complex, string ==> False
# string, simple ==> False # string, simple ==> False
# string, complex ==> False # string, complex ==> False
@@ -389,8 +419,8 @@ class CompiledRouterNode(object):
# #
if self.is_complex: if self.is_complex:
if other.is_complex: if other.is_complex:
return (_FIELD_REGEX.sub('v', self.raw_segment) == return (_FIELD_PATTERN.sub('v', self.raw_segment) ==
_FIELD_REGEX.sub('v', segment)) _FIELD_PATTERN.sub('v', segment))
return False return False
else: else:
@@ -399,3 +429,125 @@ class CompiledRouterNode(object):
# NOTE(kgriffs): If self is a static string match, then all the cases # NOTE(kgriffs): If self is a static string match, then all the cases
# for other are False, so no need to check. # for other are False, so no need to check.
return False return False
# --------------------------------------------------------------------
# AST Constructs
#
# NOTE(kgriffs): These constructs are used to create a very coarse
# AST that can then be used to generate Python source code for the
# router. Using an AST like this makes it easier to reason about
# the compilation process, and affords syntactical transformations
# that would otherwise be at best confusing and at worst extremely
# tedious and error-prone if they were to be attempted directly
# against the Python source code.
# --------------------------------------------------------------------
class _CxParent(object):
def __init__(self):
self._children = []
def append(self, construct):
self._children.append(construct)
def src(self, indentation):
return self._children_src(indentation + 1)
def _children_src(self, indentation):
src_lines = [
child.src(indentation)
for child in self._children
]
return '\n'.join(src_lines)
class _CxIfPathLength(_CxParent):
def __init__(self, comparison, length):
super(_CxIfPathLength, self).__init__()
self._comparison = comparison
self._length = length
def src(self, indentation):
template = '{0}if path_len {1} {2}:\n{3}'
return template.format(
_TAB_STR * indentation,
self._comparison,
self._length,
self._children_src(indentation + 1)
)
class _CxIfPathSegmentLiteral(_CxParent):
def __init__(self, segment_idx, literal):
super(_CxIfPathSegmentLiteral, self).__init__()
self._segment_idx = segment_idx
self._literal = literal
def src(self, indentation):
template = "{0}if path[{1}] == '{2}':\n{3}"
return template.format(
_TAB_STR * indentation,
self._segment_idx,
self._literal,
self._children_src(indentation + 1)
)
class _CxIfPathSegmentPattern(_CxParent):
def __init__(self, segment_idx, pattern_idx, pattern_text):
super(_CxIfPathSegmentPattern, self).__init__()
self._segment_idx = segment_idx
self._pattern_idx = pattern_idx
self._pattern_text = pattern_text
def src(self, indentation):
lines = []
lines.append(
'{0}match = patterns[{1}].match(path[{2}]) # {3}'.format(
_TAB_STR * indentation,
self._pattern_idx,
self._segment_idx,
self._pattern_text,
)
)
lines.append('{0}if match is not None:'.format(_TAB_STR * indentation))
lines.append('{0}params.update(match.groupdict())'.format(
_TAB_STR * (indentation + 1)
))
lines.append(self._children_src(indentation + 1))
return '\n'.join(lines)
class _CxReturnNone(object):
def src(self, indentation):
return '{0}return None'.format(_TAB_STR * indentation)
class _CxReturnValue(object):
def __init__(self, value_idx):
self._value_idx = value_idx
def src(self, indentation):
return '{0}return return_values[{1}]'.format(
_TAB_STR * indentation,
self._value_idx
)
class _CxSetParam(object):
def __init__(self, param_name, segment_idx):
self._param_name = param_name
self._segment_idx = segment_idx
def src(self, indentation):
return "{0}params['{1}'] = path[{2}]".format(
_TAB_STR * indentation,
self._param_name,
self._segment_idx,
)

View File

@@ -5,7 +5,6 @@ from falcon import testing
class TestCustomRouter(testing.TestBase): class TestCustomRouter(testing.TestBase):
def test_custom_router_add_route_should_be_used(self): def test_custom_router_add_route_should_be_used(self):
check = [] check = []
class CustomRouter(object): class CustomRouter(object):

View File

@@ -1,3 +1,5 @@
import textwrap
import pytest import pytest
import falcon import falcon
@@ -224,6 +226,20 @@ def test_root_path():
resource, __, __, __ = router.find('/') resource, __, __, __ = router.find('/')
assert resource.resource_id == 42 assert resource.resource_id == 42
expected_src = textwrap.dedent("""
def find(path, return_values, patterns, params):
path_len = len(path)
if path_len > 0:
if path[0] == '':
if path_len == 1:
return return_values[0]
return None
return None
return None
""").strip()
assert router.finder_src == expected_src
@pytest.mark.parametrize('uri_template', [ @pytest.mark.parametrize('uri_template', [
'/{field}{field}', '/{field}{field}',
@@ -309,8 +325,14 @@ def test_invalid_field_name(router, uri_template):
router.add_route(uri_template, {}, ResourceWithId(-1)) router.add_route(uri_template, {}, ResourceWithId(-1))
def test_dump(router): def test_print_src(router):
print(router._src) """Diagnostic test that simply prints the router's find() source code.
Example:
$ tox -e py27_debug -- -k test_print_src -s
"""
print('\n\n' + router.finder_src + '\n')
def test_override(router): def test_override(router):