diff --git a/novaclient/exceptions.py b/novaclient/exceptions.py index 7e9e0c8ae..d550b754b 100644 --- a/novaclient/exceptions.py +++ b/novaclient/exceptions.py @@ -105,6 +105,19 @@ class ClientException(Exception): return formatted_string +class RetryAfterException(ClientException): + """ + The base exception class for ClientExceptions that use Retry-After header. + """ + def __init__(self, *args, **kwargs): + try: + self.retry_after = int(kwargs.pop('retry_after')) + except (KeyError, ValueError): + self.retry_after = 0 + + super(RetryAfterException, self).__init__(*args, **kwargs) + + class BadRequest(ClientException): """ HTTP 400 - Bad request: you sent some malformed data. @@ -154,23 +167,15 @@ class Conflict(ClientException): message = "Conflict" -class OverLimit(ClientException): +class OverLimit(RetryAfterException): """ HTTP 413 - Over limit: you're over the API limits for this time period. """ http_status = 413 message = "Over limit" - def __init__(self, *args, **kwargs): - try: - self.retry_after = int(kwargs.pop('retry_after')) - except (KeyError, ValueError): - self.retry_after = 0 - super(OverLimit, self).__init__(*args, **kwargs) - - -class RateLimit(OverLimit): +class RateLimit(RetryAfterException): """ HTTP 429 - Rate limit: you've sent too many requests for this time period. """ @@ -220,6 +225,8 @@ def from_response(response, body, url, method=None): if resp.status_code != 200: raise exception_from_response(resp, rest.text) """ + cls = _code_map.get(response.status_code, ClientException) + kwargs = { 'code': response.status_code, 'method': method, @@ -230,7 +237,8 @@ def from_response(response, body, url, method=None): if response.headers: kwargs['request_id'] = response.headers.get('x-compute-request-id') - if 'retry-after' in response.headers: + if (issubclass(cls, RetryAfterException) and + 'retry-after' in response.headers): kwargs['retry_after'] = response.headers.get('retry-after') if body: @@ -245,7 +253,6 @@ def from_response(response, body, url, method=None): kwargs['message'] = message kwargs['details'] = details - cls = _code_map.get(response.status_code, ClientException) return cls(**kwargs) diff --git a/novaclient/tests/test_http.py b/novaclient/tests/test_http.py index 7d9e3ecfa..bd12ed827 100644 --- a/novaclient/tests/test_http.py +++ b/novaclient/tests/test_http.py @@ -44,6 +44,32 @@ unknown_error_response = utils.TestResponse({ }) unknown_error_mock_request = mock.Mock(return_value=unknown_error_response) +retry_after_response = utils.TestResponse({ + "status_code": 413, + "text": '', + "headers": { + "retry-after": "5" + }, +}) +retry_after_mock_request = mock.Mock(return_value=retry_after_response) + +retry_after_no_headers_response = utils.TestResponse({ + "status_code": 413, + "text": '', +}) +retry_after_no_headers_mock_request = mock.Mock( + return_value=retry_after_no_headers_response) + +retry_after_non_supporting_response = utils.TestResponse({ + "status_code": 403, + "text": '', + "headers": { + "retry-after": "5" + }, +}) +retry_after_non_supporting_mock_request = mock.Mock( + return_value=retry_after_non_supporting_response) + def get_client(): cl = client.HTTPClient("username", "password", @@ -154,3 +180,33 @@ class ClientTest(utils.TestCase): self.assertIn('Unknown Error', six.text_type(exc)) else: self.fail('Expected exceptions.ClientException') + + @mock.patch.object(requests, "request", retry_after_mock_request) + def test_retry_after_request(self): + cl = get_client() + + try: + cl.get("/hi") + except exceptions.OverLimit as exc: + self.assertEqual(5, exc.retry_after) + else: + self.fail('Expected exceptions.OverLimit') + + @mock.patch.object(requests, "request", + retry_after_no_headers_mock_request) + def test_retry_after_request_no_headers(self): + cl = get_client() + + try: + cl.get("/hi") + except exceptions.OverLimit as exc: + self.assertEqual(0, exc.retry_after) + else: + self.fail('Expected exceptions.OverLimit') + + @mock.patch.object(requests, "request", + retry_after_non_supporting_mock_request) + def test_retry_after_request_non_supporting_exc(self): + cl = get_client() + + self.assertRaises(exceptions.Forbidden, cl.get, "/hi")