Merge pull request #371 from dhermes/factor-out-utcnow
Factor out usage of utcnow() in client.
This commit is contained in:
@@ -117,6 +117,10 @@ _GCE_METADATA_HOST = '169.254.169.254'
|
||||
_METADATA_FLAVOR_HEADER = 'Metadata-Flavor'
|
||||
_DESIRED_METADATA_FLAVOR = 'Google'
|
||||
|
||||
# Expose utcnow() at module level to allow for
|
||||
# easier testing (by replacing with a stub).
|
||||
_UTCNOW = datetime.datetime.utcnow
|
||||
|
||||
|
||||
class SETTINGS(object):
|
||||
"""Settings namespace for globally defined values."""
|
||||
@@ -737,7 +741,7 @@ class OAuth2Credentials(Credentials):
|
||||
if not self.token_expiry:
|
||||
return False
|
||||
|
||||
now = datetime.datetime.utcnow()
|
||||
now = _UTCNOW()
|
||||
if now >= self.token_expiry:
|
||||
logger.info('access_token is expired. Now: %s, token_expiry: %s',
|
||||
now, self.token_expiry)
|
||||
@@ -780,7 +784,7 @@ class OAuth2Credentials(Credentials):
|
||||
valid; we just don't know anything about it.
|
||||
"""
|
||||
if self.token_expiry:
|
||||
now = datetime.datetime.utcnow()
|
||||
now = _UTCNOW()
|
||||
if self.token_expiry > now:
|
||||
time_delta = self.token_expiry - now
|
||||
# TODO(orestica): return time_delta.total_seconds()
|
||||
@@ -881,8 +885,8 @@ class OAuth2Credentials(Credentials):
|
||||
self.access_token = d['access_token']
|
||||
self.refresh_token = d.get('refresh_token', self.refresh_token)
|
||||
if 'expires_in' in d:
|
||||
self.token_expiry = datetime.timedelta(
|
||||
seconds=int(d['expires_in'])) + datetime.datetime.utcnow()
|
||||
delta = datetime.timedelta(seconds=int(d['expires_in']))
|
||||
self.token_expiry = delta + _UTCNOW()
|
||||
else:
|
||||
self.token_expiry = None
|
||||
if 'id_token' in d:
|
||||
@@ -2149,9 +2153,8 @@ class OAuth2WebServerFlow(Flow):
|
||||
"reauthenticating with approval_prompt='force'.")
|
||||
token_expiry = None
|
||||
if 'expires_in' in d:
|
||||
token_expiry = (
|
||||
datetime.datetime.utcnow() +
|
||||
datetime.timedelta(seconds=int(d['expires_in'])))
|
||||
delta = datetime.timedelta(seconds=int(d['expires_in']))
|
||||
token_expiry = delta + _UTCNOW()
|
||||
|
||||
extracted_id_token = None
|
||||
if 'id_token' in d:
|
||||
|
||||
@@ -21,12 +21,12 @@ Unit tests for oauth2client.
|
||||
|
||||
import base64
|
||||
import contextlib
|
||||
import copy
|
||||
import datetime
|
||||
import json
|
||||
import os
|
||||
import socket
|
||||
import sys
|
||||
import time
|
||||
|
||||
import mock
|
||||
import six
|
||||
@@ -841,11 +841,27 @@ class BasicCredentialsTests(unittest2.TestCase):
|
||||
instance = OAuth2Credentials.from_json(self.credentials.to_json())
|
||||
self.assertEqual('foobar', instance.token_response)
|
||||
|
||||
def test_get_access_token(self):
|
||||
S = 2 # number of seconds in which the token expires
|
||||
token_response_first = {'access_token': 'first_token', 'expires_in': S}
|
||||
token_response_second = {'access_token': 'second_token',
|
||||
'expires_in': S}
|
||||
@mock.patch('oauth2client.client._UTCNOW')
|
||||
def test_get_access_token(self, utcnow):
|
||||
# Configure the patch.
|
||||
seconds = 11
|
||||
NOW = datetime.datetime(1992, 12, 31, second=seconds)
|
||||
utcnow.return_value = NOW
|
||||
|
||||
lifetime = 2 # number of seconds in which the token expires
|
||||
EXPIRY_TIME = datetime.datetime(1992, 12, 31,
|
||||
second=seconds + lifetime)
|
||||
|
||||
token1 = u'first_token'
|
||||
token_response_first = {
|
||||
'access_token': token1,
|
||||
'expires_in': lifetime,
|
||||
}
|
||||
token2 = u'second_token'
|
||||
token_response_second = {
|
||||
'access_token': token2,
|
||||
'expires_in': lifetime,
|
||||
}
|
||||
http = HttpMockSequence([
|
||||
({'status': '200'}, json.dumps(token_response_first).encode(
|
||||
'utf-8')),
|
||||
@@ -853,27 +869,61 @@ class BasicCredentialsTests(unittest2.TestCase):
|
||||
'utf-8')),
|
||||
])
|
||||
|
||||
token = self.credentials.get_access_token(http=http)
|
||||
self.assertEqual('first_token', token.access_token)
|
||||
self.assertEqual(S - 1, token.expires_in)
|
||||
self.assertFalse(self.credentials.access_token_expired)
|
||||
self.assertEqual(token_response_first, self.credentials.token_response)
|
||||
# Use the current credentials but unset the expiry and
|
||||
# the access token.
|
||||
credentials = copy.deepcopy(self.credentials)
|
||||
credentials.access_token = None
|
||||
credentials.token_expiry = None
|
||||
|
||||
token = self.credentials.get_access_token(http=http)
|
||||
self.assertEqual('first_token', token.access_token)
|
||||
self.assertEqual(S - 1, token.expires_in)
|
||||
self.assertFalse(self.credentials.access_token_expired)
|
||||
self.assertEqual(token_response_first, self.credentials.token_response)
|
||||
# Get Access Token, First attempt.
|
||||
self.assertEqual(credentials.access_token, None)
|
||||
self.assertFalse(credentials.access_token_expired)
|
||||
self.assertEqual(credentials.token_expiry, None)
|
||||
token = credentials.get_access_token(http=http)
|
||||
self.assertEqual(credentials.token_expiry, EXPIRY_TIME)
|
||||
self.assertEqual(token1, token.access_token)
|
||||
self.assertEqual(lifetime, token.expires_in)
|
||||
self.assertEqual(token_response_first, credentials.token_response)
|
||||
# Two utcnow calls are expected:
|
||||
# - get_access_token() -> _do_refresh_request (setting expires in)
|
||||
# - get_access_token() -> _expires_in()
|
||||
expected_utcnow_calls = [mock.call()] * 2
|
||||
self.assertEqual(expected_utcnow_calls, utcnow.mock_calls)
|
||||
|
||||
time.sleep(S + 0.5) # some margin to avoid flakiness
|
||||
self.assertTrue(self.credentials.access_token_expired)
|
||||
# Get Access Token, Second Attempt (not expired)
|
||||
self.assertEqual(credentials.access_token, token1)
|
||||
self.assertFalse(credentials.access_token_expired)
|
||||
token = credentials.get_access_token(http=http)
|
||||
# Make sure no refresh occurred since the token was not expired.
|
||||
self.assertEqual(token1, token.access_token)
|
||||
self.assertEqual(lifetime, token.expires_in)
|
||||
self.assertEqual(token_response_first, credentials.token_response)
|
||||
# Three more utcnow calls are expected:
|
||||
# - access_token_expired
|
||||
# - get_access_token() -> access_token_expired
|
||||
# - get_access_token -> _expires_in
|
||||
expected_utcnow_calls = [mock.call()] * (2 + 3)
|
||||
self.assertEqual(expected_utcnow_calls, utcnow.mock_calls)
|
||||
|
||||
token = self.credentials.get_access_token(http=http)
|
||||
self.assertEqual('second_token', token.access_token)
|
||||
self.assertEqual(S - 1, token.expires_in)
|
||||
self.assertFalse(self.credentials.access_token_expired)
|
||||
# Get Access Token, Third Attempt (force expiration)
|
||||
self.assertEqual(credentials.access_token, token1)
|
||||
credentials.token_expiry = NOW # Manually force expiry.
|
||||
self.assertTrue(credentials.access_token_expired)
|
||||
token = credentials.get_access_token(http=http)
|
||||
# Make sure refresh occurred since the token was not expired.
|
||||
self.assertEqual(token2, token.access_token)
|
||||
self.assertEqual(lifetime, token.expires_in)
|
||||
self.assertFalse(credentials.access_token_expired)
|
||||
self.assertEqual(token_response_second,
|
||||
self.credentials.token_response)
|
||||
credentials.token_response)
|
||||
# Five more utcnow calls are expected:
|
||||
# - access_token_expired
|
||||
# - get_access_token -> access_token_expired
|
||||
# - get_access_token -> _do_refresh_request
|
||||
# - get_access_token -> _expires_in
|
||||
# - access_token_expired
|
||||
expected_utcnow_calls = [mock.call()] * (2 + 3 + 5)
|
||||
self.assertEqual(expected_utcnow_calls, utcnow.mock_calls)
|
||||
|
||||
def test_has_scopes(self):
|
||||
self.assertTrue(self.credentials.has_scopes('foo'))
|
||||
|
||||
@@ -17,12 +17,14 @@
|
||||
Unit tests for service account credentials implemented using RSA.
|
||||
"""
|
||||
|
||||
import datetime
|
||||
import json
|
||||
import os
|
||||
import rsa
|
||||
import time
|
||||
import unittest
|
||||
|
||||
import mock
|
||||
|
||||
from .http_mock import HttpMockSequence
|
||||
from oauth2client.service_account import _ServiceAccountCredentials
|
||||
|
||||
@@ -88,11 +90,28 @@ class ServiceAccountCredentialsTests(unittest.TestCase):
|
||||
_ServiceAccountCredentials))
|
||||
self.assertEqual('dummy_scope', new_credentials._scopes)
|
||||
|
||||
def test_access_token(self):
|
||||
S = 2 # number of seconds in which the token expires
|
||||
token_response_first = {'access_token': 'first_token', 'expires_in': S}
|
||||
token_response_second = {'access_token': 'second_token',
|
||||
'expires_in': S}
|
||||
@mock.patch('oauth2client.client._UTCNOW')
|
||||
@mock.patch('rsa.pkcs1.sign', return_value=b'signed-value')
|
||||
def test_access_token(self, sign_func, utcnow):
|
||||
# Configure the patch.
|
||||
seconds = 11
|
||||
NOW = datetime.datetime(1992, 12, 31, second=seconds)
|
||||
utcnow.return_value = NOW
|
||||
|
||||
lifetime = 2 # number of seconds in which the token expires
|
||||
EXPIRY_TIME = datetime.datetime(1992, 12, 31,
|
||||
second=seconds + lifetime)
|
||||
|
||||
token1 = u'first_token'
|
||||
token_response_first = {
|
||||
'access_token': token1,
|
||||
'expires_in': lifetime,
|
||||
}
|
||||
token2 = u'second_token'
|
||||
token_response_second = {
|
||||
'access_token': token2,
|
||||
'expires_in': lifetime,
|
||||
}
|
||||
http = HttpMockSequence([
|
||||
({'status': '200'},
|
||||
json.dumps(token_response_first).encode('utf-8')),
|
||||
@@ -100,27 +119,64 @@ class ServiceAccountCredentialsTests(unittest.TestCase):
|
||||
json.dumps(token_response_second).encode('utf-8')),
|
||||
])
|
||||
|
||||
token = self.credentials.get_access_token(http=http)
|
||||
self.assertEqual('first_token', token.access_token)
|
||||
self.assertEqual(S - 1, token.expires_in)
|
||||
# Get Access Token, First attempt.
|
||||
self.assertEqual(self.credentials.access_token, None)
|
||||
self.assertFalse(self.credentials.access_token_expired)
|
||||
self.assertEqual(token_response_first, self.credentials.token_response)
|
||||
|
||||
self.assertEqual(self.credentials.token_expiry, None)
|
||||
token = self.credentials.get_access_token(http=http)
|
||||
self.assertEqual('first_token', token.access_token)
|
||||
self.assertEqual(S - 1, token.expires_in)
|
||||
self.assertFalse(self.credentials.access_token_expired)
|
||||
self.assertEqual(token_response_first, self.credentials.token_response)
|
||||
self.assertEqual(self.credentials.token_expiry, EXPIRY_TIME)
|
||||
self.assertEqual(token1, token.access_token)
|
||||
self.assertEqual(lifetime, token.expires_in)
|
||||
self.assertEqual(token_response_first,
|
||||
self.credentials.token_response)
|
||||
# Two utcnow calls are expected:
|
||||
# - get_access_token() -> _do_refresh_request (setting expires in)
|
||||
# - get_access_token() -> _expires_in()
|
||||
expected_utcnow_calls = [mock.call()] * 2
|
||||
self.assertEqual(expected_utcnow_calls, utcnow.mock_calls)
|
||||
# One rsa.pkcs1.sign expected: Actual refresh was needed.
|
||||
self.assertEqual(len(sign_func.mock_calls), 1)
|
||||
|
||||
time.sleep(S + 0.5) # some margin to avoid flakiness
|
||||
# Get Access Token, Second Attempt (not expired)
|
||||
self.assertEqual(self.credentials.access_token, token1)
|
||||
self.assertFalse(self.credentials.access_token_expired)
|
||||
token = self.credentials.get_access_token(http=http)
|
||||
# Make sure no refresh occurred since the token was not expired.
|
||||
self.assertEqual(token1, token.access_token)
|
||||
self.assertEqual(lifetime, token.expires_in)
|
||||
self.assertEqual(token_response_first, self.credentials.token_response)
|
||||
# Three more utcnow calls are expected:
|
||||
# - access_token_expired
|
||||
# - get_access_token() -> access_token_expired
|
||||
# - get_access_token -> _expires_in
|
||||
expected_utcnow_calls = [mock.call()] * (2 + 3)
|
||||
self.assertEqual(expected_utcnow_calls, utcnow.mock_calls)
|
||||
# No rsa.pkcs1.sign expected: the token was not expired.
|
||||
self.assertEqual(len(sign_func.mock_calls), 1 + 0)
|
||||
|
||||
# Get Access Token, Third Attempt (force expiration)
|
||||
self.assertEqual(self.credentials.access_token, token1)
|
||||
self.credentials.token_expiry = NOW # Manually force expiry.
|
||||
self.assertTrue(self.credentials.access_token_expired)
|
||||
|
||||
token = self.credentials.get_access_token(http=http)
|
||||
self.assertEqual('second_token', token.access_token)
|
||||
self.assertEqual(S - 1, token.expires_in)
|
||||
# Make sure refresh occurred since the token was not expired.
|
||||
self.assertEqual(token2, token.access_token)
|
||||
self.assertEqual(lifetime, token.expires_in)
|
||||
self.assertFalse(self.credentials.access_token_expired)
|
||||
self.assertEqual(token_response_second,
|
||||
self.credentials.token_response)
|
||||
# Five more utcnow calls are expected:
|
||||
# - access_token_expired
|
||||
# - get_access_token -> access_token_expired
|
||||
# - get_access_token -> _do_refresh_request
|
||||
# - get_access_token -> _expires_in
|
||||
# - access_token_expired
|
||||
expected_utcnow_calls = [mock.call()] * (2 + 3 + 5)
|
||||
self.assertEqual(expected_utcnow_calls, utcnow.mock_calls)
|
||||
# One more rsa.pkcs1.sign expected: Actual refresh was needed.
|
||||
self.assertEqual(len(sign_func.mock_calls), 1 + 0 + 1)
|
||||
|
||||
self.assertEqual(self.credentials.access_token, token2)
|
||||
|
||||
|
||||
if __name__ == '__main__': # pragma: NO COVER
|
||||
|
||||
Reference in New Issue
Block a user