From 0cef592c54a9060d554687d913036a7524f10abf Mon Sep 17 00:00:00 2001 From: Travis Hobrla Date: Mon, 11 May 2015 13:41:24 -0700 Subject: [PATCH] Avoid using expired tokens from credential storage. This covers the case where credential storage contains an expired (but different) token, as well as the case where the token read from storage is almost expired and results in a 401 when resending the original request. --- oauth2client/client.py | 22 ++++++++++----- tests/test_file.py | 61 ++++++++++++++++++++++++++++++++++++------ 2 files changed, 68 insertions(+), 15 deletions(-) diff --git a/oauth2client/client.py b/oauth2client/client.py index b6fd369..d808ed4 100644 --- a/oauth2client/client.py +++ b/oauth2client/client.py @@ -557,14 +557,20 @@ class OAuth2Credentials(Credentials): resp, content = request_orig(uri, method, body, clean_headers(headers), redirections, connection_type) - if resp.status in REFRESH_STATUS_CODES: - logger.info('Refreshing due to a %s', resp.status) + # A stored token may expire between the time it is retrieved and the time + # the request is made, so we may need to try twice. + max_refresh_attempts = 2 + for refresh_attempt in range(max_refresh_attempts): + if resp.status not in REFRESH_STATUS_CODES: + break + logger.info('Refreshing due to a %s (attempt %s/%s)', resp.status, + refresh_attempt + 1, max_refresh_attempts) self._refresh(request_orig) self.apply(headers) - return request_orig(uri, method, body, clean_headers(headers), - redirections, connection_type) - else: - return (resp, content) + resp, content = request_orig(uri, method, body, clean_headers(headers), + redirections, connection_type) + + return (resp, content) # Replace the request method with our own closure. http.request = new_request @@ -757,8 +763,10 @@ class OAuth2Credentials(Credentials): self.store.acquire_lock() try: new_cred = self.store.locked_get() + if (new_cred and not new_cred.invalid and - new_cred.access_token != self.access_token): + new_cred.access_token != self.access_token and + not new_cred.access_token_expired): logger.info('Updated access_token read from Storage') self._updateFromCredential(new_cred) else: diff --git a/tests/test_file.py b/tests/test_file.py index ac81bce..020cedb 100644 --- a/tests/test_file.py +++ b/tests/test_file.py @@ -24,7 +24,6 @@ __author__ = 'jcgregorio@google.com (Joe Gregorio)' import copy import datetime -import httplib2 import json import os import pickle @@ -32,13 +31,12 @@ import stat import tempfile import unittest -from oauth2client import GOOGLE_TOKEN_URI +from .http_mock import HttpMockSequence from oauth2client import file from oauth2client import locked_file from oauth2client import multistore_file from oauth2client import util from oauth2client.client import AccessTokenCredentials -from oauth2client.client import AssertionCredentials from oauth2client.client import OAuth2Credentials try: # Python2 @@ -64,11 +62,12 @@ class OAuth2ClientFileTests(unittest.TestCase): except OSError: pass - def create_test_credentials(self, client_id='some_client_id'): + def create_test_credentials(self, client_id='some_client_id', + expiration=None): access_token = 'foo' client_secret = 'cOuDdkfjxxnv+' refresh_token = '1/0/a.df219fjls0' - token_expiry = datetime.datetime.utcnow() + token_expiry = expiration or datetime.datetime.utcnow() token_uri = 'https://www.google.com/accounts/o8/oauth2/token' user_agent = 'refresh_checker/1.0' @@ -119,8 +118,55 @@ class OAuth2ClientFileTests(unittest.TestCase): self.assertEquals(data['_class'], 'OAuth2Credentials') self.assertEquals(data['_module'], OAuth2Credentials.__module__) - def test_token_refresh(self): - credentials = self.create_test_credentials() + def test_token_refresh_store_expired(self): + expiration = datetime.datetime.utcnow() - datetime.timedelta(minutes=15) + credentials = self.create_test_credentials(expiration=expiration) + + s = file.Storage(FILENAME) + s.put(credentials) + credentials = s.get() + new_cred = copy.copy(credentials) + new_cred.access_token = 'bar' + s.put(new_cred) + + access_token = '1/3w' + token_response = {'access_token': access_token, 'expires_in': 3600} + http = HttpMockSequence([ + ({'status': '200'}, json.dumps(token_response).encode('utf-8')), + ]) + + credentials._refresh(http.request) + self.assertEquals(credentials.access_token, access_token) + + def test_token_refresh_store_expires_soon(self): + # Tests the case where an access token that is valid when it is read from + # the store expires before the original request succeeds. + expiration = datetime.datetime.utcnow() + datetime.timedelta(minutes=15) + credentials = self.create_test_credentials(expiration=expiration) + + s = file.Storage(FILENAME) + s.put(credentials) + credentials = s.get() + new_cred = copy.copy(credentials) + new_cred.access_token = 'bar' + s.put(new_cred) + + access_token = '1/3w' + token_response = {'access_token': access_token, 'expires_in': 3600} + http = HttpMockSequence([ + ({'status': '401'}, 'Initial token expired'), + ({'status': '401'}, 'Store token expired'), + ({'status': '200'}, json.dumps(token_response).encode('utf-8')), + ({'status': '200'}, 'Valid response to original request') + ]) + + credentials.authorize(http) + http.request('https://example.com') + self.assertEquals(credentials.access_token, access_token) + + def test_token_refresh_good_store(self): + expiration = datetime.datetime.utcnow() + datetime.timedelta(minutes=15) + credentials = self.create_test_credentials(expiration=expiration) s = file.Storage(FILENAME) s.put(credentials) @@ -287,7 +333,6 @@ class OAuth2ClientFileTests(unittest.TestCase): self.assertEqual(credentials.access_token, stored_credentials.access_token) - def test_multistore_file_get_all_keys(self): # start with no keys keys = multistore_file.get_all_credential_keys(FILENAME)