Ensure we always accept trailing /'s in URLs

We implement a new NormalizeURI middleware, and apply it to both the
V1 and V2 APIs. This will ensure all API endpoints accept trailing /'s
without requiring each endpoint to explicitly support or test this.

Change-Id: Ic29cb0351fd24f1316d38df3ba3c2d46622e5633
Closes-Bug: 1334935
This commit is contained in:
Kiall Mac Innes 2014-07-01 15:50:22 +01:00
parent a10c19cf20
commit b87dbe34cb
9 changed files with 131 additions and 242 deletions

View File

@ -43,31 +43,6 @@ cfg.CONF.register_opts([
], group='service:api') ], group='service:api')
class MaintenanceMiddleware(wsgi.Middleware):
def __init__(self, application):
super(MaintenanceMiddleware, self).__init__(application)
LOG.info(_LI('Starting designate maintenance middleware'))
self.enabled = cfg.CONF['service:api'].maintenance_mode
self.role = cfg.CONF['service:api'].maintenance_mode_role
def process_request(self, request):
# If maintaince mode is not enabled, pass the request on as soon as
# possible
if not self.enabled:
return None
# If the caller has the bypass role, let them through
if ('context' in request.environ
and self.role in request.environ['context'].roles):
LOG.warn(_LW('Request authorized to bypass maintenance mode'))
return None
# Otherwise, reject the request with a 503 Service Unavailable
return flask.Response(status=503, headers={'Retry-After': 60})
def auth_pipeline_factory(loader, global_conf, **local_conf): def auth_pipeline_factory(loader, global_conf, **local_conf):
""" """
A paste pipeline replica that keys off of auth_strategy. A paste pipeline replica that keys off of auth_strategy.
@ -172,6 +147,40 @@ class TestContextMiddleware(ContextMiddleware):
all_tenants=all_tenants) all_tenants=all_tenants)
class MaintenanceMiddleware(wsgi.Middleware):
def __init__(self, application):
super(MaintenanceMiddleware, self).__init__(application)
LOG.info(_LI('Starting designate maintenance middleware'))
self.enabled = cfg.CONF['service:api'].maintenance_mode
self.role = cfg.CONF['service:api'].maintenance_mode_role
def process_request(self, request):
# If maintaince mode is not enabled, pass the request on as soon as
# possible
if not self.enabled:
return None
# If the caller has the bypass role, let them through
if ('context' in request.environ
and self.role in request.environ['context'].roles):
LOG.warn(_LW('Request authorized to bypass maintenance mode'))
return None
# Otherwise, reject the request with a 503 Service Unavailable
return flask.Response(status=503, headers={'Retry-After': 60})
class NormalizeURIMiddleware(wsgi.Middleware):
@webob.dec.wsgify
def __call__(self, request):
# Remove any trailing /'s.
request.environ['PATH_INFO'] = request.environ['PATH_INFO'].rstrip('/')
return request.get_response(self.application)
class FaultWrapperMiddleware(wsgi.Middleware): class FaultWrapperMiddleware(wsgi.Middleware):
def __init__(self, application): def __init__(self, application):
super(FaultWrapperMiddleware, self).__init__(application) super(FaultWrapperMiddleware, self).__init__(application)

View File

@ -82,9 +82,6 @@ def factory(global_config, **local_conf):
# Install custom converters (URL param varidators) # Install custom converters (URL param varidators)
app.url_map.converters['uuid'] = UUIDConverter app.url_map.converters['uuid'] = UUIDConverter
# disable strict slashes. This allows trailing slashes in the URLS.
app.url_map.strict_slashes = False
# Ensure all error responses are JSON # Ensure all error responses are JSON
def _json_error(ex): def _json_error(ex):
code = ex.code if isinstance(ex, wexceptions.HTTPException) else 500 code = ex.code if isinstance(ex, wexceptions.HTTPException) else 500

View File

@ -32,6 +32,70 @@ class FakeRequest(object):
return "FakeResponse" return "FakeResponse"
class KeystoneContextMiddlewareTest(ApiTestCase):
def test_process_request(self):
app = middleware.KeystoneContextMiddleware({})
request = FakeRequest()
request.headers = {
'X-Auth-Token': 'AuthToken',
'X-User-ID': 'UserID',
'X-Tenant-ID': 'TenantID',
'X-Roles': 'admin,Member',
}
# Process the request
app.process_request(request)
self.assertIn('context', request.environ)
context = request.environ['context']
self.assertFalse(context.is_admin)
self.assertEqual('AuthToken', context.auth_token)
self.assertEqual('UserID', context.user)
self.assertEqual('TenantID', context.tenant)
self.assertEqual(['admin', 'Member'], context.roles)
def test_process_request_invalid_keystone_token(self):
app = middleware.KeystoneContextMiddleware({})
request = FakeRequest()
request.headers = {
'X-Auth-Token': 'AuthToken',
'X-User-ID': 'UserID',
'X-Tenant-ID': 'TenantID',
'X-Roles': 'admin,Member',
'X-Identity-Status': 'Invalid'
}
# Process the request
response = app(request)
self.assertEqual(response.status_code, 401)
class NoAuthContextMiddlewareTest(ApiTestCase):
def test_process_request(self):
app = middleware.NoAuthContextMiddleware({})
request = FakeRequest()
# Process the request
app.process_request(request)
self.assertIn('context', request.environ)
ctxt = request.environ['context']
self.assertIsNone(ctxt.auth_token)
self.assertEqual('noauth-user', ctxt.user)
self.assertEqual('noauth-project', ctxt.tenant)
self.assertEqual(['admin'], ctxt.roles)
class MaintenanceMiddlewareTest(ApiTestCase): class MaintenanceMiddlewareTest(ApiTestCase):
def test_process_request_disabled(self): def test_process_request_disabled(self):
self.config(maintenance_mode=False, group='service:api') self.config(maintenance_mode=False, group='service:api')
@ -104,68 +168,30 @@ class MaintenanceMiddlewareTest(ApiTestCase):
self.assertEqual(response, 'FakeResponse') self.assertEqual(response, 'FakeResponse')
class KeystoneContextMiddlewareTest(ApiTestCase): class NormalizeURIMiddlewareTest(ApiTestCase):
def test_process_request(self): def test_strip_trailing_slases(self):
app = middleware.KeystoneContextMiddleware({})
request = FakeRequest() request = FakeRequest()
request.environ['PATH_INFO'] = 'resource/'
request.headers = { app = middleware.NormalizeURIMiddleware({})
'X-Auth-Token': 'AuthToken',
'X-User-ID': 'UserID',
'X-Tenant-ID': 'TenantID',
'X-Roles': 'admin,Member',
}
# Process the request # Process the request
app.process_request(request) app(request)
self.assertIn('context', request.environ) # Ensure request's PATH_INFO had the trailing slash removed.
self.assertEqual(request.environ['PATH_INFO'], 'resource')
context = request.environ['context']
self.assertFalse(context.is_admin)
self.assertEqual('AuthToken', context.auth_token)
self.assertEqual('UserID', context.user)
self.assertEqual('TenantID', context.tenant)
self.assertEqual(['admin', 'Member'], context.roles)
def test_process_request_invalid_keystone_token(self):
app = middleware.KeystoneContextMiddleware({})
def test_strip_trailing_slases_multiple(self):
request = FakeRequest() request = FakeRequest()
request.environ['PATH_INFO'] = 'resource///'
request.headers = { app = middleware.NormalizeURIMiddleware({})
'X-Auth-Token': 'AuthToken',
'X-User-ID': 'UserID',
'X-Tenant-ID': 'TenantID',
'X-Roles': 'admin,Member',
'X-Identity-Status': 'Invalid'
}
# Process the request # Process the request
response = app(request) app(request)
self.assertEqual(response.status_code, 401) # Ensure request's PATH_INFO had the trailing slash removed.
self.assertEqual(request.environ['PATH_INFO'], 'resource')
class NoAuthContextMiddlewareTest(ApiTestCase):
def test_process_request(self):
app = middleware.NoAuthContextMiddleware({})
request = FakeRequest()
# Process the request
app.process_request(request)
self.assertIn('context', request.environ)
ctxt = request.environ['context']
self.assertIsNone(ctxt.auth_token)
self.assertEqual('noauth-user', ctxt.user)
self.assertEqual('noauth-project', ctxt.tenant)
self.assertEqual(['admin'], ctxt.roles)
class FaultMiddlewareTest(ApiTestCase): class FaultMiddlewareTest(ApiTestCase):

View File

@ -33,6 +33,10 @@ class ApiV1Test(ApiTestCase):
# Create the application # Create the application
self.app = api_v1.factory({}) self.app = api_v1.factory({})
# Inject the NormalizeURIMiddleware middleware
self.app.wsgi_app = middleware.NormalizeURIMiddleware(
self.app.wsgi_app)
# Inject the FaultWrapper middleware # Inject the FaultWrapper middleware
self.app.wsgi_app = middleware.FaultWrapperMiddleware( self.app.wsgi_app = middleware.FaultWrapperMiddleware(
self.app.wsgi_app) self.app.wsgi_app)

View File

@ -40,15 +40,6 @@ class ApiV1DomainsTest(ApiV1Test):
self.assertIn('name', response.json) self.assertIn('name', response.json)
self.assertEqual(response.json['name'], fixture['name']) self.assertEqual(response.json['name'], fixture['name'])
@patch.object(central_service.Service, 'create_domain')
def test_create_domain_trailing_slash(self, mock):
# Create a server
self.create_server()
self.post('domains/', data=self.get_domain_fixture(0))
# verify that the central service is called
self.assertTrue(mock.called)
def test_create_domain_junk(self): def test_create_domain_junk(self):
# Create a server # Create a server
self.create_server() self.create_server()
@ -181,13 +172,6 @@ class ApiV1DomainsTest(ApiV1Test):
self.assertIn('domains', response.json) self.assertIn('domains', response.json)
self.assertEqual(2, len(response.json['domains'])) self.assertEqual(2, len(response.json['domains']))
@patch.object(central_service.Service, 'find_domains')
def test_get_domains_trailing_slash(self, mock):
self.get('domains/')
# verify that the central service is called
self.assertTrue(mock.called)
@patch.object(central_service.Service, 'find_domains', @patch.object(central_service.Service, 'find_domains',
side_effect=messaging.MessagingTimeout()) side_effect=messaging.MessagingTimeout())
def test_get_domains_timeout(self, _): def test_get_domains_timeout(self, _):
@ -202,16 +186,6 @@ class ApiV1DomainsTest(ApiV1Test):
self.assertIn('id', response.json) self.assertIn('id', response.json)
self.assertEqual(response.json['id'], domain['id']) self.assertEqual(response.json['id'], domain['id'])
@patch.object(central_service.Service, 'get_domain')
def test_get_domain_trailing_slash(self, mock):
# Create a domain
domain = self.create_domain()
self.get('domains/%s/' % domain['id'])
# verify that the central service is called
self.assertTrue(mock.called)
@patch.object(central_service.Service, 'get_domain', @patch.object(central_service.Service, 'get_domain',
side_effect=messaging.MessagingTimeout()) side_effect=messaging.MessagingTimeout())
def test_get_domain_timeout(self, _): def test_get_domain_timeout(self, _):
@ -245,18 +219,6 @@ class ApiV1DomainsTest(ApiV1Test):
self.assertIn('email', response.json) self.assertIn('email', response.json)
self.assertEqual(response.json['email'], 'prefix-%s' % domain['email']) self.assertEqual(response.json['email'], 'prefix-%s' % domain['email'])
@patch.object(central_service.Service, 'update_domain')
def test_update_domain_trailing_slash(self, mock):
# Create a domain
domain = self.create_domain()
data = {'email': 'prefix-%s' % domain['email']}
self.put('domains/%s/' % domain['id'], data=data)
# verify that the central service is called
self.assertTrue(mock.called)
def test_update_domain_junk(self): def test_update_domain_junk(self):
# Create a domain # Create a domain
domain = self.create_domain() domain = self.create_domain()
@ -326,16 +288,6 @@ class ApiV1DomainsTest(ApiV1Test):
# Esnure we can no longer fetch the domain # Esnure we can no longer fetch the domain
self.get('domains/%s' % domain['id'], status_code=404) self.get('domains/%s' % domain['id'], status_code=404)
@patch.object(central_service.Service, 'delete_domain')
def test_delete_domain_trailing_slash(self, mock):
# Create a domain
domain = self.create_domain()
self.delete('domains/%s/' % domain['id'])
# verify that the central service is called
self.assertTrue(mock.called)
@patch.object(central_service.Service, 'delete_domain', @patch.object(central_service.Service, 'delete_domain',
side_effect=messaging.MessagingTimeout()) side_effect=messaging.MessagingTimeout())
def test_delete_domain_timeout(self, _): def test_delete_domain_timeout(self, _):

View File

@ -65,21 +65,6 @@ class ApiV1RecordsTest(ApiV1Test):
self.assertIn('name', response.json) self.assertIn('name', response.json)
self.assertEqual(response.json['name'], fixture['name']) self.assertEqual(response.json['name'], fixture['name'])
def test_create_record_trailing_slash(self):
fixture = self.get_record_fixture(self.recordset['type'])
fixture.update({
'name': self.recordset['name'],
'type': self.recordset['type'],
})
# Create a record with a trailing slash
response = self.post('domains/%s/records/' % self.domain['id'],
data=fixture)
self.assertIn('id', response.json)
self.assertIn('name', response.json)
self.assertEqual(response.json['name'], fixture['name'])
def test_create_record_junk(self): def test_create_record_junk(self):
fixture = self.get_record_fixture(self.recordset['type']) fixture = self.get_record_fixture(self.recordset['type'])
fixture.update({ fixture.update({
@ -282,13 +267,6 @@ class ApiV1RecordsTest(ApiV1Test):
self.assertIn('records', response.json) self.assertIn('records', response.json)
self.assertEqual(2, len(response.json['records'])) self.assertEqual(2, len(response.json['records']))
@patch.object(central_service.Service, 'find_records')
def test_get_records_trailing_slash(self, mock):
self.get('domains/%s/records/' % self.domain['id'])
# verify that the central service is called
self.assertTrue(mock.called)
@patch.object(central_service.Service, 'find_records', @patch.object(central_service.Service, 'find_records',
side_effect=messaging.MessagingTimeout()) side_effect=messaging.MessagingTimeout())
def test_get_records_timeout(self, _): def test_get_records_timeout(self, _):
@ -315,17 +293,6 @@ class ApiV1RecordsTest(ApiV1Test):
self.assertEqual(response.json['name'], self.recordset['name']) self.assertEqual(response.json['name'], self.recordset['name'])
self.assertEqual(response.json['type'], self.recordset['type']) self.assertEqual(response.json['type'], self.recordset['type'])
@patch.object(central_service.Service, 'get_recordset')
def test_get_record_trailing_slash(self, mock):
# Create a record
record = self.create_record(self.domain, self.recordset)
self.get('domains/%s/records/%s/' % (self.domain['id'],
record['id']))
# verify that the central service is called
self.assertTrue(mock.called)
def test_update_record(self): def test_update_record(self):
# Create a record # Create a record
record = self.create_record(self.domain, self.recordset) record = self.create_record(self.domain, self.recordset)
@ -360,20 +327,6 @@ class ApiV1RecordsTest(ApiV1Test):
self.assertEqual(response.json['type'], self.recordset['type']) self.assertEqual(response.json['type'], self.recordset['type'])
self.assertEqual(response.json['ttl'], 100) self.assertEqual(response.json['ttl'], 100)
@patch.object(central_service.Service, 'update_record')
def test_update_record_trailing_slash(self, mock):
# Create a record
record = self.create_record(self.domain, self.recordset)
data = {'ttl': 100}
self.put('domains/%s/records/%s/' % (self.domain['id'],
record['id']),
data=data)
# verify that the central service is called
self.assertTrue(mock.called)
def test_update_record_junk(self): def test_update_record_junk(self):
# Create a record # Create a record
record = self.create_record(self.domain, self.recordset) record = self.create_record(self.domain, self.recordset)
@ -447,17 +400,6 @@ class ApiV1RecordsTest(ApiV1Test):
record['id']), record['id']),
status_code=404) status_code=404)
@patch.object(central_service.Service, 'get_domain')
def test_delete_record_trailing_slash(self, mock):
# Create a record
record = self.create_record(self.domain, self.recordset)
self.delete('domains/%s/records/%s/' % (self.domain['id'],
record['id']))
# verify that the central service is called
self.assertTrue(mock.called)
@patch.object(central_service.Service, 'get_domain', @patch.object(central_service.Service, 'get_domain',
side_effect=messaging.MessagingTimeout()) side_effect=messaging.MessagingTimeout())
def test_delete_record_timeout(self, _): def test_delete_record_timeout(self, _):

View File

@ -43,14 +43,6 @@ class ApiV1ServersTest(ApiV1Test):
self.assertIn('name', response.json) self.assertIn('name', response.json)
self.assertEqual(response.json['name'], fixture['name']) self.assertEqual(response.json['name'], fixture['name'])
@patch.object(central_service.Service, 'create_server')
def test_create_server_trailing_slash(self, mock):
# Create a server with a trailing slash
self.post('servers/', data=self.get_server_fixture(0))
# verify that the central service is called
self.assertTrue(mock.called)
def test_create_server_junk(self): def test_create_server_junk(self):
# Create a server # Create a server
fixture = self.get_server_fixture(0) fixture = self.get_server_fixture(0)
@ -99,13 +91,6 @@ class ApiV1ServersTest(ApiV1Test):
self.assertIn('servers', response.json) self.assertIn('servers', response.json)
self.assertEqual(2, len(response.json['servers'])) self.assertEqual(2, len(response.json['servers']))
@patch.object(central_service.Service, 'find_servers')
def test_get_servers_trailing_slash(self, mock):
self.get('servers/')
# verify that the central service is called
self.assertTrue(mock.called)
@patch.object(central_service.Service, 'find_servers', @patch.object(central_service.Service, 'find_servers',
side_effect=messaging.MessagingTimeout()) side_effect=messaging.MessagingTimeout())
def test_get_servers_timeout(self, _): def test_get_servers_timeout(self, _):
@ -120,16 +105,6 @@ class ApiV1ServersTest(ApiV1Test):
self.assertIn('id', response.json) self.assertIn('id', response.json)
self.assertEqual(response.json['id'], server['id']) self.assertEqual(response.json['id'], server['id'])
@patch.object(central_service.Service, 'get_server')
def test_get_server_trailing_slash(self, mock):
# Create a server
server = self.create_server()
self.get('servers/%s/' % server['id'])
# verify that the central service is called
self.assertTrue(mock.called)
@patch.object(central_service.Service, 'get_server', @patch.object(central_service.Service, 'get_server',
side_effect=messaging.MessagingTimeout()) side_effect=messaging.MessagingTimeout())
def test_get_server_timeout(self, _): def test_get_server_timeout(self, _):
@ -156,18 +131,6 @@ class ApiV1ServersTest(ApiV1Test):
self.assertIn('name', response.json) self.assertIn('name', response.json)
self.assertEqual(response.json['name'], 'test.example.org.') self.assertEqual(response.json['name'], 'test.example.org.')
@patch.object(central_service.Service, 'update_server')
def test_update_server_trailing_slash(self, mock):
# Create a server
server = self.create_server()
data = {'name': 'test.example.org.'}
self.put('servers/%s/' % server['id'], data=data)
# verify that the central service is called
self.assertTrue(mock.called)
def test_update_server_junk(self): def test_update_server_junk(self):
# Create a server # Create a server
server = self.create_server() server = self.create_server()
@ -218,16 +181,6 @@ class ApiV1ServersTest(ApiV1Test):
# Also, verify we cannot delete last remaining server # Also, verify we cannot delete last remaining server
self.delete('servers/%s' % server2['id'], status_code=400) self.delete('servers/%s' % server2['id'], status_code=400)
@patch.object(central_service.Service, 'delete_server')
def test_delete_server_trailing_slash(self, mock):
# Create a server
server = self.create_server()
self.delete('servers/%s/' % server['id'])
# verify that the central service is called
self.assertTrue(mock.called)
@patch.object(central_service.Service, 'delete_server', @patch.object(central_service.Service, 'delete_server',
side_effect=messaging.MessagingTimeout()) side_effect=messaging.MessagingTimeout())
def test_delete_server_timeout(self, _): def test_delete_server_timeout(self, _):

View File

@ -43,6 +43,9 @@ class ApiV2TestCase(ApiTestCase):
# Create the application # Create the application
self.app = api_v2.factory({}) self.app = api_v2.factory({})
# Inject the NormalizeURIMiddleware middleware
self.app = middleware.NormalizeURIMiddleware(self.app)
# Inject the FaultWrapper middleware # Inject the FaultWrapper middleware
self.app = middleware.FaultWrapperMiddleware(self.app) self.app = middleware.FaultWrapperMiddleware(self.app)

View File

@ -9,16 +9,16 @@ paste.app_factory = designate.api.versions:factory
[composite:osapi_dns_v1] [composite:osapi_dns_v1]
use = call:designate.api.middleware:auth_pipeline_factory use = call:designate.api.middleware:auth_pipeline_factory
noauth = request_id noauthcontext maintenance faultwrapper osapi_dns_app_v1 noauth = request_id noauthcontext maintenance faultwrapper normalizeuri osapi_dns_app_v1
keystone = request_id authtoken keystonecontext maintenance faultwrapper osapi_dns_app_v1 keystone = request_id authtoken keystonecontext maintenance faultwrapper normalizeuri osapi_dns_app_v1
[app:osapi_dns_app_v1] [app:osapi_dns_app_v1]
paste.app_factory = designate.api.v1:factory paste.app_factory = designate.api.v1:factory
[composite:osapi_dns_v2] [composite:osapi_dns_v2]
use = call:designate.api.middleware:auth_pipeline_factory use = call:designate.api.middleware:auth_pipeline_factory
noauth = request_id noauthcontext maintenance faultwrapper osapi_dns_app_v2 noauth = request_id noauthcontext maintenance faultwrapper normalizeuri osapi_dns_app_v2
keystone = request_id authtoken keystonecontext maintenance faultwrapper osapi_dns_app_v2 keystone = request_id authtoken keystonecontext maintenance faultwrapper normalizeuri osapi_dns_app_v2
[app:osapi_dns_app_v2] [app:osapi_dns_app_v2]
paste.app_factory = designate.api.v2:factory paste.app_factory = designate.api.v2:factory
@ -26,17 +26,20 @@ paste.app_factory = designate.api.v2:factory
[filter:request_id] [filter:request_id]
paste.filter_factory = designate.openstack.common.middleware.request_id:RequestIdMiddleware.factory paste.filter_factory = designate.openstack.common.middleware.request_id:RequestIdMiddleware.factory
[filter:maintenance]
paste.filter_factory = designate.api.middleware:MaintenanceMiddleware.factory
[filter:noauthcontext] [filter:noauthcontext]
paste.filter_factory = designate.api.middleware:NoAuthContextMiddleware.factory paste.filter_factory = designate.api.middleware:NoAuthContextMiddleware.factory
[filter:authtoken]
paste.filter_factory = keystoneclient.middleware.auth_token:filter_factory
[filter:keystonecontext] [filter:keystonecontext]
paste.filter_factory = designate.api.middleware:KeystoneContextMiddleware.factory paste.filter_factory = designate.api.middleware:KeystoneContextMiddleware.factory
[filter:maintenance]
paste.filter_factory = designate.api.middleware:MaintenanceMiddleware.factory
[filter:normalizeuri]
paste.filter_factory = designate.api.middleware:NormalizeURIMiddleware.factory
[filter:faultwrapper] [filter:faultwrapper]
paste.filter_factory = designate.api.middleware:FaultWrapperMiddleware.factory paste.filter_factory = designate.api.middleware:FaultWrapperMiddleware.factory
[filter:authtoken]
paste.filter_factory = keystoneclient.middleware.auth_token:filter_factory