Merge pull request #173 from thobrla/master
Avoid using expired tokens from credential storage.
This commit is contained in:
@@ -557,13 +557,19 @@ class OAuth2Credentials(Credentials):
|
|||||||
resp, content = request_orig(uri, method, body, clean_headers(headers),
|
resp, content = request_orig(uri, method, body, clean_headers(headers),
|
||||||
redirections, connection_type)
|
redirections, connection_type)
|
||||||
|
|
||||||
if resp.status in REFRESH_STATUS_CODES:
|
# A stored token may expire between the time it is retrieved and the time
|
||||||
logger.info('Refreshing due to a %s', resp.status)
|
# the request is made, so we may need to try twice.
|
||||||
|
max_refresh_attempts = 2
|
||||||
|
for refresh_attempt in range(max_refresh_attempts):
|
||||||
|
if resp.status not in REFRESH_STATUS_CODES:
|
||||||
|
break
|
||||||
|
logger.info('Refreshing due to a %s (attempt %s/%s)', resp.status,
|
||||||
|
refresh_attempt + 1, max_refresh_attempts)
|
||||||
self._refresh(request_orig)
|
self._refresh(request_orig)
|
||||||
self.apply(headers)
|
self.apply(headers)
|
||||||
return request_orig(uri, method, body, clean_headers(headers),
|
resp, content = request_orig(uri, method, body, clean_headers(headers),
|
||||||
redirections, connection_type)
|
redirections, connection_type)
|
||||||
else:
|
|
||||||
return (resp, content)
|
return (resp, content)
|
||||||
|
|
||||||
# Replace the request method with our own closure.
|
# Replace the request method with our own closure.
|
||||||
@@ -757,8 +763,10 @@ class OAuth2Credentials(Credentials):
|
|||||||
self.store.acquire_lock()
|
self.store.acquire_lock()
|
||||||
try:
|
try:
|
||||||
new_cred = self.store.locked_get()
|
new_cred = self.store.locked_get()
|
||||||
|
|
||||||
if (new_cred and not new_cred.invalid and
|
if (new_cred and not new_cred.invalid and
|
||||||
new_cred.access_token != self.access_token):
|
new_cred.access_token != self.access_token and
|
||||||
|
not new_cred.access_token_expired):
|
||||||
logger.info('Updated access_token read from Storage')
|
logger.info('Updated access_token read from Storage')
|
||||||
self._updateFromCredential(new_cred)
|
self._updateFromCredential(new_cred)
|
||||||
else:
|
else:
|
||||||
|
|||||||
@@ -24,7 +24,6 @@ __author__ = 'jcgregorio@google.com (Joe Gregorio)'
|
|||||||
|
|
||||||
import copy
|
import copy
|
||||||
import datetime
|
import datetime
|
||||||
import httplib2
|
|
||||||
import json
|
import json
|
||||||
import os
|
import os
|
||||||
import pickle
|
import pickle
|
||||||
@@ -32,13 +31,12 @@ import stat
|
|||||||
import tempfile
|
import tempfile
|
||||||
import unittest
|
import unittest
|
||||||
|
|
||||||
from oauth2client import GOOGLE_TOKEN_URI
|
from .http_mock import HttpMockSequence
|
||||||
from oauth2client import file
|
from oauth2client import file
|
||||||
from oauth2client import locked_file
|
from oauth2client import locked_file
|
||||||
from oauth2client import multistore_file
|
from oauth2client import multistore_file
|
||||||
from oauth2client import util
|
from oauth2client import util
|
||||||
from oauth2client.client import AccessTokenCredentials
|
from oauth2client.client import AccessTokenCredentials
|
||||||
from oauth2client.client import AssertionCredentials
|
|
||||||
from oauth2client.client import OAuth2Credentials
|
from oauth2client.client import OAuth2Credentials
|
||||||
try:
|
try:
|
||||||
# Python2
|
# Python2
|
||||||
@@ -64,11 +62,12 @@ class OAuth2ClientFileTests(unittest.TestCase):
|
|||||||
except OSError:
|
except OSError:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def create_test_credentials(self, client_id='some_client_id'):
|
def create_test_credentials(self, client_id='some_client_id',
|
||||||
|
expiration=None):
|
||||||
access_token = 'foo'
|
access_token = 'foo'
|
||||||
client_secret = 'cOuDdkfjxxnv+'
|
client_secret = 'cOuDdkfjxxnv+'
|
||||||
refresh_token = '1/0/a.df219fjls0'
|
refresh_token = '1/0/a.df219fjls0'
|
||||||
token_expiry = datetime.datetime.utcnow()
|
token_expiry = expiration or datetime.datetime.utcnow()
|
||||||
token_uri = 'https://www.google.com/accounts/o8/oauth2/token'
|
token_uri = 'https://www.google.com/accounts/o8/oauth2/token'
|
||||||
user_agent = 'refresh_checker/1.0'
|
user_agent = 'refresh_checker/1.0'
|
||||||
|
|
||||||
@@ -119,8 +118,55 @@ class OAuth2ClientFileTests(unittest.TestCase):
|
|||||||
self.assertEquals(data['_class'], 'OAuth2Credentials')
|
self.assertEquals(data['_class'], 'OAuth2Credentials')
|
||||||
self.assertEquals(data['_module'], OAuth2Credentials.__module__)
|
self.assertEquals(data['_module'], OAuth2Credentials.__module__)
|
||||||
|
|
||||||
def test_token_refresh(self):
|
def test_token_refresh_store_expired(self):
|
||||||
credentials = self.create_test_credentials()
|
expiration = datetime.datetime.utcnow() - datetime.timedelta(minutes=15)
|
||||||
|
credentials = self.create_test_credentials(expiration=expiration)
|
||||||
|
|
||||||
|
s = file.Storage(FILENAME)
|
||||||
|
s.put(credentials)
|
||||||
|
credentials = s.get()
|
||||||
|
new_cred = copy.copy(credentials)
|
||||||
|
new_cred.access_token = 'bar'
|
||||||
|
s.put(new_cred)
|
||||||
|
|
||||||
|
access_token = '1/3w'
|
||||||
|
token_response = {'access_token': access_token, 'expires_in': 3600}
|
||||||
|
http = HttpMockSequence([
|
||||||
|
({'status': '200'}, json.dumps(token_response).encode('utf-8')),
|
||||||
|
])
|
||||||
|
|
||||||
|
credentials._refresh(http.request)
|
||||||
|
self.assertEquals(credentials.access_token, access_token)
|
||||||
|
|
||||||
|
def test_token_refresh_store_expires_soon(self):
|
||||||
|
# Tests the case where an access token that is valid when it is read from
|
||||||
|
# the store expires before the original request succeeds.
|
||||||
|
expiration = datetime.datetime.utcnow() + datetime.timedelta(minutes=15)
|
||||||
|
credentials = self.create_test_credentials(expiration=expiration)
|
||||||
|
|
||||||
|
s = file.Storage(FILENAME)
|
||||||
|
s.put(credentials)
|
||||||
|
credentials = s.get()
|
||||||
|
new_cred = copy.copy(credentials)
|
||||||
|
new_cred.access_token = 'bar'
|
||||||
|
s.put(new_cred)
|
||||||
|
|
||||||
|
access_token = '1/3w'
|
||||||
|
token_response = {'access_token': access_token, 'expires_in': 3600}
|
||||||
|
http = HttpMockSequence([
|
||||||
|
({'status': '401'}, b'Initial token expired'),
|
||||||
|
({'status': '401'}, b'Store token expired'),
|
||||||
|
({'status': '200'}, json.dumps(token_response).encode('utf-8')),
|
||||||
|
({'status': '200'}, b'Valid response to original request')
|
||||||
|
])
|
||||||
|
|
||||||
|
credentials.authorize(http)
|
||||||
|
http.request('https://example.com')
|
||||||
|
self.assertEquals(credentials.access_token, access_token)
|
||||||
|
|
||||||
|
def test_token_refresh_good_store(self):
|
||||||
|
expiration = datetime.datetime.utcnow() + datetime.timedelta(minutes=15)
|
||||||
|
credentials = self.create_test_credentials(expiration=expiration)
|
||||||
|
|
||||||
s = file.Storage(FILENAME)
|
s = file.Storage(FILENAME)
|
||||||
s.put(credentials)
|
s.put(credentials)
|
||||||
@@ -287,7 +333,6 @@ class OAuth2ClientFileTests(unittest.TestCase):
|
|||||||
|
|
||||||
self.assertEqual(credentials.access_token, stored_credentials.access_token)
|
self.assertEqual(credentials.access_token, stored_credentials.access_token)
|
||||||
|
|
||||||
|
|
||||||
def test_multistore_file_get_all_keys(self):
|
def test_multistore_file_get_all_keys(self):
|
||||||
# start with no keys
|
# start with no keys
|
||||||
keys = multistore_file.get_all_credential_keys(FILENAME)
|
keys = multistore_file.get_all_credential_keys(FILENAME)
|
||||||
|
|||||||
Reference in New Issue
Block a user