Add HttpAccessTokenRefreshError

This adds an exception class including the HTTP status for access
token refresh requests made over HTTP.

Fixes https://github.com/google/oauth2client/issues/309
This commit is contained in:
Travis Hobrla
2015-09-21 13:46:33 -07:00
parent 29347a9135
commit 7af949f52f
3 changed files with 22 additions and 12 deletions

View File

@@ -134,6 +134,13 @@ class AccessTokenRefreshError(Error):
"""Error trying to refresh an expired access token."""
class HttpAccessTokenRefreshError(AccessTokenRefreshError):
"""Error (with HTTP status) trying to refresh an expired access token."""
def __init__(self, *args, **kwargs):
super(HttpAccessTokenRefreshError, self).__init__(*args)
self.status = kwargs.get('status')
class TokenRevokeError(Error):
"""Error trying to revoke a token."""
@@ -830,7 +837,7 @@ class OAuth2Credentials(Credentials):
refresh request.
Raises:
AccessTokenRefreshError: When the refresh fails.
HttpAccessTokenRefreshError: When the refresh fails.
"""
if not self.store:
self._do_refresh_request(http_request)
@@ -858,7 +865,7 @@ class OAuth2Credentials(Credentials):
refresh request.
Raises:
AccessTokenRefreshError: When the refresh fails.
HttpAccessTokenRefreshError: When the refresh fails.
"""
body = self._generate_refresh_request_body()
headers = self._generate_refresh_request_headers()
@@ -898,7 +905,7 @@ class OAuth2Credentials(Credentials):
self.store.locked_put(self)
except (TypeError, ValueError):
pass
raise AccessTokenRefreshError(error_msg)
raise HttpAccessTokenRefreshError(error_msg, status=resp.status)
def _revoke(self, http_request):
"""Revokes this credential and deletes the stored copy (if it exists).

View File

@@ -23,7 +23,7 @@ from six.moves import urllib
from oauth2client._helpers import _from_bytes
from oauth2client import util
from oauth2client.client import AccessTokenRefreshError
from oauth2client.client import HttpAccessTokenRefreshError
from oauth2client.client import AssertionCredentials
@@ -80,7 +80,7 @@ class AppAssertionCredentials(AssertionCredentials):
the refresh request.
Raises:
AccessTokenRefreshError: When the refresh fails.
HttpAccessTokenRefreshError: When the refresh fails.
"""
query = '?scope=%s' % urllib.parse.quote(self.scope, '')
uri = META.replace('{?scope}', query)
@@ -90,13 +90,14 @@ class AppAssertionCredentials(AssertionCredentials):
try:
d = json.loads(content)
except Exception as e:
raise AccessTokenRefreshError(str(e))
raise HttpAccessTokenRefreshError(str(e),
status=response.status)
self.access_token = d['accessToken']
else:
if response.status == 404:
content += (' This can occur if a VM was created'
' with no service account or scopes.')
raise AccessTokenRefreshError(content)
raise HttpAccessTokenRefreshError(content, status=response.status)
@property
def serialization_data(self):

View File

@@ -31,6 +31,7 @@ import unittest
import mock
import six
from six.moves import http_client
from six.moves import urllib
from .http_mock import CacheMock
@@ -43,7 +44,7 @@ from oauth2client import client
from oauth2client import util as oauth2client_util
from oauth2client.client import AccessTokenCredentials
from oauth2client.client import AccessTokenCredentialsError
from oauth2client.client import AccessTokenRefreshError
from oauth2client.client import HttpAccessTokenRefreshError
from oauth2client.client import ADC_HELP_MSG
from oauth2client.client import AssertionCredentials
from oauth2client.client import AUTHORIZED_USER
@@ -690,14 +691,15 @@ class BasicCredentialsTests(unittest.TestCase):
for status_code in REFRESH_STATUS_CODES:
http = HttpMockSequence([
({'status': status_code}, b''),
({'status': '400'}, b'{"error":"access_denied"}'),
({'status': http_client.BAD_REQUEST},
b'{"error":"access_denied"}'),
])
http = self.credentials.authorize(http)
try:
http.request('http://example.com')
self.fail('should raise AccessTokenRefreshError exception')
except AccessTokenRefreshError:
pass
self.fail('should raise HttpAccessTokenRefreshError exception')
except HttpAccessTokenRefreshError as e:
self.assertEqual(http_client.BAD_REQUEST, e.status)
self.assertTrue(self.credentials.access_token_expired)
self.assertEqual(None, self.credentials.token_response)