Catch an edge case in RestController routing and properly return an HTTP 400.

Change-Id: I0e57cf719a2c3372ebc61efe71a832b0482a0a3e
This commit is contained in:
Ryan Petrello
2014-08-22 10:14:18 -04:00
parent 7a086c3713
commit e07e75e6b7
2 changed files with 136 additions and 15 deletions

View File

@@ -43,6 +43,19 @@ class RestController(object):
return argspec.args[3:]
return argspec.args[1:]
def _handle_bad_rest_arguments(self, controller, remainder, request):
"""
Ensure that the argspec for a discovered controller actually matched
the positional arguments in the request path. If not, raise
a webob.exc.HTTPBadRequest.
"""
argspec = self._get_args_for_controller(controller)
fixed_args = len(argspec) - len(
request.pecan.get('routing_args', [])
)
if len(remainder) < fixed_args:
abort(400)
@expose()
def _route(self, args, request=None):
'''
@@ -89,10 +102,10 @@ class RestController(object):
_lookup_result = self._handle_lookup(args, request)
if _lookup_result:
return _lookup_result
except exc.HTTPNotFound:
except (exc.HTTPClientError, exc.HTTPNotFound):
#
# If the matching handler results in a 404, attempt to handle
# a _lookup method (if it exists)
# If the matching handler results in a 400 or 404, attempt to
# handle a _lookup method (if it exists)
#
_lookup_result = self._handle_lookup(args, request)
if _lookup_result:
@@ -201,14 +214,10 @@ class RestController(object):
# route to a get_all or get if no additional parts are available
if not remainder or remainder == ['']:
remainder = list(six.moves.filter(bool, remainder))
controller = self._find_controller('get_all', 'get')
if controller:
argspec = self._get_args_for_controller(controller)
fixed_args = len(argspec) - len(
request.pecan.get('routing_args', [])
)
if len(remainder) < fixed_args:
abort(404)
self._handle_bad_rest_arguments(controller, remainder, request)
return controller, []
abort(404)
@@ -232,6 +241,7 @@ class RestController(object):
# finally, check for the regular get_one/get requests
controller = self._find_controller('get_one', 'get')
if controller:
self._handle_bad_rest_arguments(controller, remainder, request)
return controller, remainder
abort(404)

View File

@@ -7,7 +7,7 @@ except:
from six import b as b_
from pecan import abort, expose, make_app, response
from pecan import abort, expose, make_app, response, redirect
from pecan.rest import RestController
from pecan.tests import PecanTestCase
@@ -681,6 +681,117 @@ class TestRestController(PecanTestCase):
assert r.status_int == 200
assert len(loads(r.body.decode())['items']) == 1
def test_nested_get_all(self):
class BarsController(RestController):
@expose()
def get_one(self, foo_id, id):
return '4'
@expose()
def get_all(self, foo_id):
return '3'
class FoosController(RestController):
bars = BarsController()
@expose()
def get_one(self, id):
return '2'
@expose()
def get_all(self):
return '1'
class RootController(object):
foos = FoosController()
# create the app
app = TestApp(make_app(RootController()))
r = app.get('/foos/')
assert r.status_int == 200
assert r.body == b_('1')
r = app.get('/foos/1/')
assert r.status_int == 200
assert r.body == b_('2')
r = app.get('/foos/1/bars/')
assert r.status_int == 200
assert r.body == b_('3')
r = app.get('/foos/1/bars/2/')
assert r.status_int == 200
assert r.body == b_('4')
r = app.get('/foos/bars/', status=400)
assert r.status_int == 400
r = app.get('/foos/bars/1', status=400)
assert r.status_int == 400
def test_nested_get_all_with_lookup(self):
class BarsController(RestController):
@expose()
def get_one(self, foo_id, id):
return '4'
@expose()
def get_all(self, foo_id):
return '3'
@expose('json')
def _lookup(self, id, *remainder):
redirect('/lookup-hit/')
class FoosController(RestController):
bars = BarsController()
@expose()
def get_one(self, id):
return '2'
@expose()
def get_all(self):
return '1'
class RootController(object):
foos = FoosController()
# create the app
app = TestApp(make_app(RootController()))
r = app.get('/foos/')
assert r.status_int == 200
assert r.body == b_('1')
r = app.get('/foos/1/')
assert r.status_int == 200
assert r.body == b_('2')
r = app.get('/foos/1/bars/')
assert r.status_int == 200
assert r.body == b_('3')
r = app.get('/foos/1/bars/2/')
assert r.status_int == 200
assert r.body == b_('4')
r = app.get('/foos/bars/', status=400)
assert r.status_int == 400
r = app.get('/foos/bars/', status=400)
r = app.get('/foos/bars/1')
assert r.status_int == 302
assert r.headers['Location'].endswith('/lookup-hit/')
def test_bad_rest(self):
class ThingsController(RestController):
@@ -773,16 +884,16 @@ class TestRestController(PecanTestCase):
# test get_all
r = app.get('/foos')
assert r.status_int == 200
assert r.body == b_(dumps(dict(items=FoosController.data)))
self.assertEqual(r.status_int, 200)
self.assertEqual(r.body, b_(dumps(dict(items=FoosController.data))))
# test nested get_all
r = app.get('/foos/1/bars')
assert r.status_int == 200
assert r.body == b_(dumps(dict(items=BarsController.data[1])))
self.assertEqual(r.status_int, 200)
self.assertEqual(r.body, b_(dumps(dict(items=BarsController.data[1]))))
r = app.get('/foos/bars', expect_errors=True)
assert r.status_int == 404
self.assertEqual(r.status_int, 400)
def test_custom_with_trailing_slash(self):