Add support for path translation and better handle extension stripping

This commit is contained in:
Alessandro Molina 2013-12-04 12:47:25 +01:00
parent fe826630a5
commit c7ab05ec7e
7 changed files with 133 additions and 45 deletions

View File

@ -1,13 +1,14 @@
""" """
This module implements the :class:`DispatchState` class This module implements the :class:`DispatchState` class
""" """
from crank.util import Path from crank.util import default_path_translator, noop_translation
try: try:
string_type = basestring string_type = basestring
except NameError: # pragma: no cover except NameError: # pragma: no cover
string_type = str string_type = str
class DispatchState(object): class DispatchState(object):
""" """
This class keeps around all the pertainent info for the state This class keeps around all the pertainent info for the state
@ -26,39 +27,43 @@ class DispatchState(object):
pre-split list of path elements, will use request.pathinfo if not used pre-split list of path elements, will use request.pathinfo if not used
""" """
def __init__(self, request, dispatcher=None, params=None, path_info=None, ignore_parameters=None): def __init__(self, request, dispatcher=None, params=None, path_info=None,
self.request = request ignore_parameters=None, strip_extension=True,
path_translator=default_path_translator):
path = path_info path = path_info
if path is None: if path is None:
path = request.path_info[1:] path = request.path_info[1:]
path = path.split('/') path = path.split('/')
elif isinstance(path, string_type): elif isinstance(path, string_type):
path = path.split('/') path = path.split('/')
try: try:
if not path[0]: if not path[0]:
path = path[1:] path = path[1:]
except IndexError: except IndexError:
pass pass
try: try:
while not path[-1]: while not path[-1]:
path = path[:-1] path = path[:-1]
except IndexError: except IndexError:
pass pass
if path_translator is None:
path_translator = noop_translation
self.request = request
self.extension = None self.extension = None
self.path_translator = path_translator
#rob the extension #rob the extension
if len(path) > 0 and '.' in path[-1]: if strip_extension and len(path) > 0 and '.' in path[-1]:
end = path[-1] end = path[-1]
end = end.split('.') end, ext = end.rsplit('.', 1)
self.extension = end[-1] self.extension = ext
path[-1] = '.'.join(end[:-1]) path[-1] = end
self.path = path self.path = path
if params is not None: if params is not None:
self.params = params self.params = params

View File

@ -85,7 +85,7 @@ class ObjectDispatcher(Dispatcher):
obj = getattr(controller, 'im_self', controller) obj = getattr(controller, 'im_self', controller)
security_check = getattr(obj, '_check_security', None) security_check = getattr(obj, '_check_security', None)
if security_check: if security_check is not None:
security_check() security_check()
def _dispatch_controller(self, current_path, controller, state, remainder): def _dispatch_controller(self, current_path, controller, state, remainder):
@ -100,7 +100,7 @@ class ObjectDispatcher(Dispatcher):
""" """
dispatcher = getattr(controller, '_dispatch', None) dispatcher = getattr(controller, '_dispatch', None)
if dispatcher: if dispatcher is not None:
self._perform_security_check(controller) self._perform_security_check(controller)
state.add_controller(current_path, controller) state.add_controller(current_path, controller)
state.dispatcher = controller state.dispatcher = controller
@ -170,7 +170,8 @@ class ObjectDispatcher(Dispatcher):
#to see if there is a default or lookup method we can use #to see if there is a default or lookup method we can use
return self._dispatch_first_found_default_or_lookup(state, remainder) return self._dispatch_first_found_default_or_lookup(state, remainder)
current_path = remainder[0]
current_path = state.path_translator(remainder[0])
current_args = remainder[1:] current_args = remainder[1:]
#an exposed method matching the path is found #an exposed method matching the path is found
@ -183,9 +184,9 @@ class ObjectDispatcher(Dispatcher):
#another controller is found #another controller is found
current_controller = getattr(current_controller, current_path, None) current_controller = getattr(current_controller, current_path, None)
if current_controller: if current_controller is not None:
return self._dispatch_controller( return self._dispatch_controller(current_path, current_controller,
current_path, current_controller, state, current_args) state, current_args)
#dispatch not found #dispatch not found
return self._dispatch_first_found_default_or_lookup(state, remainder) return self._dispatch_first_found_default_or_lookup(state, remainder)

View File

@ -20,7 +20,7 @@ class RestDispatcher(ObjectDispatcher):
if self._is_exposed(controller, method): if self._is_exposed(controller, method):
return getattr(controller, method) return getattr(controller, method)
def _handle_put_or_post(self, method, state, remainder): def _handle_put_or_post(self, http_method, state, remainder):
current_controller = state.controller current_controller = state.controller
if remainder: if remainder:
current_path = remainder[0] current_path = remainder[0]
@ -32,17 +32,15 @@ class RestDispatcher(ObjectDispatcher):
current_controller = getattr(current_controller, current_path) current_controller = getattr(current_controller, current_path)
return self._dispatch_controller(current_path, current_controller, state, remainder[1:]) return self._dispatch_controller(current_path, current_controller, state, remainder[1:])
method_name = method method = self._find_first_exposed(current_controller, [http_method])
method = self._find_first_exposed(current_controller, [method,])
if method and method_matches_args(method, state.params, remainder, self._use_lax_params): if method and method_matches_args(method, state.params, remainder, self._use_lax_params):
state.add_method(method, remainder) state.add_method(method, remainder)
return state return state
return self._dispatch_first_found_default_or_lookup(state, remainder) return self._dispatch_first_found_default_or_lookup(state, remainder)
def _handle_delete(self, method, state, remainder): def _handle_delete(self, http_method, state, remainder):
current_controller = state.controller current_controller = state.controller
method_name = method
method = self._find_first_exposed(current_controller, ('post_delete', 'delete')) method = self._find_first_exposed(current_controller, ('post_delete', 'delete'))
if method and method_matches_args(method, state.params, remainder, self._use_lax_params): if method and method_matches_args(method, state.params, remainder, self._use_lax_params):
@ -72,27 +70,32 @@ class RestDispatcher(ObjectDispatcher):
if hasattr(current_controller, find): if hasattr(current_controller, find):
method = find method = find
break break
if method is None: if method is None:
return return
fixed_args, var_args, kws, kw_args = get_argspec(getattr(current_controller, method)) fixed_args, var_args, kws, kw_args = get_argspec(getattr(current_controller, method))
fixed_arg_length = len(fixed_args) fixed_arg_length = len(fixed_args)
if var_args: if var_args:
for i, item in enumerate(remainder): for i, item in enumerate(remainder):
item = state.path_translator(item)
if hasattr(current_controller, item) and self._is_controller(current_controller, item): if hasattr(current_controller, item) and self._is_controller(current_controller, item):
current_controller = getattr(current_controller, item) current_controller = getattr(current_controller, item)
state.add_routing_args(item, remainder[:i], fixed_args, var_args) state.add_routing_args(item, remainder[:i], fixed_args, var_args)
return self._dispatch_controller(item, current_controller, state, remainder[i+1:]) return self._dispatch_controller(item, current_controller, state, remainder[i+1:])
elif fixed_arg_length< len(remainder) and hasattr(current_controller, remainder[fixed_arg_length]): elif fixed_arg_length< len(remainder) and hasattr(current_controller, remainder[fixed_arg_length]):
item = remainder[fixed_arg_length] item = state.path_translator(remainder[fixed_arg_length])
if hasattr(current_controller, item): if hasattr(current_controller, item):
if self._is_controller(current_controller, item): if self._is_controller(current_controller, item):
state.add_routing_args(item, remainder, fixed_args, var_args) state.add_routing_args(item, remainder, fixed_args, var_args)
return self._dispatch_controller(item, getattr(current_controller, item), state, remainder[fixed_arg_length+1:]) return self._dispatch_controller(item, getattr(current_controller, item),
state, remainder[fixed_arg_length+1:])
def _handle_delete_edit_or_new(self, state, remainder): def _handle_delete_edit_or_new(self, state, remainder):
method_name = remainder[-1] method_name = remainder[-1]
if method_name not in ('new', 'edit', 'delete'): if method_name not in ('new', 'edit', 'delete'):
return return
if method_name == 'delete': if method_name == 'delete':
method_name = 'get_delete' method_name = 'get_delete'
@ -157,15 +160,15 @@ class RestDispatcher(ObjectDispatcher):
#test for "delete", "edit" or "new" #test for "delete", "edit" or "new"
r = self._handle_delete_edit_or_new(state, remainder) r = self._handle_delete_edit_or_new(state, remainder)
if r: if r is not None:
return r return r
#test for custom REST-like attribute #test for custom REST-like attribute
r = self._handle_custom_get(state, remainder) r = self._handle_custom_get(state, remainder)
if r: if r is not None:
return r return r
current_path = remainder[0] current_path = state.path_translator(remainder[0])
if self._is_exposed(current_controller, current_path): if self._is_exposed(current_controller, current_path):
state.add_method(getattr(current_controller, current_path), remainder[1:]) state.add_method(getattr(current_controller, current_path), remainder[1:])
return state return state
@ -230,7 +233,7 @@ class RestDispatcher(ObjectDispatcher):
state.http_method = method state.http_method = method
r = self._check_for_sub_controllers(state, remainder) r = self._check_for_sub_controllers(state, remainder)
if r: if r is not None:
return r return r
if state.http_method in self._handler_lookup.keys(): if state.http_method in self._handler_lookup.keys():

View File

@ -5,14 +5,17 @@ Copyright (c) Chrispther Perkins
MIT License MIT License
""" """
import collections, sys import collections, sys, string
from inspect import getargspec
__all__ = [ __all__ = [
'get_argspec', 'get_params_with_argspec', 'remove_argspec_params_from_params', 'method_matches_args', 'get_argspec', 'get_params_with_argspec', 'remove_argspec_params_from_params',
'Path' 'method_matches_args', 'Path', 'default_path_translator'
] ]
from inspect import getargspec
_PY2 = bool(sys.version_info[0] == 2)
_cached_argspecs = {} _cached_argspecs = {}
def get_argspec(func): def get_argspec(func):
@ -146,7 +149,27 @@ def method_matches_args(method, params, remainder, lax_params=False):
return False return False
_PY3 = bool(sys.version_info[0] == 3)
if _PY2: #pragma: no cover
translation_dict = dict([(ord(c), unicode('_')) for c in unicode(string.punctuation)])
translation_string = string.maketrans(string.punctuation,
'_' * len(string.punctuation))
else: #pragma: no cover
translation_dict = None
translation_string = str.maketrans(string.punctuation,
'_' * len(string.punctuation))
def default_path_translator(path_piece):
if isinstance(path_piece, str):
return path_piece.translate(translation_string)
else: #pragma: no cover
return path_piece.translate(translation_dict)
def noop_translation(path_piece):
return path_piece
class Path(collections.deque): class Path(collections.deque):
def __init__(self, value=None, separator='/'): def __init__(self, value=None, separator='/'):
@ -161,7 +184,7 @@ class Path(collections.deque):
separator = self.separator separator = self.separator
self.clear() self.clear()
if _PY3: # pragma: no cover if not _PY2: # pragma: no cover
string_types = str string_types = str
else: # pragma: no cover else: # pragma: no cover
string_types = basestring string_types = basestring

View File

@ -92,6 +92,8 @@ class MockDispatcherWithNoDefault(ObjectDispatcher):
def index(self): def index(self):
pass pass
sub_child = MockDispatcher()
mock_dispatcher_with_no_default = MockDispatcherWithNoDefault() mock_dispatcher_with_no_default = MockDispatcherWithNoDefault()
class MockDispatcherWithIndexWithArgVars(ObjectDispatcher): class MockDispatcherWithIndexWithArgVars(ObjectDispatcher):
@ -242,3 +244,54 @@ class TestDispatcher:
req = MockRequest('/get_here', params={'a':1}) req = MockRequest('/get_here', params={'a':1})
state = DispatchState(req, mock_lookup_dispatcher_with_args) state = DispatchState(req, mock_lookup_dispatcher_with_args)
state = mock_lookup_dispatcher_with_args._dispatch(state) state = mock_lookup_dispatcher_with_args._dispatch(state)
def test_path_translation(self):
req = MockRequest('/no.args.json')
state = DispatchState(req, mock_dispatcher_with_no_default_or_index)
state = mock_dispatcher_with_no_default_or_index._dispatch(state)
assert state.method.__name__ == 'no_args', state.method
def test_path_translation_no_extension(self):
req = MockRequest('/no.args')
state = DispatchState(req, mock_dispatcher_with_no_default_or_index,
strip_extension=False)
state = mock_dispatcher_with_no_default_or_index._dispatch(state)
assert state.method.__name__ == 'no_args', state.method
@raises(HTTPNotFound)
def test_disabled_path_translation_no_extension(self):
req = MockRequest('/no.args')
state = DispatchState(req, mock_dispatcher_with_no_default_or_index,
strip_extension=False, path_translator=None)
state = mock_dispatcher_with_no_default_or_index._dispatch(state)
def test_path_translation_args_skipped(self):
req = MockRequest('/with.args/para.meter1/para.meter2.json')
state = DispatchState(req, mock_dispatcher_with_no_default_or_index)
state = mock_dispatcher_with_no_default_or_index._dispatch(state)
assert state.method.__name__ == 'with_args', state.method
assert 'para.meter1' in state.remainder, state.remainder
assert 'para.meter2' in state.remainder, state.remainder
def test_path_translation_sub_controller(self):
req = MockRequest('/sub.child/with.args/para.meter1/para.meter2.json')
state = DispatchState(req, mock_dispatcher_with_no_default)
state = mock_dispatcher_with_no_default._dispatch(state)
path_pieces = [piece[0] for piece in state.controller_path]
assert 'sub_child' in path_pieces
assert state.method.__name__ == 'with_args', state.method
assert 'para.meter1' in state.remainder, state.remainder
assert 'para.meter2' in state.remainder, state.remainder
def test_path_translation_sub_controller_no_strip_extension(self):
req = MockRequest('/sub.child/with.args/para.meter1/para.meter2.json')
state = DispatchState(req, mock_dispatcher_with_no_default,
strip_extension=False)
state = mock_dispatcher_with_no_default._dispatch(state)
path_pieces = [piece[0] for piece in state.controller_path]
assert 'sub_child' in path_pieces
assert state.method.__name__ == 'with_args', state.method
assert 'para.meter1' in state.remainder, state.remainder
assert 'para.meter2.json' in state.remainder, state.remainder

View File

@ -431,11 +431,9 @@ class TestRestWithSecurity:
req = MockRequest('/direct/a') req = MockRequest('/direct/a')
state = DispatchState(req) state = DispatchState(req)
state = self.dispatcher._dispatch(state) state = self.dispatcher._dispatch(state)
print state.method
@raises(MockError) @raises(MockError)
def test_check_security_with_nested_lookup(self): def test_check_security_with_nested_lookup(self):
req = MockRequest('/nested/withsec/a') req = MockRequest('/nested/withsec/a')
state = DispatchState(req) state = DispatchState(req)
state = self.dispatcher._dispatch(state) state = self.dispatcher._dispatch(state)
print state.method

View File

@ -1,13 +1,11 @@
# encoding: utf-8 # encoding: utf-8
import sys
from nose.tools import raises
from crank.util import * from crank.util import *
from crank.util import _PY2
_PY3 = bool(sys.version_info[0] == 3) if _PY2:
if _PY3:
def u(s): return s
else:
def u(s): return s.decode('utf-8') def u(s): return s.decode('utf-8')
else:
def u(s): return s
def mock_f(self, a, b, c=None, d=50, *args, **kw): def mock_f(self, a, b, c=None, d=50, *args, **kw):
pass pass
@ -185,10 +183,10 @@ def test_path_unicode():
instance = MockOb() instance = MockOb()
instance.path = case instance.path = case
if _PY3: if _PY2:
yield assert_path, instance, expected, str
else:
yield assert_path, instance, expected, unicode yield assert_path, instance, expected, unicode
else:
yield assert_path, instance, expected, str
def test_path_slicing(): def test_path_slicing():
class MockOb(object): class MockOb(object):
@ -208,3 +206,10 @@ def test_path_comparison():
assert Path('/foo') == ['', 'foo'], 'list comparison' assert Path('/foo') == ['', 'foo'], 'list comparison'
assert Path('/foo') == '/foo', 'string comparison' assert Path('/foo') == '/foo', 'string comparison'
assert Path(u('/föö')) == u('/föö'), 'string comparison' assert Path(u('/föö')) == u('/föö'), 'string comparison'
def test_path_translation():
translated = default_path_translator('a.b')
assert translated == 'a_b', translated
translated = default_path_translator(u('f.ö.ö'))
assert translated == u('f_ö_ö'), translated