Avoid using expired tokens from credential storage.

This covers the case where credential storage contains an expired (but different) token, as well as the case where the token read from storage is almost expired and results in a 401 when resending the original request.
This commit is contained in:
Travis Hobrla
2015-05-11 13:41:24 -07:00
parent 3a742eefd8
commit 0cef592c54
2 changed files with 68 additions and 15 deletions

View File

@@ -557,14 +557,20 @@ class OAuth2Credentials(Credentials):
resp, content = request_orig(uri, method, body, clean_headers(headers),
redirections, connection_type)
if resp.status in REFRESH_STATUS_CODES:
logger.info('Refreshing due to a %s', resp.status)
# A stored token may expire between the time it is retrieved and the time
# the request is made, so we may need to try twice.
max_refresh_attempts = 2
for refresh_attempt in range(max_refresh_attempts):
if resp.status not in REFRESH_STATUS_CODES:
break
logger.info('Refreshing due to a %s (attempt %s/%s)', resp.status,
refresh_attempt + 1, max_refresh_attempts)
self._refresh(request_orig)
self.apply(headers)
return request_orig(uri, method, body, clean_headers(headers),
redirections, connection_type)
else:
return (resp, content)
resp, content = request_orig(uri, method, body, clean_headers(headers),
redirections, connection_type)
return (resp, content)
# Replace the request method with our own closure.
http.request = new_request
@@ -757,8 +763,10 @@ class OAuth2Credentials(Credentials):
self.store.acquire_lock()
try:
new_cred = self.store.locked_get()
if (new_cred and not new_cred.invalid and
new_cred.access_token != self.access_token):
new_cred.access_token != self.access_token and
not new_cred.access_token_expired):
logger.info('Updated access_token read from Storage')
self._updateFromCredential(new_cred)
else:

View File

@@ -24,7 +24,6 @@ __author__ = 'jcgregorio@google.com (Joe Gregorio)'
import copy
import datetime
import httplib2
import json
import os
import pickle
@@ -32,13 +31,12 @@ import stat
import tempfile
import unittest
from oauth2client import GOOGLE_TOKEN_URI
from .http_mock import HttpMockSequence
from oauth2client import file
from oauth2client import locked_file
from oauth2client import multistore_file
from oauth2client import util
from oauth2client.client import AccessTokenCredentials
from oauth2client.client import AssertionCredentials
from oauth2client.client import OAuth2Credentials
try:
# Python2
@@ -64,11 +62,12 @@ class OAuth2ClientFileTests(unittest.TestCase):
except OSError:
pass
def create_test_credentials(self, client_id='some_client_id'):
def create_test_credentials(self, client_id='some_client_id',
expiration=None):
access_token = 'foo'
client_secret = 'cOuDdkfjxxnv+'
refresh_token = '1/0/a.df219fjls0'
token_expiry = datetime.datetime.utcnow()
token_expiry = expiration or datetime.datetime.utcnow()
token_uri = 'https://www.google.com/accounts/o8/oauth2/token'
user_agent = 'refresh_checker/1.0'
@@ -119,8 +118,55 @@ class OAuth2ClientFileTests(unittest.TestCase):
self.assertEquals(data['_class'], 'OAuth2Credentials')
self.assertEquals(data['_module'], OAuth2Credentials.__module__)
def test_token_refresh(self):
credentials = self.create_test_credentials()
def test_token_refresh_store_expired(self):
expiration = datetime.datetime.utcnow() - datetime.timedelta(minutes=15)
credentials = self.create_test_credentials(expiration=expiration)
s = file.Storage(FILENAME)
s.put(credentials)
credentials = s.get()
new_cred = copy.copy(credentials)
new_cred.access_token = 'bar'
s.put(new_cred)
access_token = '1/3w'
token_response = {'access_token': access_token, 'expires_in': 3600}
http = HttpMockSequence([
({'status': '200'}, json.dumps(token_response).encode('utf-8')),
])
credentials._refresh(http.request)
self.assertEquals(credentials.access_token, access_token)
def test_token_refresh_store_expires_soon(self):
# Tests the case where an access token that is valid when it is read from
# the store expires before the original request succeeds.
expiration = datetime.datetime.utcnow() + datetime.timedelta(minutes=15)
credentials = self.create_test_credentials(expiration=expiration)
s = file.Storage(FILENAME)
s.put(credentials)
credentials = s.get()
new_cred = copy.copy(credentials)
new_cred.access_token = 'bar'
s.put(new_cred)
access_token = '1/3w'
token_response = {'access_token': access_token, 'expires_in': 3600}
http = HttpMockSequence([
({'status': '401'}, 'Initial token expired'),
({'status': '401'}, 'Store token expired'),
({'status': '200'}, json.dumps(token_response).encode('utf-8')),
({'status': '200'}, 'Valid response to original request')
])
credentials.authorize(http)
http.request('https://example.com')
self.assertEquals(credentials.access_token, access_token)
def test_token_refresh_good_store(self):
expiration = datetime.datetime.utcnow() + datetime.timedelta(minutes=15)
credentials = self.create_test_credentials(expiration=expiration)
s = file.Storage(FILENAME)
s.put(credentials)
@@ -287,7 +333,6 @@ class OAuth2ClientFileTests(unittest.TestCase):
self.assertEqual(credentials.access_token, stored_credentials.access_token)
def test_multistore_file_get_all_keys(self):
# start with no keys
keys = multistore_file.get_all_credential_keys(FILENAME)