Factor out usage of utcnow() in client.
This is to enable better stubs in testing and eliminate two sleep() statements in unit tests. (The philosophy is "unit tests should be fast".)
This commit is contained in:
@@ -117,6 +117,10 @@ _GCE_METADATA_HOST = '169.254.169.254'
|
||||
_METADATA_FLAVOR_HEADER = 'Metadata-Flavor'
|
||||
_DESIRED_METADATA_FLAVOR = 'Google'
|
||||
|
||||
# Expose utcnow() at module level to allow for
|
||||
# easier testing (by replacing with a stub).
|
||||
_UTCNOW = datetime.datetime.utcnow
|
||||
|
||||
|
||||
class SETTINGS(object):
|
||||
"""Settings namespace for globally defined values."""
|
||||
@@ -737,7 +741,7 @@ class OAuth2Credentials(Credentials):
|
||||
if not self.token_expiry:
|
||||
return False
|
||||
|
||||
now = datetime.datetime.utcnow()
|
||||
now = _UTCNOW()
|
||||
if now >= self.token_expiry:
|
||||
logger.info('access_token is expired. Now: %s, token_expiry: %s',
|
||||
now, self.token_expiry)
|
||||
@@ -780,7 +784,7 @@ class OAuth2Credentials(Credentials):
|
||||
valid; we just don't know anything about it.
|
||||
"""
|
||||
if self.token_expiry:
|
||||
now = datetime.datetime.utcnow()
|
||||
now = _UTCNOW()
|
||||
if self.token_expiry > now:
|
||||
time_delta = self.token_expiry - now
|
||||
# TODO(orestica): return time_delta.total_seconds()
|
||||
@@ -881,8 +885,8 @@ class OAuth2Credentials(Credentials):
|
||||
self.access_token = d['access_token']
|
||||
self.refresh_token = d.get('refresh_token', self.refresh_token)
|
||||
if 'expires_in' in d:
|
||||
self.token_expiry = datetime.timedelta(
|
||||
seconds=int(d['expires_in'])) + datetime.datetime.utcnow()
|
||||
delta = datetime.timedelta(seconds=int(d['expires_in']))
|
||||
self.token_expiry = delta + _UTCNOW()
|
||||
else:
|
||||
self.token_expiry = None
|
||||
if 'id_token' in d:
|
||||
@@ -2149,9 +2153,8 @@ class OAuth2WebServerFlow(Flow):
|
||||
"reauthenticating with approval_prompt='force'.")
|
||||
token_expiry = None
|
||||
if 'expires_in' in d:
|
||||
token_expiry = (
|
||||
datetime.datetime.utcnow() +
|
||||
datetime.timedelta(seconds=int(d['expires_in'])))
|
||||
delta = datetime.timedelta(seconds=int(d['expires_in']))
|
||||
token_expiry = delta + _UTCNOW()
|
||||
|
||||
extracted_id_token = None
|
||||
if 'id_token' in d:
|
||||
|
||||
@@ -21,12 +21,12 @@ Unit tests for oauth2client.
|
||||
|
||||
import base64
|
||||
import contextlib
|
||||
import copy
|
||||
import datetime
|
||||
import json
|
||||
import os
|
||||
import socket
|
||||
import sys
|
||||
import time
|
||||
|
||||
import mock
|
||||
import six
|
||||
@@ -841,11 +841,27 @@ class BasicCredentialsTests(unittest2.TestCase):
|
||||
instance = OAuth2Credentials.from_json(self.credentials.to_json())
|
||||
self.assertEqual('foobar', instance.token_response)
|
||||
|
||||
def test_get_access_token(self):
|
||||
S = 2 # number of seconds in which the token expires
|
||||
token_response_first = {'access_token': 'first_token', 'expires_in': S}
|
||||
token_response_second = {'access_token': 'second_token',
|
||||
'expires_in': S}
|
||||
@mock.patch('oauth2client.client._UTCNOW')
|
||||
def test_get_access_token(self, utcnow):
|
||||
# Configure the patch.
|
||||
seconds = 11
|
||||
NOW = datetime.datetime(1992, 12, 31, second=seconds)
|
||||
utcnow.return_value = NOW
|
||||
|
||||
lifetime = 2 # number of seconds in which the token expires
|
||||
EXPIRY_TIME = datetime.datetime(1992, 12, 31,
|
||||
second=seconds + lifetime)
|
||||
|
||||
token1 = u'first_token'
|
||||
token_response_first = {
|
||||
'access_token': token1,
|
||||
'expires_in': lifetime,
|
||||
}
|
||||
token2 = u'second_token'
|
||||
token_response_second = {
|
||||
'access_token': token2,
|
||||
'expires_in': lifetime,
|
||||
}
|
||||
http = HttpMockSequence([
|
||||
({'status': '200'}, json.dumps(token_response_first).encode(
|
||||
'utf-8')),
|
||||
@@ -853,27 +869,61 @@ class BasicCredentialsTests(unittest2.TestCase):
|
||||
'utf-8')),
|
||||
])
|
||||
|
||||
token = self.credentials.get_access_token(http=http)
|
||||
self.assertEqual('first_token', token.access_token)
|
||||
self.assertEqual(S - 1, token.expires_in)
|
||||
self.assertFalse(self.credentials.access_token_expired)
|
||||
self.assertEqual(token_response_first, self.credentials.token_response)
|
||||
# Use the current credentials but unset the expiry and
|
||||
# the access token.
|
||||
credentials = copy.deepcopy(self.credentials)
|
||||
credentials.access_token = None
|
||||
credentials.token_expiry = None
|
||||
|
||||
token = self.credentials.get_access_token(http=http)
|
||||
self.assertEqual('first_token', token.access_token)
|
||||
self.assertEqual(S - 1, token.expires_in)
|
||||
self.assertFalse(self.credentials.access_token_expired)
|
||||
self.assertEqual(token_response_first, self.credentials.token_response)
|
||||
# Get Access Token, First attempt.
|
||||
self.assertEqual(credentials.access_token, None)
|
||||
self.assertFalse(credentials.access_token_expired)
|
||||
self.assertEqual(credentials.token_expiry, None)
|
||||
token = credentials.get_access_token(http=http)
|
||||
self.assertEqual(credentials.token_expiry, EXPIRY_TIME)
|
||||
self.assertEqual(token1, token.access_token)
|
||||
self.assertEqual(lifetime, token.expires_in)
|
||||
self.assertEqual(token_response_first, credentials.token_response)
|
||||
# Two utcnow calls are expected:
|
||||
# - get_access_token() -> _do_refresh_request (setting expires in)
|
||||
# - get_access_token() -> _expires_in()
|
||||
expected_utcnow_calls = [mock.call()] * 2
|
||||
self.assertEqual(expected_utcnow_calls, utcnow.mock_calls)
|
||||
|
||||
time.sleep(S + 0.5) # some margin to avoid flakiness
|
||||
self.assertTrue(self.credentials.access_token_expired)
|
||||
# Get Access Token, Second Attempt (not expired)
|
||||
self.assertEqual(credentials.access_token, token1)
|
||||
self.assertFalse(credentials.access_token_expired)
|
||||
token = credentials.get_access_token(http=http)
|
||||
# Make sure no refresh occurred since the token was not expired.
|
||||
self.assertEqual(token1, token.access_token)
|
||||
self.assertEqual(lifetime, token.expires_in)
|
||||
self.assertEqual(token_response_first, credentials.token_response)
|
||||
# Three more utcnow calls are expected:
|
||||
# - access_token_expired
|
||||
# - get_access_token() -> access_token_expired
|
||||
# - get_access_token -> _expires_in
|
||||
expected_utcnow_calls = [mock.call()] * (2 + 3)
|
||||
self.assertEqual(expected_utcnow_calls, utcnow.mock_calls)
|
||||
|
||||
token = self.credentials.get_access_token(http=http)
|
||||
self.assertEqual('second_token', token.access_token)
|
||||
self.assertEqual(S - 1, token.expires_in)
|
||||
self.assertFalse(self.credentials.access_token_expired)
|
||||
# Get Access Token, Third Attempt (force expiration)
|
||||
self.assertEqual(credentials.access_token, token1)
|
||||
credentials.token_expiry = NOW # Manually force expiry.
|
||||
self.assertTrue(credentials.access_token_expired)
|
||||
token = credentials.get_access_token(http=http)
|
||||
# Make sure refresh occurred since the token was not expired.
|
||||
self.assertEqual(token2, token.access_token)
|
||||
self.assertEqual(lifetime, token.expires_in)
|
||||
self.assertFalse(credentials.access_token_expired)
|
||||
self.assertEqual(token_response_second,
|
||||
self.credentials.token_response)
|
||||
credentials.token_response)
|
||||
# Five more utcnow calls are expected:
|
||||
# - access_token_expired
|
||||
# - get_access_token -> access_token_expired
|
||||
# - get_access_token -> _do_refresh_request
|
||||
# - get_access_token -> _expires_in
|
||||
# - access_token_expired
|
||||
expected_utcnow_calls = [mock.call()] * (2 + 3 + 5)
|
||||
self.assertEqual(expected_utcnow_calls, utcnow.mock_calls)
|
||||
|
||||
def test_has_scopes(self):
|
||||
self.assertTrue(self.credentials.has_scopes('foo'))
|
||||
|
||||
@@ -17,12 +17,14 @@
|
||||
Unit tests for service account credentials implemented using RSA.
|
||||
"""
|
||||
|
||||
import datetime
|
||||
import json
|
||||
import os
|
||||
import rsa
|
||||
import time
|
||||
import unittest
|
||||
|
||||
import mock
|
||||
|
||||
from .http_mock import HttpMockSequence
|
||||
from oauth2client.service_account import _ServiceAccountCredentials
|
||||
|
||||
@@ -88,11 +90,28 @@ class ServiceAccountCredentialsTests(unittest.TestCase):
|
||||
_ServiceAccountCredentials))
|
||||
self.assertEqual('dummy_scope', new_credentials._scopes)
|
||||
|
||||
def test_access_token(self):
|
||||
S = 2 # number of seconds in which the token expires
|
||||
token_response_first = {'access_token': 'first_token', 'expires_in': S}
|
||||
token_response_second = {'access_token': 'second_token',
|
||||
'expires_in': S}
|
||||
@mock.patch('oauth2client.client._UTCNOW')
|
||||
@mock.patch('rsa.pkcs1.sign', return_value=b'signed-value')
|
||||
def test_access_token(self, sign_func, utcnow):
|
||||
# Configure the patch.
|
||||
seconds = 11
|
||||
NOW = datetime.datetime(1992, 12, 31, second=seconds)
|
||||
utcnow.return_value = NOW
|
||||
|
||||
lifetime = 2 # number of seconds in which the token expires
|
||||
EXPIRY_TIME = datetime.datetime(1992, 12, 31,
|
||||
second=seconds + lifetime)
|
||||
|
||||
token1 = u'first_token'
|
||||
token_response_first = {
|
||||
'access_token': token1,
|
||||
'expires_in': lifetime,
|
||||
}
|
||||
token2 = u'second_token'
|
||||
token_response_second = {
|
||||
'access_token': token2,
|
||||
'expires_in': lifetime,
|
||||
}
|
||||
http = HttpMockSequence([
|
||||
({'status': '200'},
|
||||
json.dumps(token_response_first).encode('utf-8')),
|
||||
@@ -100,27 +119,64 @@ class ServiceAccountCredentialsTests(unittest.TestCase):
|
||||
json.dumps(token_response_second).encode('utf-8')),
|
||||
])
|
||||
|
||||
token = self.credentials.get_access_token(http=http)
|
||||
self.assertEqual('first_token', token.access_token)
|
||||
self.assertEqual(S - 1, token.expires_in)
|
||||
# Get Access Token, First attempt.
|
||||
self.assertEqual(self.credentials.access_token, None)
|
||||
self.assertFalse(self.credentials.access_token_expired)
|
||||
self.assertEqual(token_response_first, self.credentials.token_response)
|
||||
|
||||
self.assertEqual(self.credentials.token_expiry, None)
|
||||
token = self.credentials.get_access_token(http=http)
|
||||
self.assertEqual('first_token', token.access_token)
|
||||
self.assertEqual(S - 1, token.expires_in)
|
||||
self.assertFalse(self.credentials.access_token_expired)
|
||||
self.assertEqual(token_response_first, self.credentials.token_response)
|
||||
self.assertEqual(self.credentials.token_expiry, EXPIRY_TIME)
|
||||
self.assertEqual(token1, token.access_token)
|
||||
self.assertEqual(lifetime, token.expires_in)
|
||||
self.assertEqual(token_response_first,
|
||||
self.credentials.token_response)
|
||||
# Two utcnow calls are expected:
|
||||
# - get_access_token() -> _do_refresh_request (setting expires in)
|
||||
# - get_access_token() -> _expires_in()
|
||||
expected_utcnow_calls = [mock.call()] * 2
|
||||
self.assertEqual(expected_utcnow_calls, utcnow.mock_calls)
|
||||
# One rsa.pkcs1.sign expected: Actual refresh was needed.
|
||||
self.assertEqual(len(sign_func.mock_calls), 1)
|
||||
|
||||
time.sleep(S + 0.5) # some margin to avoid flakiness
|
||||
# Get Access Token, Second Attempt (not expired)
|
||||
self.assertEqual(self.credentials.access_token, token1)
|
||||
self.assertFalse(self.credentials.access_token_expired)
|
||||
token = self.credentials.get_access_token(http=http)
|
||||
# Make sure no refresh occurred since the token was not expired.
|
||||
self.assertEqual(token1, token.access_token)
|
||||
self.assertEqual(lifetime, token.expires_in)
|
||||
self.assertEqual(token_response_first, self.credentials.token_response)
|
||||
# Three more utcnow calls are expected:
|
||||
# - access_token_expired
|
||||
# - get_access_token() -> access_token_expired
|
||||
# - get_access_token -> _expires_in
|
||||
expected_utcnow_calls = [mock.call()] * (2 + 3)
|
||||
self.assertEqual(expected_utcnow_calls, utcnow.mock_calls)
|
||||
# No rsa.pkcs1.sign expected: the token was not expired.
|
||||
self.assertEqual(len(sign_func.mock_calls), 1 + 0)
|
||||
|
||||
# Get Access Token, Third Attempt (force expiration)
|
||||
self.assertEqual(self.credentials.access_token, token1)
|
||||
self.credentials.token_expiry = NOW # Manually force expiry.
|
||||
self.assertTrue(self.credentials.access_token_expired)
|
||||
|
||||
token = self.credentials.get_access_token(http=http)
|
||||
self.assertEqual('second_token', token.access_token)
|
||||
self.assertEqual(S - 1, token.expires_in)
|
||||
# Make sure refresh occurred since the token was not expired.
|
||||
self.assertEqual(token2, token.access_token)
|
||||
self.assertEqual(lifetime, token.expires_in)
|
||||
self.assertFalse(self.credentials.access_token_expired)
|
||||
self.assertEqual(token_response_second,
|
||||
self.credentials.token_response)
|
||||
# Five more utcnow calls are expected:
|
||||
# - access_token_expired
|
||||
# - get_access_token -> access_token_expired
|
||||
# - get_access_token -> _do_refresh_request
|
||||
# - get_access_token -> _expires_in
|
||||
# - access_token_expired
|
||||
expected_utcnow_calls = [mock.call()] * (2 + 3 + 5)
|
||||
self.assertEqual(expected_utcnow_calls, utcnow.mock_calls)
|
||||
# One more rsa.pkcs1.sign expected: Actual refresh was needed.
|
||||
self.assertEqual(len(sign_func.mock_calls), 1 + 0 + 1)
|
||||
|
||||
self.assertEqual(self.credentials.access_token, token2)
|
||||
|
||||
|
||||
if __name__ == '__main__': # pragma: NO COVER
|
||||
|
||||
Reference in New Issue
Block a user