Factor out usage of utcnow() in client.

This is to enable better stubs in testing and eliminate
two sleep() statements in unit tests. (The philosophy
is "unit tests should be fast".)
This commit is contained in:
Danny Hermes
2016-01-05 00:02:20 -08:00
parent 24214201c6
commit 30c342a437
3 changed files with 158 additions and 49 deletions

View File

@@ -117,6 +117,10 @@ _GCE_METADATA_HOST = '169.254.169.254'
_METADATA_FLAVOR_HEADER = 'Metadata-Flavor'
_DESIRED_METADATA_FLAVOR = 'Google'
# Expose utcnow() at module level to allow for
# easier testing (by replacing with a stub).
_UTCNOW = datetime.datetime.utcnow
class SETTINGS(object):
"""Settings namespace for globally defined values."""
@@ -737,7 +741,7 @@ class OAuth2Credentials(Credentials):
if not self.token_expiry:
return False
now = datetime.datetime.utcnow()
now = _UTCNOW()
if now >= self.token_expiry:
logger.info('access_token is expired. Now: %s, token_expiry: %s',
now, self.token_expiry)
@@ -780,7 +784,7 @@ class OAuth2Credentials(Credentials):
valid; we just don't know anything about it.
"""
if self.token_expiry:
now = datetime.datetime.utcnow()
now = _UTCNOW()
if self.token_expiry > now:
time_delta = self.token_expiry - now
# TODO(orestica): return time_delta.total_seconds()
@@ -881,8 +885,8 @@ class OAuth2Credentials(Credentials):
self.access_token = d['access_token']
self.refresh_token = d.get('refresh_token', self.refresh_token)
if 'expires_in' in d:
self.token_expiry = datetime.timedelta(
seconds=int(d['expires_in'])) + datetime.datetime.utcnow()
delta = datetime.timedelta(seconds=int(d['expires_in']))
self.token_expiry = delta + _UTCNOW()
else:
self.token_expiry = None
if 'id_token' in d:
@@ -2149,9 +2153,8 @@ class OAuth2WebServerFlow(Flow):
"reauthenticating with approval_prompt='force'.")
token_expiry = None
if 'expires_in' in d:
token_expiry = (
datetime.datetime.utcnow() +
datetime.timedelta(seconds=int(d['expires_in'])))
delta = datetime.timedelta(seconds=int(d['expires_in']))
token_expiry = delta + _UTCNOW()
extracted_id_token = None
if 'id_token' in d:

View File

@@ -21,12 +21,12 @@ Unit tests for oauth2client.
import base64
import contextlib
import copy
import datetime
import json
import os
import socket
import sys
import time
import mock
import six
@@ -841,11 +841,27 @@ class BasicCredentialsTests(unittest2.TestCase):
instance = OAuth2Credentials.from_json(self.credentials.to_json())
self.assertEqual('foobar', instance.token_response)
def test_get_access_token(self):
S = 2 # number of seconds in which the token expires
token_response_first = {'access_token': 'first_token', 'expires_in': S}
token_response_second = {'access_token': 'second_token',
'expires_in': S}
@mock.patch('oauth2client.client._UTCNOW')
def test_get_access_token(self, utcnow):
# Configure the patch.
seconds = 11
NOW = datetime.datetime(1992, 12, 31, second=seconds)
utcnow.return_value = NOW
lifetime = 2 # number of seconds in which the token expires
EXPIRY_TIME = datetime.datetime(1992, 12, 31,
second=seconds + lifetime)
token1 = u'first_token'
token_response_first = {
'access_token': token1,
'expires_in': lifetime,
}
token2 = u'second_token'
token_response_second = {
'access_token': token2,
'expires_in': lifetime,
}
http = HttpMockSequence([
({'status': '200'}, json.dumps(token_response_first).encode(
'utf-8')),
@@ -853,27 +869,61 @@ class BasicCredentialsTests(unittest2.TestCase):
'utf-8')),
])
token = self.credentials.get_access_token(http=http)
self.assertEqual('first_token', token.access_token)
self.assertEqual(S - 1, token.expires_in)
self.assertFalse(self.credentials.access_token_expired)
self.assertEqual(token_response_first, self.credentials.token_response)
# Use the current credentials but unset the expiry and
# the access token.
credentials = copy.deepcopy(self.credentials)
credentials.access_token = None
credentials.token_expiry = None
token = self.credentials.get_access_token(http=http)
self.assertEqual('first_token', token.access_token)
self.assertEqual(S - 1, token.expires_in)
self.assertFalse(self.credentials.access_token_expired)
self.assertEqual(token_response_first, self.credentials.token_response)
# Get Access Token, First attempt.
self.assertEqual(credentials.access_token, None)
self.assertFalse(credentials.access_token_expired)
self.assertEqual(credentials.token_expiry, None)
token = credentials.get_access_token(http=http)
self.assertEqual(credentials.token_expiry, EXPIRY_TIME)
self.assertEqual(token1, token.access_token)
self.assertEqual(lifetime, token.expires_in)
self.assertEqual(token_response_first, credentials.token_response)
# Two utcnow calls are expected:
# - get_access_token() -> _do_refresh_request (setting expires in)
# - get_access_token() -> _expires_in()
expected_utcnow_calls = [mock.call()] * 2
self.assertEqual(expected_utcnow_calls, utcnow.mock_calls)
time.sleep(S + 0.5) # some margin to avoid flakiness
self.assertTrue(self.credentials.access_token_expired)
# Get Access Token, Second Attempt (not expired)
self.assertEqual(credentials.access_token, token1)
self.assertFalse(credentials.access_token_expired)
token = credentials.get_access_token(http=http)
# Make sure no refresh occurred since the token was not expired.
self.assertEqual(token1, token.access_token)
self.assertEqual(lifetime, token.expires_in)
self.assertEqual(token_response_first, credentials.token_response)
# Three more utcnow calls are expected:
# - access_token_expired
# - get_access_token() -> access_token_expired
# - get_access_token -> _expires_in
expected_utcnow_calls = [mock.call()] * (2 + 3)
self.assertEqual(expected_utcnow_calls, utcnow.mock_calls)
token = self.credentials.get_access_token(http=http)
self.assertEqual('second_token', token.access_token)
self.assertEqual(S - 1, token.expires_in)
self.assertFalse(self.credentials.access_token_expired)
# Get Access Token, Third Attempt (force expiration)
self.assertEqual(credentials.access_token, token1)
credentials.token_expiry = NOW # Manually force expiry.
self.assertTrue(credentials.access_token_expired)
token = credentials.get_access_token(http=http)
# Make sure refresh occurred since the token was not expired.
self.assertEqual(token2, token.access_token)
self.assertEqual(lifetime, token.expires_in)
self.assertFalse(credentials.access_token_expired)
self.assertEqual(token_response_second,
self.credentials.token_response)
credentials.token_response)
# Five more utcnow calls are expected:
# - access_token_expired
# - get_access_token -> access_token_expired
# - get_access_token -> _do_refresh_request
# - get_access_token -> _expires_in
# - access_token_expired
expected_utcnow_calls = [mock.call()] * (2 + 3 + 5)
self.assertEqual(expected_utcnow_calls, utcnow.mock_calls)
def test_has_scopes(self):
self.assertTrue(self.credentials.has_scopes('foo'))

View File

@@ -17,12 +17,14 @@
Unit tests for service account credentials implemented using RSA.
"""
import datetime
import json
import os
import rsa
import time
import unittest
import mock
from .http_mock import HttpMockSequence
from oauth2client.service_account import _ServiceAccountCredentials
@@ -88,11 +90,28 @@ class ServiceAccountCredentialsTests(unittest.TestCase):
_ServiceAccountCredentials))
self.assertEqual('dummy_scope', new_credentials._scopes)
def test_access_token(self):
S = 2 # number of seconds in which the token expires
token_response_first = {'access_token': 'first_token', 'expires_in': S}
token_response_second = {'access_token': 'second_token',
'expires_in': S}
@mock.patch('oauth2client.client._UTCNOW')
@mock.patch('rsa.pkcs1.sign', return_value=b'signed-value')
def test_access_token(self, sign_func, utcnow):
# Configure the patch.
seconds = 11
NOW = datetime.datetime(1992, 12, 31, second=seconds)
utcnow.return_value = NOW
lifetime = 2 # number of seconds in which the token expires
EXPIRY_TIME = datetime.datetime(1992, 12, 31,
second=seconds + lifetime)
token1 = u'first_token'
token_response_first = {
'access_token': token1,
'expires_in': lifetime,
}
token2 = u'second_token'
token_response_second = {
'access_token': token2,
'expires_in': lifetime,
}
http = HttpMockSequence([
({'status': '200'},
json.dumps(token_response_first).encode('utf-8')),
@@ -100,27 +119,64 @@ class ServiceAccountCredentialsTests(unittest.TestCase):
json.dumps(token_response_second).encode('utf-8')),
])
token = self.credentials.get_access_token(http=http)
self.assertEqual('first_token', token.access_token)
self.assertEqual(S - 1, token.expires_in)
# Get Access Token, First attempt.
self.assertEqual(self.credentials.access_token, None)
self.assertFalse(self.credentials.access_token_expired)
self.assertEqual(token_response_first, self.credentials.token_response)
self.assertEqual(self.credentials.token_expiry, None)
token = self.credentials.get_access_token(http=http)
self.assertEqual('first_token', token.access_token)
self.assertEqual(S - 1, token.expires_in)
self.assertFalse(self.credentials.access_token_expired)
self.assertEqual(token_response_first, self.credentials.token_response)
self.assertEqual(self.credentials.token_expiry, EXPIRY_TIME)
self.assertEqual(token1, token.access_token)
self.assertEqual(lifetime, token.expires_in)
self.assertEqual(token_response_first,
self.credentials.token_response)
# Two utcnow calls are expected:
# - get_access_token() -> _do_refresh_request (setting expires in)
# - get_access_token() -> _expires_in()
expected_utcnow_calls = [mock.call()] * 2
self.assertEqual(expected_utcnow_calls, utcnow.mock_calls)
# One rsa.pkcs1.sign expected: Actual refresh was needed.
self.assertEqual(len(sign_func.mock_calls), 1)
time.sleep(S + 0.5) # some margin to avoid flakiness
# Get Access Token, Second Attempt (not expired)
self.assertEqual(self.credentials.access_token, token1)
self.assertFalse(self.credentials.access_token_expired)
token = self.credentials.get_access_token(http=http)
# Make sure no refresh occurred since the token was not expired.
self.assertEqual(token1, token.access_token)
self.assertEqual(lifetime, token.expires_in)
self.assertEqual(token_response_first, self.credentials.token_response)
# Three more utcnow calls are expected:
# - access_token_expired
# - get_access_token() -> access_token_expired
# - get_access_token -> _expires_in
expected_utcnow_calls = [mock.call()] * (2 + 3)
self.assertEqual(expected_utcnow_calls, utcnow.mock_calls)
# No rsa.pkcs1.sign expected: the token was not expired.
self.assertEqual(len(sign_func.mock_calls), 1 + 0)
# Get Access Token, Third Attempt (force expiration)
self.assertEqual(self.credentials.access_token, token1)
self.credentials.token_expiry = NOW # Manually force expiry.
self.assertTrue(self.credentials.access_token_expired)
token = self.credentials.get_access_token(http=http)
self.assertEqual('second_token', token.access_token)
self.assertEqual(S - 1, token.expires_in)
# Make sure refresh occurred since the token was not expired.
self.assertEqual(token2, token.access_token)
self.assertEqual(lifetime, token.expires_in)
self.assertFalse(self.credentials.access_token_expired)
self.assertEqual(token_response_second,
self.credentials.token_response)
# Five more utcnow calls are expected:
# - access_token_expired
# - get_access_token -> access_token_expired
# - get_access_token -> _do_refresh_request
# - get_access_token -> _expires_in
# - access_token_expired
expected_utcnow_calls = [mock.call()] * (2 + 3 + 5)
self.assertEqual(expected_utcnow_calls, utcnow.mock_calls)
# One more rsa.pkcs1.sign expected: Actual refresh was needed.
self.assertEqual(len(sign_func.mock_calls), 1 + 0 + 1)
self.assertEqual(self.credentials.access_token, token2)
if __name__ == '__main__': # pragma: NO COVER