Merge pull request #371 from dhermes/factor-out-utcnow

Factor out usage of utcnow() in client.
This commit is contained in:
Danny Hermes
2016-01-05 13:51:58 -08:00
3 changed files with 158 additions and 49 deletions

View File

@@ -117,6 +117,10 @@ _GCE_METADATA_HOST = '169.254.169.254'
_METADATA_FLAVOR_HEADER = 'Metadata-Flavor'
_DESIRED_METADATA_FLAVOR = 'Google'
# Expose utcnow() at module level to allow for
# easier testing (by replacing with a stub).
_UTCNOW = datetime.datetime.utcnow
class SETTINGS(object):
"""Settings namespace for globally defined values."""
@@ -737,7 +741,7 @@ class OAuth2Credentials(Credentials):
if not self.token_expiry:
return False
now = datetime.datetime.utcnow()
now = _UTCNOW()
if now >= self.token_expiry:
logger.info('access_token is expired. Now: %s, token_expiry: %s',
now, self.token_expiry)
@@ -780,7 +784,7 @@ class OAuth2Credentials(Credentials):
valid; we just don't know anything about it.
"""
if self.token_expiry:
now = datetime.datetime.utcnow()
now = _UTCNOW()
if self.token_expiry > now:
time_delta = self.token_expiry - now
# TODO(orestica): return time_delta.total_seconds()
@@ -881,8 +885,8 @@ class OAuth2Credentials(Credentials):
self.access_token = d['access_token']
self.refresh_token = d.get('refresh_token', self.refresh_token)
if 'expires_in' in d:
self.token_expiry = datetime.timedelta(
seconds=int(d['expires_in'])) + datetime.datetime.utcnow()
delta = datetime.timedelta(seconds=int(d['expires_in']))
self.token_expiry = delta + _UTCNOW()
else:
self.token_expiry = None
if 'id_token' in d:
@@ -2149,9 +2153,8 @@ class OAuth2WebServerFlow(Flow):
"reauthenticating with approval_prompt='force'.")
token_expiry = None
if 'expires_in' in d:
token_expiry = (
datetime.datetime.utcnow() +
datetime.timedelta(seconds=int(d['expires_in'])))
delta = datetime.timedelta(seconds=int(d['expires_in']))
token_expiry = delta + _UTCNOW()
extracted_id_token = None
if 'id_token' in d:

View File

@@ -21,12 +21,12 @@ Unit tests for oauth2client.
import base64
import contextlib
import copy
import datetime
import json
import os
import socket
import sys
import time
import mock
import six
@@ -841,11 +841,27 @@ class BasicCredentialsTests(unittest2.TestCase):
instance = OAuth2Credentials.from_json(self.credentials.to_json())
self.assertEqual('foobar', instance.token_response)
def test_get_access_token(self):
S = 2 # number of seconds in which the token expires
token_response_first = {'access_token': 'first_token', 'expires_in': S}
token_response_second = {'access_token': 'second_token',
'expires_in': S}
@mock.patch('oauth2client.client._UTCNOW')
def test_get_access_token(self, utcnow):
# Configure the patch.
seconds = 11
NOW = datetime.datetime(1992, 12, 31, second=seconds)
utcnow.return_value = NOW
lifetime = 2 # number of seconds in which the token expires
EXPIRY_TIME = datetime.datetime(1992, 12, 31,
second=seconds + lifetime)
token1 = u'first_token'
token_response_first = {
'access_token': token1,
'expires_in': lifetime,
}
token2 = u'second_token'
token_response_second = {
'access_token': token2,
'expires_in': lifetime,
}
http = HttpMockSequence([
({'status': '200'}, json.dumps(token_response_first).encode(
'utf-8')),
@@ -853,27 +869,61 @@ class BasicCredentialsTests(unittest2.TestCase):
'utf-8')),
])
token = self.credentials.get_access_token(http=http)
self.assertEqual('first_token', token.access_token)
self.assertEqual(S - 1, token.expires_in)
self.assertFalse(self.credentials.access_token_expired)
self.assertEqual(token_response_first, self.credentials.token_response)
# Use the current credentials but unset the expiry and
# the access token.
credentials = copy.deepcopy(self.credentials)
credentials.access_token = None
credentials.token_expiry = None
token = self.credentials.get_access_token(http=http)
self.assertEqual('first_token', token.access_token)
self.assertEqual(S - 1, token.expires_in)
self.assertFalse(self.credentials.access_token_expired)
self.assertEqual(token_response_first, self.credentials.token_response)
# Get Access Token, First attempt.
self.assertEqual(credentials.access_token, None)
self.assertFalse(credentials.access_token_expired)
self.assertEqual(credentials.token_expiry, None)
token = credentials.get_access_token(http=http)
self.assertEqual(credentials.token_expiry, EXPIRY_TIME)
self.assertEqual(token1, token.access_token)
self.assertEqual(lifetime, token.expires_in)
self.assertEqual(token_response_first, credentials.token_response)
# Two utcnow calls are expected:
# - get_access_token() -> _do_refresh_request (setting expires in)
# - get_access_token() -> _expires_in()
expected_utcnow_calls = [mock.call()] * 2
self.assertEqual(expected_utcnow_calls, utcnow.mock_calls)
time.sleep(S + 0.5) # some margin to avoid flakiness
self.assertTrue(self.credentials.access_token_expired)
# Get Access Token, Second Attempt (not expired)
self.assertEqual(credentials.access_token, token1)
self.assertFalse(credentials.access_token_expired)
token = credentials.get_access_token(http=http)
# Make sure no refresh occurred since the token was not expired.
self.assertEqual(token1, token.access_token)
self.assertEqual(lifetime, token.expires_in)
self.assertEqual(token_response_first, credentials.token_response)
# Three more utcnow calls are expected:
# - access_token_expired
# - get_access_token() -> access_token_expired
# - get_access_token -> _expires_in
expected_utcnow_calls = [mock.call()] * (2 + 3)
self.assertEqual(expected_utcnow_calls, utcnow.mock_calls)
token = self.credentials.get_access_token(http=http)
self.assertEqual('second_token', token.access_token)
self.assertEqual(S - 1, token.expires_in)
self.assertFalse(self.credentials.access_token_expired)
# Get Access Token, Third Attempt (force expiration)
self.assertEqual(credentials.access_token, token1)
credentials.token_expiry = NOW # Manually force expiry.
self.assertTrue(credentials.access_token_expired)
token = credentials.get_access_token(http=http)
# Make sure refresh occurred since the token was not expired.
self.assertEqual(token2, token.access_token)
self.assertEqual(lifetime, token.expires_in)
self.assertFalse(credentials.access_token_expired)
self.assertEqual(token_response_second,
self.credentials.token_response)
credentials.token_response)
# Five more utcnow calls are expected:
# - access_token_expired
# - get_access_token -> access_token_expired
# - get_access_token -> _do_refresh_request
# - get_access_token -> _expires_in
# - access_token_expired
expected_utcnow_calls = [mock.call()] * (2 + 3 + 5)
self.assertEqual(expected_utcnow_calls, utcnow.mock_calls)
def test_has_scopes(self):
self.assertTrue(self.credentials.has_scopes('foo'))

View File

@@ -17,12 +17,14 @@
Unit tests for service account credentials implemented using RSA.
"""
import datetime
import json
import os
import rsa
import time
import unittest
import mock
from .http_mock import HttpMockSequence
from oauth2client.service_account import _ServiceAccountCredentials
@@ -88,11 +90,28 @@ class ServiceAccountCredentialsTests(unittest.TestCase):
_ServiceAccountCredentials))
self.assertEqual('dummy_scope', new_credentials._scopes)
def test_access_token(self):
S = 2 # number of seconds in which the token expires
token_response_first = {'access_token': 'first_token', 'expires_in': S}
token_response_second = {'access_token': 'second_token',
'expires_in': S}
@mock.patch('oauth2client.client._UTCNOW')
@mock.patch('rsa.pkcs1.sign', return_value=b'signed-value')
def test_access_token(self, sign_func, utcnow):
# Configure the patch.
seconds = 11
NOW = datetime.datetime(1992, 12, 31, second=seconds)
utcnow.return_value = NOW
lifetime = 2 # number of seconds in which the token expires
EXPIRY_TIME = datetime.datetime(1992, 12, 31,
second=seconds + lifetime)
token1 = u'first_token'
token_response_first = {
'access_token': token1,
'expires_in': lifetime,
}
token2 = u'second_token'
token_response_second = {
'access_token': token2,
'expires_in': lifetime,
}
http = HttpMockSequence([
({'status': '200'},
json.dumps(token_response_first).encode('utf-8')),
@@ -100,27 +119,64 @@ class ServiceAccountCredentialsTests(unittest.TestCase):
json.dumps(token_response_second).encode('utf-8')),
])
token = self.credentials.get_access_token(http=http)
self.assertEqual('first_token', token.access_token)
self.assertEqual(S - 1, token.expires_in)
# Get Access Token, First attempt.
self.assertEqual(self.credentials.access_token, None)
self.assertFalse(self.credentials.access_token_expired)
self.assertEqual(token_response_first, self.credentials.token_response)
self.assertEqual(self.credentials.token_expiry, None)
token = self.credentials.get_access_token(http=http)
self.assertEqual('first_token', token.access_token)
self.assertEqual(S - 1, token.expires_in)
self.assertFalse(self.credentials.access_token_expired)
self.assertEqual(token_response_first, self.credentials.token_response)
self.assertEqual(self.credentials.token_expiry, EXPIRY_TIME)
self.assertEqual(token1, token.access_token)
self.assertEqual(lifetime, token.expires_in)
self.assertEqual(token_response_first,
self.credentials.token_response)
# Two utcnow calls are expected:
# - get_access_token() -> _do_refresh_request (setting expires in)
# - get_access_token() -> _expires_in()
expected_utcnow_calls = [mock.call()] * 2
self.assertEqual(expected_utcnow_calls, utcnow.mock_calls)
# One rsa.pkcs1.sign expected: Actual refresh was needed.
self.assertEqual(len(sign_func.mock_calls), 1)
time.sleep(S + 0.5) # some margin to avoid flakiness
# Get Access Token, Second Attempt (not expired)
self.assertEqual(self.credentials.access_token, token1)
self.assertFalse(self.credentials.access_token_expired)
token = self.credentials.get_access_token(http=http)
# Make sure no refresh occurred since the token was not expired.
self.assertEqual(token1, token.access_token)
self.assertEqual(lifetime, token.expires_in)
self.assertEqual(token_response_first, self.credentials.token_response)
# Three more utcnow calls are expected:
# - access_token_expired
# - get_access_token() -> access_token_expired
# - get_access_token -> _expires_in
expected_utcnow_calls = [mock.call()] * (2 + 3)
self.assertEqual(expected_utcnow_calls, utcnow.mock_calls)
# No rsa.pkcs1.sign expected: the token was not expired.
self.assertEqual(len(sign_func.mock_calls), 1 + 0)
# Get Access Token, Third Attempt (force expiration)
self.assertEqual(self.credentials.access_token, token1)
self.credentials.token_expiry = NOW # Manually force expiry.
self.assertTrue(self.credentials.access_token_expired)
token = self.credentials.get_access_token(http=http)
self.assertEqual('second_token', token.access_token)
self.assertEqual(S - 1, token.expires_in)
# Make sure refresh occurred since the token was not expired.
self.assertEqual(token2, token.access_token)
self.assertEqual(lifetime, token.expires_in)
self.assertFalse(self.credentials.access_token_expired)
self.assertEqual(token_response_second,
self.credentials.token_response)
# Five more utcnow calls are expected:
# - access_token_expired
# - get_access_token -> access_token_expired
# - get_access_token -> _do_refresh_request
# - get_access_token -> _expires_in
# - access_token_expired
expected_utcnow_calls = [mock.call()] * (2 + 3 + 5)
self.assertEqual(expected_utcnow_calls, utcnow.mock_calls)
# One more rsa.pkcs1.sign expected: Actual refresh was needed.
self.assertEqual(len(sign_func.mock_calls), 1 + 0 + 1)
self.assertEqual(self.credentials.access_token, token2)
if __name__ == '__main__': # pragma: NO COVER