Add support for path translation and better handle extension stripping

This commit is contained in:
Alessandro Molina 2013-12-04 12:47:25 +01:00
parent fe826630a5
commit c7ab05ec7e
7 changed files with 133 additions and 45 deletions

View File

@ -1,13 +1,14 @@
"""
This module implements the :class:`DispatchState` class
"""
from crank.util import Path
from crank.util import default_path_translator, noop_translation
try:
string_type = basestring
except NameError: # pragma: no cover
string_type = str
class DispatchState(object):
"""
This class keeps around all the pertainent info for the state
@ -26,40 +27,44 @@ class DispatchState(object):
pre-split list of path elements, will use request.pathinfo if not used
"""
def __init__(self, request, dispatcher=None, params=None, path_info=None, ignore_parameters=None):
self.request = request
def __init__(self, request, dispatcher=None, params=None, path_info=None,
ignore_parameters=None, strip_extension=True,
path_translator=default_path_translator):
path = path_info
if path is None:
path = request.path_info[1:]
path = path.split('/')
elif isinstance(path, string_type):
path = path.split('/')
try:
if not path[0]:
path = path[1:]
except IndexError:
pass
try:
while not path[-1]:
path = path[:-1]
except IndexError:
pass
if path_translator is None:
path_translator = noop_translation
self.request = request
self.extension = None
self.path_translator = path_translator
#rob the extension
if len(path) > 0 and '.' in path[-1]:
if strip_extension and len(path) > 0 and '.' in path[-1]:
end = path[-1]
end = end.split('.')
self.extension = end[-1]
path[-1] = '.'.join(end[:-1])
end, ext = end.rsplit('.', 1)
self.extension = ext
path[-1] = end
self.path = path
if params is not None:
self.params = params
else:

View File

@ -85,7 +85,7 @@ class ObjectDispatcher(Dispatcher):
obj = getattr(controller, 'im_self', controller)
security_check = getattr(obj, '_check_security', None)
if security_check:
if security_check is not None:
security_check()
def _dispatch_controller(self, current_path, controller, state, remainder):
@ -100,7 +100,7 @@ class ObjectDispatcher(Dispatcher):
"""
dispatcher = getattr(controller, '_dispatch', None)
if dispatcher:
if dispatcher is not None:
self._perform_security_check(controller)
state.add_controller(current_path, controller)
state.dispatcher = controller
@ -170,7 +170,8 @@ class ObjectDispatcher(Dispatcher):
#to see if there is a default or lookup method we can use
return self._dispatch_first_found_default_or_lookup(state, remainder)
current_path = remainder[0]
current_path = state.path_translator(remainder[0])
current_args = remainder[1:]
#an exposed method matching the path is found
@ -183,9 +184,9 @@ class ObjectDispatcher(Dispatcher):
#another controller is found
current_controller = getattr(current_controller, current_path, None)
if current_controller:
return self._dispatch_controller(
current_path, current_controller, state, current_args)
if current_controller is not None:
return self._dispatch_controller(current_path, current_controller,
state, current_args)
#dispatch not found
return self._dispatch_first_found_default_or_lookup(state, remainder)

View File

@ -20,7 +20,7 @@ class RestDispatcher(ObjectDispatcher):
if self._is_exposed(controller, method):
return getattr(controller, method)
def _handle_put_or_post(self, method, state, remainder):
def _handle_put_or_post(self, http_method, state, remainder):
current_controller = state.controller
if remainder:
current_path = remainder[0]
@ -32,17 +32,15 @@ class RestDispatcher(ObjectDispatcher):
current_controller = getattr(current_controller, current_path)
return self._dispatch_controller(current_path, current_controller, state, remainder[1:])
method_name = method
method = self._find_first_exposed(current_controller, [method,])
method = self._find_first_exposed(current_controller, [http_method])
if method and method_matches_args(method, state.params, remainder, self._use_lax_params):
state.add_method(method, remainder)
return state
return self._dispatch_first_found_default_or_lookup(state, remainder)
def _handle_delete(self, method, state, remainder):
def _handle_delete(self, http_method, state, remainder):
current_controller = state.controller
method_name = method
method = self._find_first_exposed(current_controller, ('post_delete', 'delete'))
if method and method_matches_args(method, state.params, remainder, self._use_lax_params):
@ -72,27 +70,32 @@ class RestDispatcher(ObjectDispatcher):
if hasattr(current_controller, find):
method = find
break
if method is None:
return
fixed_args, var_args, kws, kw_args = get_argspec(getattr(current_controller, method))
fixed_arg_length = len(fixed_args)
if var_args:
for i, item in enumerate(remainder):
item = state.path_translator(item)
if hasattr(current_controller, item) and self._is_controller(current_controller, item):
current_controller = getattr(current_controller, item)
state.add_routing_args(item, remainder[:i], fixed_args, var_args)
return self._dispatch_controller(item, current_controller, state, remainder[i+1:])
elif fixed_arg_length< len(remainder) and hasattr(current_controller, remainder[fixed_arg_length]):
item = remainder[fixed_arg_length]
item = state.path_translator(remainder[fixed_arg_length])
if hasattr(current_controller, item):
if self._is_controller(current_controller, item):
state.add_routing_args(item, remainder, fixed_args, var_args)
return self._dispatch_controller(item, getattr(current_controller, item), state, remainder[fixed_arg_length+1:])
return self._dispatch_controller(item, getattr(current_controller, item),
state, remainder[fixed_arg_length+1:])
def _handle_delete_edit_or_new(self, state, remainder):
method_name = remainder[-1]
if method_name not in ('new', 'edit', 'delete'):
return
if method_name == 'delete':
method_name = 'get_delete'
@ -157,15 +160,15 @@ class RestDispatcher(ObjectDispatcher):
#test for "delete", "edit" or "new"
r = self._handle_delete_edit_or_new(state, remainder)
if r:
if r is not None:
return r
#test for custom REST-like attribute
r = self._handle_custom_get(state, remainder)
if r:
if r is not None:
return r
current_path = remainder[0]
current_path = state.path_translator(remainder[0])
if self._is_exposed(current_controller, current_path):
state.add_method(getattr(current_controller, current_path), remainder[1:])
return state
@ -230,7 +233,7 @@ class RestDispatcher(ObjectDispatcher):
state.http_method = method
r = self._check_for_sub_controllers(state, remainder)
if r:
if r is not None:
return r
if state.http_method in self._handler_lookup.keys():

View File

@ -5,14 +5,17 @@ Copyright (c) Chrispther Perkins
MIT License
"""
import collections, sys
import collections, sys, string
from inspect import getargspec
__all__ = [
'get_argspec', 'get_params_with_argspec', 'remove_argspec_params_from_params', 'method_matches_args',
'Path'
'get_argspec', 'get_params_with_argspec', 'remove_argspec_params_from_params',
'method_matches_args', 'Path', 'default_path_translator'
]
from inspect import getargspec
_PY2 = bool(sys.version_info[0] == 2)
_cached_argspecs = {}
def get_argspec(func):
@ -146,7 +149,27 @@ def method_matches_args(method, params, remainder, lax_params=False):
return False
_PY3 = bool(sys.version_info[0] == 3)
if _PY2: #pragma: no cover
translation_dict = dict([(ord(c), unicode('_')) for c in unicode(string.punctuation)])
translation_string = string.maketrans(string.punctuation,
'_' * len(string.punctuation))
else: #pragma: no cover
translation_dict = None
translation_string = str.maketrans(string.punctuation,
'_' * len(string.punctuation))
def default_path_translator(path_piece):
if isinstance(path_piece, str):
return path_piece.translate(translation_string)
else: #pragma: no cover
return path_piece.translate(translation_dict)
def noop_translation(path_piece):
return path_piece
class Path(collections.deque):
def __init__(self, value=None, separator='/'):
@ -161,7 +184,7 @@ class Path(collections.deque):
separator = self.separator
self.clear()
if _PY3: # pragma: no cover
if not _PY2: # pragma: no cover
string_types = str
else: # pragma: no cover
string_types = basestring

View File

@ -92,6 +92,8 @@ class MockDispatcherWithNoDefault(ObjectDispatcher):
def index(self):
pass
sub_child = MockDispatcher()
mock_dispatcher_with_no_default = MockDispatcherWithNoDefault()
class MockDispatcherWithIndexWithArgVars(ObjectDispatcher):
@ -242,3 +244,54 @@ class TestDispatcher:
req = MockRequest('/get_here', params={'a':1})
state = DispatchState(req, mock_lookup_dispatcher_with_args)
state = mock_lookup_dispatcher_with_args._dispatch(state)
def test_path_translation(self):
req = MockRequest('/no.args.json')
state = DispatchState(req, mock_dispatcher_with_no_default_or_index)
state = mock_dispatcher_with_no_default_or_index._dispatch(state)
assert state.method.__name__ == 'no_args', state.method
def test_path_translation_no_extension(self):
req = MockRequest('/no.args')
state = DispatchState(req, mock_dispatcher_with_no_default_or_index,
strip_extension=False)
state = mock_dispatcher_with_no_default_or_index._dispatch(state)
assert state.method.__name__ == 'no_args', state.method
@raises(HTTPNotFound)
def test_disabled_path_translation_no_extension(self):
req = MockRequest('/no.args')
state = DispatchState(req, mock_dispatcher_with_no_default_or_index,
strip_extension=False, path_translator=None)
state = mock_dispatcher_with_no_default_or_index._dispatch(state)
def test_path_translation_args_skipped(self):
req = MockRequest('/with.args/para.meter1/para.meter2.json')
state = DispatchState(req, mock_dispatcher_with_no_default_or_index)
state = mock_dispatcher_with_no_default_or_index._dispatch(state)
assert state.method.__name__ == 'with_args', state.method
assert 'para.meter1' in state.remainder, state.remainder
assert 'para.meter2' in state.remainder, state.remainder
def test_path_translation_sub_controller(self):
req = MockRequest('/sub.child/with.args/para.meter1/para.meter2.json')
state = DispatchState(req, mock_dispatcher_with_no_default)
state = mock_dispatcher_with_no_default._dispatch(state)
path_pieces = [piece[0] for piece in state.controller_path]
assert 'sub_child' in path_pieces
assert state.method.__name__ == 'with_args', state.method
assert 'para.meter1' in state.remainder, state.remainder
assert 'para.meter2' in state.remainder, state.remainder
def test_path_translation_sub_controller_no_strip_extension(self):
req = MockRequest('/sub.child/with.args/para.meter1/para.meter2.json')
state = DispatchState(req, mock_dispatcher_with_no_default,
strip_extension=False)
state = mock_dispatcher_with_no_default._dispatch(state)
path_pieces = [piece[0] for piece in state.controller_path]
assert 'sub_child' in path_pieces
assert state.method.__name__ == 'with_args', state.method
assert 'para.meter1' in state.remainder, state.remainder
assert 'para.meter2.json' in state.remainder, state.remainder

View File

@ -431,11 +431,9 @@ class TestRestWithSecurity:
req = MockRequest('/direct/a')
state = DispatchState(req)
state = self.dispatcher._dispatch(state)
print state.method
@raises(MockError)
def test_check_security_with_nested_lookup(self):
req = MockRequest('/nested/withsec/a')
state = DispatchState(req)
state = self.dispatcher._dispatch(state)
print state.method

View File

@ -1,13 +1,11 @@
# encoding: utf-8
import sys
from nose.tools import raises
from crank.util import *
from crank.util import _PY2
_PY3 = bool(sys.version_info[0] == 3)
if _PY3:
def u(s): return s
else:
if _PY2:
def u(s): return s.decode('utf-8')
else:
def u(s): return s
def mock_f(self, a, b, c=None, d=50, *args, **kw):
pass
@ -185,10 +183,10 @@ def test_path_unicode():
instance = MockOb()
instance.path = case
if _PY3:
yield assert_path, instance, expected, str
else:
if _PY2:
yield assert_path, instance, expected, unicode
else:
yield assert_path, instance, expected, str
def test_path_slicing():
class MockOb(object):
@ -208,3 +206,10 @@ def test_path_comparison():
assert Path('/foo') == ['', 'foo'], 'list comparison'
assert Path('/foo') == '/foo', 'string comparison'
assert Path(u('/föö')) == u('/föö'), 'string comparison'
def test_path_translation():
translated = default_path_translator('a.b')
assert translated == 'a_b', translated
translated = default_path_translator(u('f.ö.ö'))
assert translated == u('f_ö_ö'), translated