Avoid using expired tokens from credential storage.

This covers the case where credential storage contains an expired (but different) token, as well as the case where the token read from storage is almost expired and results in a 401 when resending the original request.
This commit is contained in:
Travis Hobrla
2015-05-11 13:41:24 -07:00
parent 3a742eefd8
commit 0cef592c54
2 changed files with 68 additions and 15 deletions

View File

@@ -557,14 +557,20 @@ class OAuth2Credentials(Credentials):
resp, content = request_orig(uri, method, body, clean_headers(headers), resp, content = request_orig(uri, method, body, clean_headers(headers),
redirections, connection_type) redirections, connection_type)
if resp.status in REFRESH_STATUS_CODES: # A stored token may expire between the time it is retrieved and the time
logger.info('Refreshing due to a %s', resp.status) # the request is made, so we may need to try twice.
max_refresh_attempts = 2
for refresh_attempt in range(max_refresh_attempts):
if resp.status not in REFRESH_STATUS_CODES:
break
logger.info('Refreshing due to a %s (attempt %s/%s)', resp.status,
refresh_attempt + 1, max_refresh_attempts)
self._refresh(request_orig) self._refresh(request_orig)
self.apply(headers) self.apply(headers)
return request_orig(uri, method, body, clean_headers(headers), resp, content = request_orig(uri, method, body, clean_headers(headers),
redirections, connection_type) redirections, connection_type)
else:
return (resp, content) return (resp, content)
# Replace the request method with our own closure. # Replace the request method with our own closure.
http.request = new_request http.request = new_request
@@ -757,8 +763,10 @@ class OAuth2Credentials(Credentials):
self.store.acquire_lock() self.store.acquire_lock()
try: try:
new_cred = self.store.locked_get() new_cred = self.store.locked_get()
if (new_cred and not new_cred.invalid and if (new_cred and not new_cred.invalid and
new_cred.access_token != self.access_token): new_cred.access_token != self.access_token and
not new_cred.access_token_expired):
logger.info('Updated access_token read from Storage') logger.info('Updated access_token read from Storage')
self._updateFromCredential(new_cred) self._updateFromCredential(new_cred)
else: else:

View File

@@ -24,7 +24,6 @@ __author__ = 'jcgregorio@google.com (Joe Gregorio)'
import copy import copy
import datetime import datetime
import httplib2
import json import json
import os import os
import pickle import pickle
@@ -32,13 +31,12 @@ import stat
import tempfile import tempfile
import unittest import unittest
from oauth2client import GOOGLE_TOKEN_URI from .http_mock import HttpMockSequence
from oauth2client import file from oauth2client import file
from oauth2client import locked_file from oauth2client import locked_file
from oauth2client import multistore_file from oauth2client import multistore_file
from oauth2client import util from oauth2client import util
from oauth2client.client import AccessTokenCredentials from oauth2client.client import AccessTokenCredentials
from oauth2client.client import AssertionCredentials
from oauth2client.client import OAuth2Credentials from oauth2client.client import OAuth2Credentials
try: try:
# Python2 # Python2
@@ -64,11 +62,12 @@ class OAuth2ClientFileTests(unittest.TestCase):
except OSError: except OSError:
pass pass
def create_test_credentials(self, client_id='some_client_id'): def create_test_credentials(self, client_id='some_client_id',
expiration=None):
access_token = 'foo' access_token = 'foo'
client_secret = 'cOuDdkfjxxnv+' client_secret = 'cOuDdkfjxxnv+'
refresh_token = '1/0/a.df219fjls0' refresh_token = '1/0/a.df219fjls0'
token_expiry = datetime.datetime.utcnow() token_expiry = expiration or datetime.datetime.utcnow()
token_uri = 'https://www.google.com/accounts/o8/oauth2/token' token_uri = 'https://www.google.com/accounts/o8/oauth2/token'
user_agent = 'refresh_checker/1.0' user_agent = 'refresh_checker/1.0'
@@ -119,8 +118,55 @@ class OAuth2ClientFileTests(unittest.TestCase):
self.assertEquals(data['_class'], 'OAuth2Credentials') self.assertEquals(data['_class'], 'OAuth2Credentials')
self.assertEquals(data['_module'], OAuth2Credentials.__module__) self.assertEquals(data['_module'], OAuth2Credentials.__module__)
def test_token_refresh(self): def test_token_refresh_store_expired(self):
credentials = self.create_test_credentials() expiration = datetime.datetime.utcnow() - datetime.timedelta(minutes=15)
credentials = self.create_test_credentials(expiration=expiration)
s = file.Storage(FILENAME)
s.put(credentials)
credentials = s.get()
new_cred = copy.copy(credentials)
new_cred.access_token = 'bar'
s.put(new_cred)
access_token = '1/3w'
token_response = {'access_token': access_token, 'expires_in': 3600}
http = HttpMockSequence([
({'status': '200'}, json.dumps(token_response).encode('utf-8')),
])
credentials._refresh(http.request)
self.assertEquals(credentials.access_token, access_token)
def test_token_refresh_store_expires_soon(self):
# Tests the case where an access token that is valid when it is read from
# the store expires before the original request succeeds.
expiration = datetime.datetime.utcnow() + datetime.timedelta(minutes=15)
credentials = self.create_test_credentials(expiration=expiration)
s = file.Storage(FILENAME)
s.put(credentials)
credentials = s.get()
new_cred = copy.copy(credentials)
new_cred.access_token = 'bar'
s.put(new_cred)
access_token = '1/3w'
token_response = {'access_token': access_token, 'expires_in': 3600}
http = HttpMockSequence([
({'status': '401'}, 'Initial token expired'),
({'status': '401'}, 'Store token expired'),
({'status': '200'}, json.dumps(token_response).encode('utf-8')),
({'status': '200'}, 'Valid response to original request')
])
credentials.authorize(http)
http.request('https://example.com')
self.assertEquals(credentials.access_token, access_token)
def test_token_refresh_good_store(self):
expiration = datetime.datetime.utcnow() + datetime.timedelta(minutes=15)
credentials = self.create_test_credentials(expiration=expiration)
s = file.Storage(FILENAME) s = file.Storage(FILENAME)
s.put(credentials) s.put(credentials)
@@ -287,7 +333,6 @@ class OAuth2ClientFileTests(unittest.TestCase):
self.assertEqual(credentials.access_token, stored_credentials.access_token) self.assertEqual(credentials.access_token, stored_credentials.access_token)
def test_multistore_file_get_all_keys(self): def test_multistore_file_get_all_keys(self):
# start with no keys # start with no keys
keys = multistore_file.get_all_credential_keys(FILENAME) keys = multistore_file.get_all_credential_keys(FILENAME)