diff --git a/oauth2client/client.py b/oauth2client/client.py index b6fd369..d808ed4 100644 --- a/oauth2client/client.py +++ b/oauth2client/client.py @@ -557,14 +557,20 @@ class OAuth2Credentials(Credentials): resp, content = request_orig(uri, method, body, clean_headers(headers), redirections, connection_type) - if resp.status in REFRESH_STATUS_CODES: - logger.info('Refreshing due to a %s', resp.status) + # A stored token may expire between the time it is retrieved and the time + # the request is made, so we may need to try twice. + max_refresh_attempts = 2 + for refresh_attempt in range(max_refresh_attempts): + if resp.status not in REFRESH_STATUS_CODES: + break + logger.info('Refreshing due to a %s (attempt %s/%s)', resp.status, + refresh_attempt + 1, max_refresh_attempts) self._refresh(request_orig) self.apply(headers) - return request_orig(uri, method, body, clean_headers(headers), - redirections, connection_type) - else: - return (resp, content) + resp, content = request_orig(uri, method, body, clean_headers(headers), + redirections, connection_type) + + return (resp, content) # Replace the request method with our own closure. http.request = new_request @@ -757,8 +763,10 @@ class OAuth2Credentials(Credentials): self.store.acquire_lock() try: new_cred = self.store.locked_get() + if (new_cred and not new_cred.invalid and - new_cred.access_token != self.access_token): + new_cred.access_token != self.access_token and + not new_cred.access_token_expired): logger.info('Updated access_token read from Storage') self._updateFromCredential(new_cred) else: diff --git a/tests/test_file.py b/tests/test_file.py index ac81bce..8efe6cf 100644 --- a/tests/test_file.py +++ b/tests/test_file.py @@ -24,7 +24,6 @@ __author__ = 'jcgregorio@google.com (Joe Gregorio)' import copy import datetime -import httplib2 import json import os import pickle @@ -32,13 +31,12 @@ import stat import tempfile import unittest -from oauth2client import GOOGLE_TOKEN_URI +from .http_mock import HttpMockSequence from oauth2client import file from oauth2client import locked_file from oauth2client import multistore_file from oauth2client import util from oauth2client.client import AccessTokenCredentials -from oauth2client.client import AssertionCredentials from oauth2client.client import OAuth2Credentials try: # Python2 @@ -64,11 +62,12 @@ class OAuth2ClientFileTests(unittest.TestCase): except OSError: pass - def create_test_credentials(self, client_id='some_client_id'): + def create_test_credentials(self, client_id='some_client_id', + expiration=None): access_token = 'foo' client_secret = 'cOuDdkfjxxnv+' refresh_token = '1/0/a.df219fjls0' - token_expiry = datetime.datetime.utcnow() + token_expiry = expiration or datetime.datetime.utcnow() token_uri = 'https://www.google.com/accounts/o8/oauth2/token' user_agent = 'refresh_checker/1.0' @@ -119,8 +118,55 @@ class OAuth2ClientFileTests(unittest.TestCase): self.assertEquals(data['_class'], 'OAuth2Credentials') self.assertEquals(data['_module'], OAuth2Credentials.__module__) - def test_token_refresh(self): - credentials = self.create_test_credentials() + def test_token_refresh_store_expired(self): + expiration = datetime.datetime.utcnow() - datetime.timedelta(minutes=15) + credentials = self.create_test_credentials(expiration=expiration) + + s = file.Storage(FILENAME) + s.put(credentials) + credentials = s.get() + new_cred = copy.copy(credentials) + new_cred.access_token = 'bar' + s.put(new_cred) + + access_token = '1/3w' + token_response = {'access_token': access_token, 'expires_in': 3600} + http = HttpMockSequence([ + ({'status': '200'}, json.dumps(token_response).encode('utf-8')), + ]) + + credentials._refresh(http.request) + self.assertEquals(credentials.access_token, access_token) + + def test_token_refresh_store_expires_soon(self): + # Tests the case where an access token that is valid when it is read from + # the store expires before the original request succeeds. + expiration = datetime.datetime.utcnow() + datetime.timedelta(minutes=15) + credentials = self.create_test_credentials(expiration=expiration) + + s = file.Storage(FILENAME) + s.put(credentials) + credentials = s.get() + new_cred = copy.copy(credentials) + new_cred.access_token = 'bar' + s.put(new_cred) + + access_token = '1/3w' + token_response = {'access_token': access_token, 'expires_in': 3600} + http = HttpMockSequence([ + ({'status': '401'}, b'Initial token expired'), + ({'status': '401'}, b'Store token expired'), + ({'status': '200'}, json.dumps(token_response).encode('utf-8')), + ({'status': '200'}, b'Valid response to original request') + ]) + + credentials.authorize(http) + http.request('https://example.com') + self.assertEquals(credentials.access_token, access_token) + + def test_token_refresh_good_store(self): + expiration = datetime.datetime.utcnow() + datetime.timedelta(minutes=15) + credentials = self.create_test_credentials(expiration=expiration) s = file.Storage(FILENAME) s.put(credentials) @@ -287,7 +333,6 @@ class OAuth2ClientFileTests(unittest.TestCase): self.assertEqual(credentials.access_token, stored_credentials.access_token) - def test_multistore_file_get_all_keys(self): # start with no keys keys = multistore_file.get_all_credential_keys(FILENAME)