diff --git a/oauth2client/client.py b/oauth2client/client.py index 20a59c4..155d321 100644 --- a/oauth2client/client.py +++ b/oauth2client/client.py @@ -117,6 +117,10 @@ _GCE_METADATA_HOST = '169.254.169.254' _METADATA_FLAVOR_HEADER = 'Metadata-Flavor' _DESIRED_METADATA_FLAVOR = 'Google' +# Expose utcnow() at module level to allow for +# easier testing (by replacing with a stub). +_UTCNOW = datetime.datetime.utcnow + class SETTINGS(object): """Settings namespace for globally defined values.""" @@ -737,7 +741,7 @@ class OAuth2Credentials(Credentials): if not self.token_expiry: return False - now = datetime.datetime.utcnow() + now = _UTCNOW() if now >= self.token_expiry: logger.info('access_token is expired. Now: %s, token_expiry: %s', now, self.token_expiry) @@ -780,7 +784,7 @@ class OAuth2Credentials(Credentials): valid; we just don't know anything about it. """ if self.token_expiry: - now = datetime.datetime.utcnow() + now = _UTCNOW() if self.token_expiry > now: time_delta = self.token_expiry - now # TODO(orestica): return time_delta.total_seconds() @@ -881,8 +885,8 @@ class OAuth2Credentials(Credentials): self.access_token = d['access_token'] self.refresh_token = d.get('refresh_token', self.refresh_token) if 'expires_in' in d: - self.token_expiry = datetime.timedelta( - seconds=int(d['expires_in'])) + datetime.datetime.utcnow() + delta = datetime.timedelta(seconds=int(d['expires_in'])) + self.token_expiry = delta + _UTCNOW() else: self.token_expiry = None if 'id_token' in d: @@ -2149,9 +2153,8 @@ class OAuth2WebServerFlow(Flow): "reauthenticating with approval_prompt='force'.") token_expiry = None if 'expires_in' in d: - token_expiry = ( - datetime.datetime.utcnow() + - datetime.timedelta(seconds=int(d['expires_in']))) + delta = datetime.timedelta(seconds=int(d['expires_in'])) + token_expiry = delta + _UTCNOW() extracted_id_token = None if 'id_token' in d: diff --git a/tests/test_client.py b/tests/test_client.py index c3cbdab..ef9d633 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -21,12 +21,12 @@ Unit tests for oauth2client. import base64 import contextlib +import copy import datetime import json import os import socket import sys -import time import mock import six @@ -841,11 +841,27 @@ class BasicCredentialsTests(unittest2.TestCase): instance = OAuth2Credentials.from_json(self.credentials.to_json()) self.assertEqual('foobar', instance.token_response) - def test_get_access_token(self): - S = 2 # number of seconds in which the token expires - token_response_first = {'access_token': 'first_token', 'expires_in': S} - token_response_second = {'access_token': 'second_token', - 'expires_in': S} + @mock.patch('oauth2client.client._UTCNOW') + def test_get_access_token(self, utcnow): + # Configure the patch. + seconds = 11 + NOW = datetime.datetime(1992, 12, 31, second=seconds) + utcnow.return_value = NOW + + lifetime = 2 # number of seconds in which the token expires + EXPIRY_TIME = datetime.datetime(1992, 12, 31, + second=seconds + lifetime) + + token1 = u'first_token' + token_response_first = { + 'access_token': token1, + 'expires_in': lifetime, + } + token2 = u'second_token' + token_response_second = { + 'access_token': token2, + 'expires_in': lifetime, + } http = HttpMockSequence([ ({'status': '200'}, json.dumps(token_response_first).encode( 'utf-8')), @@ -853,27 +869,61 @@ class BasicCredentialsTests(unittest2.TestCase): 'utf-8')), ]) - token = self.credentials.get_access_token(http=http) - self.assertEqual('first_token', token.access_token) - self.assertEqual(S - 1, token.expires_in) - self.assertFalse(self.credentials.access_token_expired) - self.assertEqual(token_response_first, self.credentials.token_response) + # Use the current credentials but unset the expiry and + # the access token. + credentials = copy.deepcopy(self.credentials) + credentials.access_token = None + credentials.token_expiry = None - token = self.credentials.get_access_token(http=http) - self.assertEqual('first_token', token.access_token) - self.assertEqual(S - 1, token.expires_in) - self.assertFalse(self.credentials.access_token_expired) - self.assertEqual(token_response_first, self.credentials.token_response) + # Get Access Token, First attempt. + self.assertEqual(credentials.access_token, None) + self.assertFalse(credentials.access_token_expired) + self.assertEqual(credentials.token_expiry, None) + token = credentials.get_access_token(http=http) + self.assertEqual(credentials.token_expiry, EXPIRY_TIME) + self.assertEqual(token1, token.access_token) + self.assertEqual(lifetime, token.expires_in) + self.assertEqual(token_response_first, credentials.token_response) + # Two utcnow calls are expected: + # - get_access_token() -> _do_refresh_request (setting expires in) + # - get_access_token() -> _expires_in() + expected_utcnow_calls = [mock.call()] * 2 + self.assertEqual(expected_utcnow_calls, utcnow.mock_calls) - time.sleep(S + 0.5) # some margin to avoid flakiness - self.assertTrue(self.credentials.access_token_expired) + # Get Access Token, Second Attempt (not expired) + self.assertEqual(credentials.access_token, token1) + self.assertFalse(credentials.access_token_expired) + token = credentials.get_access_token(http=http) + # Make sure no refresh occurred since the token was not expired. + self.assertEqual(token1, token.access_token) + self.assertEqual(lifetime, token.expires_in) + self.assertEqual(token_response_first, credentials.token_response) + # Three more utcnow calls are expected: + # - access_token_expired + # - get_access_token() -> access_token_expired + # - get_access_token -> _expires_in + expected_utcnow_calls = [mock.call()] * (2 + 3) + self.assertEqual(expected_utcnow_calls, utcnow.mock_calls) - token = self.credentials.get_access_token(http=http) - self.assertEqual('second_token', token.access_token) - self.assertEqual(S - 1, token.expires_in) - self.assertFalse(self.credentials.access_token_expired) + # Get Access Token, Third Attempt (force expiration) + self.assertEqual(credentials.access_token, token1) + credentials.token_expiry = NOW # Manually force expiry. + self.assertTrue(credentials.access_token_expired) + token = credentials.get_access_token(http=http) + # Make sure refresh occurred since the token was not expired. + self.assertEqual(token2, token.access_token) + self.assertEqual(lifetime, token.expires_in) + self.assertFalse(credentials.access_token_expired) self.assertEqual(token_response_second, - self.credentials.token_response) + credentials.token_response) + # Five more utcnow calls are expected: + # - access_token_expired + # - get_access_token -> access_token_expired + # - get_access_token -> _do_refresh_request + # - get_access_token -> _expires_in + # - access_token_expired + expected_utcnow_calls = [mock.call()] * (2 + 3 + 5) + self.assertEqual(expected_utcnow_calls, utcnow.mock_calls) def test_has_scopes(self): self.assertTrue(self.credentials.has_scopes('foo')) diff --git a/tests/test_service_account.py b/tests/test_service_account.py index 1ba0cae..09d6234 100644 --- a/tests/test_service_account.py +++ b/tests/test_service_account.py @@ -17,12 +17,14 @@ Unit tests for service account credentials implemented using RSA. """ +import datetime import json import os import rsa -import time import unittest +import mock + from .http_mock import HttpMockSequence from oauth2client.service_account import _ServiceAccountCredentials @@ -88,11 +90,28 @@ class ServiceAccountCredentialsTests(unittest.TestCase): _ServiceAccountCredentials)) self.assertEqual('dummy_scope', new_credentials._scopes) - def test_access_token(self): - S = 2 # number of seconds in which the token expires - token_response_first = {'access_token': 'first_token', 'expires_in': S} - token_response_second = {'access_token': 'second_token', - 'expires_in': S} + @mock.patch('oauth2client.client._UTCNOW') + @mock.patch('rsa.pkcs1.sign', return_value=b'signed-value') + def test_access_token(self, sign_func, utcnow): + # Configure the patch. + seconds = 11 + NOW = datetime.datetime(1992, 12, 31, second=seconds) + utcnow.return_value = NOW + + lifetime = 2 # number of seconds in which the token expires + EXPIRY_TIME = datetime.datetime(1992, 12, 31, + second=seconds + lifetime) + + token1 = u'first_token' + token_response_first = { + 'access_token': token1, + 'expires_in': lifetime, + } + token2 = u'second_token' + token_response_second = { + 'access_token': token2, + 'expires_in': lifetime, + } http = HttpMockSequence([ ({'status': '200'}, json.dumps(token_response_first).encode('utf-8')), @@ -100,27 +119,64 @@ class ServiceAccountCredentialsTests(unittest.TestCase): json.dumps(token_response_second).encode('utf-8')), ]) - token = self.credentials.get_access_token(http=http) - self.assertEqual('first_token', token.access_token) - self.assertEqual(S - 1, token.expires_in) + # Get Access Token, First attempt. + self.assertEqual(self.credentials.access_token, None) self.assertFalse(self.credentials.access_token_expired) - self.assertEqual(token_response_first, self.credentials.token_response) - + self.assertEqual(self.credentials.token_expiry, None) token = self.credentials.get_access_token(http=http) - self.assertEqual('first_token', token.access_token) - self.assertEqual(S - 1, token.expires_in) - self.assertFalse(self.credentials.access_token_expired) - self.assertEqual(token_response_first, self.credentials.token_response) + self.assertEqual(self.credentials.token_expiry, EXPIRY_TIME) + self.assertEqual(token1, token.access_token) + self.assertEqual(lifetime, token.expires_in) + self.assertEqual(token_response_first, + self.credentials.token_response) + # Two utcnow calls are expected: + # - get_access_token() -> _do_refresh_request (setting expires in) + # - get_access_token() -> _expires_in() + expected_utcnow_calls = [mock.call()] * 2 + self.assertEqual(expected_utcnow_calls, utcnow.mock_calls) + # One rsa.pkcs1.sign expected: Actual refresh was needed. + self.assertEqual(len(sign_func.mock_calls), 1) - time.sleep(S + 0.5) # some margin to avoid flakiness + # Get Access Token, Second Attempt (not expired) + self.assertEqual(self.credentials.access_token, token1) + self.assertFalse(self.credentials.access_token_expired) + token = self.credentials.get_access_token(http=http) + # Make sure no refresh occurred since the token was not expired. + self.assertEqual(token1, token.access_token) + self.assertEqual(lifetime, token.expires_in) + self.assertEqual(token_response_first, self.credentials.token_response) + # Three more utcnow calls are expected: + # - access_token_expired + # - get_access_token() -> access_token_expired + # - get_access_token -> _expires_in + expected_utcnow_calls = [mock.call()] * (2 + 3) + self.assertEqual(expected_utcnow_calls, utcnow.mock_calls) + # No rsa.pkcs1.sign expected: the token was not expired. + self.assertEqual(len(sign_func.mock_calls), 1 + 0) + + # Get Access Token, Third Attempt (force expiration) + self.assertEqual(self.credentials.access_token, token1) + self.credentials.token_expiry = NOW # Manually force expiry. self.assertTrue(self.credentials.access_token_expired) - token = self.credentials.get_access_token(http=http) - self.assertEqual('second_token', token.access_token) - self.assertEqual(S - 1, token.expires_in) + # Make sure refresh occurred since the token was not expired. + self.assertEqual(token2, token.access_token) + self.assertEqual(lifetime, token.expires_in) self.assertFalse(self.credentials.access_token_expired) self.assertEqual(token_response_second, self.credentials.token_response) + # Five more utcnow calls are expected: + # - access_token_expired + # - get_access_token -> access_token_expired + # - get_access_token -> _do_refresh_request + # - get_access_token -> _expires_in + # - access_token_expired + expected_utcnow_calls = [mock.call()] * (2 + 3 + 5) + self.assertEqual(expected_utcnow_calls, utcnow.mock_calls) + # One more rsa.pkcs1.sign expected: Actual refresh was needed. + self.assertEqual(len(sign_func.mock_calls), 1 + 0 + 1) + + self.assertEqual(self.credentials.access_token, token2) if __name__ == '__main__': # pragma: NO COVER