From 30c342a437a137c82444364073f4e4afebeb891a Mon Sep 17 00:00:00 2001 From: Danny Hermes Date: Tue, 5 Jan 2016 00:02:20 -0800 Subject: [PATCH] Factor out usage of utcnow() in client. This is to enable better stubs in testing and eliminate two sleep() statements in unit tests. (The philosophy is "unit tests should be fast".) --- oauth2client/client.py | 17 ++++--- tests/test_client.py | 96 ++++++++++++++++++++++++++--------- tests/test_service_account.py | 94 +++++++++++++++++++++++++++------- 3 files changed, 158 insertions(+), 49 deletions(-) diff --git a/oauth2client/client.py b/oauth2client/client.py index 20a59c4..155d321 100644 --- a/oauth2client/client.py +++ b/oauth2client/client.py @@ -117,6 +117,10 @@ _GCE_METADATA_HOST = '169.254.169.254' _METADATA_FLAVOR_HEADER = 'Metadata-Flavor' _DESIRED_METADATA_FLAVOR = 'Google' +# Expose utcnow() at module level to allow for +# easier testing (by replacing with a stub). +_UTCNOW = datetime.datetime.utcnow + class SETTINGS(object): """Settings namespace for globally defined values.""" @@ -737,7 +741,7 @@ class OAuth2Credentials(Credentials): if not self.token_expiry: return False - now = datetime.datetime.utcnow() + now = _UTCNOW() if now >= self.token_expiry: logger.info('access_token is expired. Now: %s, token_expiry: %s', now, self.token_expiry) @@ -780,7 +784,7 @@ class OAuth2Credentials(Credentials): valid; we just don't know anything about it. """ if self.token_expiry: - now = datetime.datetime.utcnow() + now = _UTCNOW() if self.token_expiry > now: time_delta = self.token_expiry - now # TODO(orestica): return time_delta.total_seconds() @@ -881,8 +885,8 @@ class OAuth2Credentials(Credentials): self.access_token = d['access_token'] self.refresh_token = d.get('refresh_token', self.refresh_token) if 'expires_in' in d: - self.token_expiry = datetime.timedelta( - seconds=int(d['expires_in'])) + datetime.datetime.utcnow() + delta = datetime.timedelta(seconds=int(d['expires_in'])) + self.token_expiry = delta + _UTCNOW() else: self.token_expiry = None if 'id_token' in d: @@ -2149,9 +2153,8 @@ class OAuth2WebServerFlow(Flow): "reauthenticating with approval_prompt='force'.") token_expiry = None if 'expires_in' in d: - token_expiry = ( - datetime.datetime.utcnow() + - datetime.timedelta(seconds=int(d['expires_in']))) + delta = datetime.timedelta(seconds=int(d['expires_in'])) + token_expiry = delta + _UTCNOW() extracted_id_token = None if 'id_token' in d: diff --git a/tests/test_client.py b/tests/test_client.py index c3cbdab..ef9d633 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -21,12 +21,12 @@ Unit tests for oauth2client. import base64 import contextlib +import copy import datetime import json import os import socket import sys -import time import mock import six @@ -841,11 +841,27 @@ class BasicCredentialsTests(unittest2.TestCase): instance = OAuth2Credentials.from_json(self.credentials.to_json()) self.assertEqual('foobar', instance.token_response) - def test_get_access_token(self): - S = 2 # number of seconds in which the token expires - token_response_first = {'access_token': 'first_token', 'expires_in': S} - token_response_second = {'access_token': 'second_token', - 'expires_in': S} + @mock.patch('oauth2client.client._UTCNOW') + def test_get_access_token(self, utcnow): + # Configure the patch. + seconds = 11 + NOW = datetime.datetime(1992, 12, 31, second=seconds) + utcnow.return_value = NOW + + lifetime = 2 # number of seconds in which the token expires + EXPIRY_TIME = datetime.datetime(1992, 12, 31, + second=seconds + lifetime) + + token1 = u'first_token' + token_response_first = { + 'access_token': token1, + 'expires_in': lifetime, + } + token2 = u'second_token' + token_response_second = { + 'access_token': token2, + 'expires_in': lifetime, + } http = HttpMockSequence([ ({'status': '200'}, json.dumps(token_response_first).encode( 'utf-8')), @@ -853,27 +869,61 @@ class BasicCredentialsTests(unittest2.TestCase): 'utf-8')), ]) - token = self.credentials.get_access_token(http=http) - self.assertEqual('first_token', token.access_token) - self.assertEqual(S - 1, token.expires_in) - self.assertFalse(self.credentials.access_token_expired) - self.assertEqual(token_response_first, self.credentials.token_response) + # Use the current credentials but unset the expiry and + # the access token. + credentials = copy.deepcopy(self.credentials) + credentials.access_token = None + credentials.token_expiry = None - token = self.credentials.get_access_token(http=http) - self.assertEqual('first_token', token.access_token) - self.assertEqual(S - 1, token.expires_in) - self.assertFalse(self.credentials.access_token_expired) - self.assertEqual(token_response_first, self.credentials.token_response) + # Get Access Token, First attempt. + self.assertEqual(credentials.access_token, None) + self.assertFalse(credentials.access_token_expired) + self.assertEqual(credentials.token_expiry, None) + token = credentials.get_access_token(http=http) + self.assertEqual(credentials.token_expiry, EXPIRY_TIME) + self.assertEqual(token1, token.access_token) + self.assertEqual(lifetime, token.expires_in) + self.assertEqual(token_response_first, credentials.token_response) + # Two utcnow calls are expected: + # - get_access_token() -> _do_refresh_request (setting expires in) + # - get_access_token() -> _expires_in() + expected_utcnow_calls = [mock.call()] * 2 + self.assertEqual(expected_utcnow_calls, utcnow.mock_calls) - time.sleep(S + 0.5) # some margin to avoid flakiness - self.assertTrue(self.credentials.access_token_expired) + # Get Access Token, Second Attempt (not expired) + self.assertEqual(credentials.access_token, token1) + self.assertFalse(credentials.access_token_expired) + token = credentials.get_access_token(http=http) + # Make sure no refresh occurred since the token was not expired. + self.assertEqual(token1, token.access_token) + self.assertEqual(lifetime, token.expires_in) + self.assertEqual(token_response_first, credentials.token_response) + # Three more utcnow calls are expected: + # - access_token_expired + # - get_access_token() -> access_token_expired + # - get_access_token -> _expires_in + expected_utcnow_calls = [mock.call()] * (2 + 3) + self.assertEqual(expected_utcnow_calls, utcnow.mock_calls) - token = self.credentials.get_access_token(http=http) - self.assertEqual('second_token', token.access_token) - self.assertEqual(S - 1, token.expires_in) - self.assertFalse(self.credentials.access_token_expired) + # Get Access Token, Third Attempt (force expiration) + self.assertEqual(credentials.access_token, token1) + credentials.token_expiry = NOW # Manually force expiry. + self.assertTrue(credentials.access_token_expired) + token = credentials.get_access_token(http=http) + # Make sure refresh occurred since the token was not expired. + self.assertEqual(token2, token.access_token) + self.assertEqual(lifetime, token.expires_in) + self.assertFalse(credentials.access_token_expired) self.assertEqual(token_response_second, - self.credentials.token_response) + credentials.token_response) + # Five more utcnow calls are expected: + # - access_token_expired + # - get_access_token -> access_token_expired + # - get_access_token -> _do_refresh_request + # - get_access_token -> _expires_in + # - access_token_expired + expected_utcnow_calls = [mock.call()] * (2 + 3 + 5) + self.assertEqual(expected_utcnow_calls, utcnow.mock_calls) def test_has_scopes(self): self.assertTrue(self.credentials.has_scopes('foo')) diff --git a/tests/test_service_account.py b/tests/test_service_account.py index 1ba0cae..09d6234 100644 --- a/tests/test_service_account.py +++ b/tests/test_service_account.py @@ -17,12 +17,14 @@ Unit tests for service account credentials implemented using RSA. """ +import datetime import json import os import rsa -import time import unittest +import mock + from .http_mock import HttpMockSequence from oauth2client.service_account import _ServiceAccountCredentials @@ -88,11 +90,28 @@ class ServiceAccountCredentialsTests(unittest.TestCase): _ServiceAccountCredentials)) self.assertEqual('dummy_scope', new_credentials._scopes) - def test_access_token(self): - S = 2 # number of seconds in which the token expires - token_response_first = {'access_token': 'first_token', 'expires_in': S} - token_response_second = {'access_token': 'second_token', - 'expires_in': S} + @mock.patch('oauth2client.client._UTCNOW') + @mock.patch('rsa.pkcs1.sign', return_value=b'signed-value') + def test_access_token(self, sign_func, utcnow): + # Configure the patch. + seconds = 11 + NOW = datetime.datetime(1992, 12, 31, second=seconds) + utcnow.return_value = NOW + + lifetime = 2 # number of seconds in which the token expires + EXPIRY_TIME = datetime.datetime(1992, 12, 31, + second=seconds + lifetime) + + token1 = u'first_token' + token_response_first = { + 'access_token': token1, + 'expires_in': lifetime, + } + token2 = u'second_token' + token_response_second = { + 'access_token': token2, + 'expires_in': lifetime, + } http = HttpMockSequence([ ({'status': '200'}, json.dumps(token_response_first).encode('utf-8')), @@ -100,27 +119,64 @@ class ServiceAccountCredentialsTests(unittest.TestCase): json.dumps(token_response_second).encode('utf-8')), ]) - token = self.credentials.get_access_token(http=http) - self.assertEqual('first_token', token.access_token) - self.assertEqual(S - 1, token.expires_in) + # Get Access Token, First attempt. + self.assertEqual(self.credentials.access_token, None) self.assertFalse(self.credentials.access_token_expired) - self.assertEqual(token_response_first, self.credentials.token_response) - + self.assertEqual(self.credentials.token_expiry, None) token = self.credentials.get_access_token(http=http) - self.assertEqual('first_token', token.access_token) - self.assertEqual(S - 1, token.expires_in) - self.assertFalse(self.credentials.access_token_expired) - self.assertEqual(token_response_first, self.credentials.token_response) + self.assertEqual(self.credentials.token_expiry, EXPIRY_TIME) + self.assertEqual(token1, token.access_token) + self.assertEqual(lifetime, token.expires_in) + self.assertEqual(token_response_first, + self.credentials.token_response) + # Two utcnow calls are expected: + # - get_access_token() -> _do_refresh_request (setting expires in) + # - get_access_token() -> _expires_in() + expected_utcnow_calls = [mock.call()] * 2 + self.assertEqual(expected_utcnow_calls, utcnow.mock_calls) + # One rsa.pkcs1.sign expected: Actual refresh was needed. + self.assertEqual(len(sign_func.mock_calls), 1) - time.sleep(S + 0.5) # some margin to avoid flakiness + # Get Access Token, Second Attempt (not expired) + self.assertEqual(self.credentials.access_token, token1) + self.assertFalse(self.credentials.access_token_expired) + token = self.credentials.get_access_token(http=http) + # Make sure no refresh occurred since the token was not expired. + self.assertEqual(token1, token.access_token) + self.assertEqual(lifetime, token.expires_in) + self.assertEqual(token_response_first, self.credentials.token_response) + # Three more utcnow calls are expected: + # - access_token_expired + # - get_access_token() -> access_token_expired + # - get_access_token -> _expires_in + expected_utcnow_calls = [mock.call()] * (2 + 3) + self.assertEqual(expected_utcnow_calls, utcnow.mock_calls) + # No rsa.pkcs1.sign expected: the token was not expired. + self.assertEqual(len(sign_func.mock_calls), 1 + 0) + + # Get Access Token, Third Attempt (force expiration) + self.assertEqual(self.credentials.access_token, token1) + self.credentials.token_expiry = NOW # Manually force expiry. self.assertTrue(self.credentials.access_token_expired) - token = self.credentials.get_access_token(http=http) - self.assertEqual('second_token', token.access_token) - self.assertEqual(S - 1, token.expires_in) + # Make sure refresh occurred since the token was not expired. + self.assertEqual(token2, token.access_token) + self.assertEqual(lifetime, token.expires_in) self.assertFalse(self.credentials.access_token_expired) self.assertEqual(token_response_second, self.credentials.token_response) + # Five more utcnow calls are expected: + # - access_token_expired + # - get_access_token -> access_token_expired + # - get_access_token -> _do_refresh_request + # - get_access_token -> _expires_in + # - access_token_expired + expected_utcnow_calls = [mock.call()] * (2 + 3 + 5) + self.assertEqual(expected_utcnow_calls, utcnow.mock_calls) + # One more rsa.pkcs1.sign expected: Actual refresh was needed. + self.assertEqual(len(sign_func.mock_calls), 1 + 0 + 1) + + self.assertEqual(self.credentials.access_token, token2) if __name__ == '__main__': # pragma: NO COVER