diff --git a/ec2api/api/__init__.py b/ec2api/api/__init__.py index f3838985..61e8ff97 100644 --- a/ec2api/api/__init__.py +++ b/ec2api/api/__init__.py @@ -112,6 +112,12 @@ class RequestLogging(wsgi.Middleware): context=ctxt) +class InvalidCredentialsException(Exception): + def __init__(self, resp): + super(Exception, self).__init__() + self.resp = resp + + class EC2KeystoneAuth(wsgi.Middleware): """Authenticate an EC2 request with keystone and convert to context.""" @@ -120,10 +126,18 @@ class EC2KeystoneAuth(wsgi.Middleware): def __call__(self, req): request_id = context.generate_request_id() - if 'Signature' in req.params: - cred_dict = self._get_creds(req, request_id) - else: - cred_dict = self._get_creds_v4(req, request_id) + try: + if 'Signature' in req.params: + cred_dict = self._get_creds(req, request_id) + else: + cred_dict = self._get_creds_v4(req, request_id) + except InvalidCredentialsException as ex: + return ex.resp + except Exception: + msg = _("Invalid authorization parameters") + return faults.ec2_error_response(request_id, "AuthFailure", msg, + status=400) + access = cred_dict['access'] token_url = CONF.keystone_url + "/ec2tokens" if "ec2" in token_url: @@ -199,13 +213,15 @@ class EC2KeystoneAuth(wsgi.Middleware): signature = req.params.get('Signature') if not signature: msg = _("Signature not provided") - return faults.ec2_error_response(request_id, "AuthFailure", msg, - status=400) + raise InvalidCredentialsException( + faults.ec2_error_response(request_id, "AuthFailure", msg, + status=400)) access = req.params.get('AWSAccessKeyId') if not access: msg = _("Access key not provided") - return faults.ec2_error_response(request_id, "AuthFailure", msg, - status=400) + raise InvalidCredentialsException( + faults.ec2_error_response(request_id, "AuthFailure", msg, + status=400)) # Make a copy of args for authentication and signature verification. auth_params = dict(req.params) @@ -223,25 +239,35 @@ class EC2KeystoneAuth(wsgi.Middleware): return cred_dict def _get_creds_v4(self, req, request_id): - auth = req.environ['HTTP_AUTHORIZATION'].split(',') + auth = req.environ.get('HTTP_AUTHORIZATION') + if not auth: + msg = _("Signature not provided") + raise InvalidCredentialsException( + faults.ec2_error_response(request_id, "AuthFailure", msg, + status=400)) + + auth = auth.split(',') auth = [a.strip() for a in auth] if not auth[0].startswith('AWS4-HMAC-SHA256'): msg = _("Invalid authorization parameters") - return faults.ec2_error_response(request_id, "AuthFailure", msg, - status=400) + raise InvalidCredentialsException( + faults.ec2_error_response(request_id, "AuthFailure", msg, + status=400)) access = auth[0].split('=')[1].split('/')[0] if not access: msg = _("Access key not provided") - return faults.ec2_error_response(request_id, "AuthFailure", msg, - status=400) + raise InvalidCredentialsException( + faults.ec2_error_response(request_id, "AuthFailure", msg, + status=400)) for item in auth: if item.startswith('Signature'): signature = item.split('=')[1] if not signature: msg = _("Signature could not be found in request") - return faults.ec2_error_response(request_id, "AuthFailure", msg, - status=400) + raise InvalidCredentialsException( + faults.ec2_error_response(request_id, "AuthFailure", msg, + status=400)) headers = dict() for key in req.headers: @@ -284,9 +310,6 @@ class EC2KeystoneAuth(wsgi.Middleware): class Requestify(wsgi.Middleware): - def __init__(self, app): - super(Requestify, self).__init__(app) - @webob.dec.wsgify(RequestClass=wsgi.Request) def __call__(self, req): non_args = ['Action', 'Signature', 'AWSAccessKeyId', 'SignatureMethod', diff --git a/ec2api/tests/test_middleware.py b/ec2api/tests/test_middleware.py new file mode 100644 index 00000000..a55fe7ad --- /dev/null +++ b/ec2api/tests/test_middleware.py @@ -0,0 +1,161 @@ +# Copyright 2010 United States Government as represented by the +# Administrator of the National Aeronautics and Space Administration. +# All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + +from lxml import etree +import mock +from oslo.config import cfg +from oslotest import base as test_base +import requests +import webob.dec +import webob.exc + +from ec2api import api as ec2 +from ec2api import context +from ec2api import exception +from ec2api import wsgi + +CONF = cfg.CONF + + +@webob.dec.wsgify +def conditional_forbid(req): + """Helper wsgi app returns 403 if param 'die' is 1.""" + if 'die' in req.params and req.params['die'] == '1': + raise webob.exc.HTTPForbidden() + return 'OK' + + +class ExecutorTestCase(test_base.BaseTestCase): + def setUp(self): + super(ExecutorTestCase, self).setUp() + self.executor = ec2.Executor() + + def _execute(self, invoke): + class Fake(object): + pass + fake_ec2_request = Fake() + fake_ec2_request.invoke = invoke + + fake_wsgi_request = Fake() + + fake_wsgi_request.environ = { + 'ec2api.context': context.get_admin_context(), + 'ec2.request': fake_ec2_request, + } + return self.executor(fake_wsgi_request) + + def _extract_message(self, result): + tree = etree.fromstring(result.body) + return tree.findall('./Errors')[0].find('Error/Message').text + + def _extract_code(self, result): + tree = etree.fromstring(result.body) + return tree.findall('./Errors')[0].find('Error/Code').text + + def test_instance_not_found(self): + def not_found(context): + raise exception.InvalidInstanceIDNotFound(id='i-01') + result = self._execute(not_found) + self.assertIn('i-01', self._extract_message(result)) + self.assertEqual('InvalidInstanceID.NotFound', + self._extract_code(result)) + + def test_instance_not_found_none(self): + def not_found(context): + raise exception.InvalidInstanceIDNotFound(id=None) + + # NOTE(mikal): we want no exception to be raised here, which was what + # was happening in bug/1080406 + result = self._execute(not_found) + self.assertIn('None', self._extract_message(result)) + self.assertEqual('InvalidInstanceID.NotFound', + self._extract_code(result)) + + def test_snapshot_not_found(self): + def not_found(context): + raise exception.InvalidSnapshotNotFound(id='snap-01') + result = self._execute(not_found) + self.assertIn('snap-01', self._extract_message(result)) + self.assertEqual('InvalidSnapshot.NotFound', + self._extract_code(result)) + + def test_volume_not_found(self): + def not_found(context): + raise exception.InvalidVolumeNotFound(id='vol-01') + result = self._execute(not_found) + self.assertIn('vol-01', self._extract_message(result)) + self.assertEqual('InvalidVolume.NotFound', self._extract_code(result)) + + +class FakeResponse(object): + reason = "Test Reason" + + def __init__(self, status_code=400): + self.status_code = status_code + + def json(self): + return {} + + +class KeystoneAuthTestCase(test_base.BaseTestCase): + def setUp(self): + super(KeystoneAuthTestCase, self).setUp() + self.kauth = ec2.EC2KeystoneAuth(conditional_forbid) + + def _validate_ec2_error(self, response, http_status, ec2_code): + self.assertEqual(response.status_code, http_status, + 'Expected HTTP status %s' % http_status) + root_e = etree.XML(response.body) + self.assertEqual(root_e.tag, 'Response', + "Top element must be Response.") + errors_e = root_e.find('Errors') + error_e = errors_e[0] + code_e = error_e.find('Code') + self.assertIsNotNone(code_e, "Code element must be present.") + self.assertEqual(code_e.text, ec2_code) + + def test_no_signature(self): + req = wsgi.Request.blank('/test') + resp = self.kauth(req) + self._validate_ec2_error(resp, 400, 'AuthFailure') + + def test_no_key_id(self): + req = wsgi.Request.blank('/test') + req.GET['Signature'] = 'test-signature' + resp = self.kauth(req) + self._validate_ec2_error(resp, 400, 'AuthFailure') + + @mock.patch.object(requests, 'request', return_value=FakeResponse()) + def test_communication_failure(self, mock_request): + req = wsgi.Request.blank('/test') + req.GET['Signature'] = 'test-signature' + req.GET['AWSAccessKeyId'] = 'test-key-id' + resp = self.kauth(req) + self._validate_ec2_error(resp, 400, 'AuthFailure') + mock_request.assert_called_with('POST', + CONF.keystone_url + '/ec2tokens', + data=mock.ANY, headers=mock.ANY) + + @mock.patch.object(requests, 'request', return_value=FakeResponse(200)) + def test_no_result_data(self, mock_request): + req = wsgi.Request.blank('/test') + req.GET['Signature'] = 'test-signature' + req.GET['AWSAccessKeyId'] = 'test-key-id' + resp = self.kauth(req) + self._validate_ec2_error(resp, 400, 'AuthFailure') + mock_request.assert_called_with('POST', + CONF.keystone_url + '/ec2tokens', + data=mock.ANY, headers=mock.ANY)