Use transport.request in tests.

In the process

- "spring clean" the modules that were touched
- use HttpMock when HttpMockSequence not needed
- add some verifications on new HttpMock's
This commit is contained in:
Danny Hermes
2016-08-10 16:53:46 -07:00
parent b7f3eca135
commit 1a0c4dbf92
6 changed files with 237 additions and 119 deletions

View File

@@ -24,12 +24,13 @@ import unittest
import fasteners import fasteners
import mock import mock
from six import StringIO import six
from six.moves import urllib_parse
from oauth2client import client from oauth2client import client
from oauth2client.contrib import multiprocess_file_storage from oauth2client.contrib import multiprocess_file_storage
from ..http_mock import HttpMockSequence from .. import http_mock
@contextlib.contextmanager @contextlib.contextmanager
@@ -68,11 +69,7 @@ def _generate_token_response_http(new_token='new_token'):
'access_token': new_token, 'access_token': new_token,
'expires_in': '3600', 'expires_in': '3600',
}) })
http = HttpMockSequence([ return http_mock.HttpMock(data=token_response)
({'status': '200'}, token_response),
])
return http
class MultiprocessStorageBehaviorTests(unittest.TestCase): class MultiprocessStorageBehaviorTests(unittest.TestCase):
@@ -115,6 +112,23 @@ class MultiprocessStorageBehaviorTests(unittest.TestCase):
self.assertIsNone(credentials) self.assertIsNone(credentials)
def _verify_refresh_payload(self, http, credentials):
self.assertEqual(http.requests, 1)
self.assertEqual(http.uri, credentials.token_uri)
self.assertEqual(http.method, 'POST')
expected_body = {
'grant_type': ['refresh_token'],
'client_id': [credentials.client_id],
'client_secret': [credentials.client_secret],
'refresh_token': [credentials.refresh_token],
}
self.assertEqual(urllib_parse.parse_qs(http.body), expected_body)
expected_headers = {
'content-type': 'application/x-www-form-urlencoded',
'user-agent': credentials.user_agent,
}
self.assertEqual(http.headers, expected_headers)
def test_single_process_refresh(self): def test_single_process_refresh(self):
store = multiprocess_file_storage.MultiprocessFileStorage( store = multiprocess_file_storage.MultiprocessFileStorage(
self.filename, 'single-process') self.filename, 'single-process')
@@ -128,6 +142,9 @@ class MultiprocessStorageBehaviorTests(unittest.TestCase):
retrieved = store.get() retrieved = store.get()
self.assertEqual(retrieved.access_token, 'new_token') self.assertEqual(retrieved.access_token, 'new_token')
# Verify mocks.
self._verify_refresh_payload(http, credentials)
def test_multi_process_refresh(self): def test_multi_process_refresh(self):
# This will test that two processes attempting to refresh credentials # This will test that two processes attempting to refresh credentials
# will only refresh once. # will only refresh once.
@@ -136,6 +153,7 @@ class MultiprocessStorageBehaviorTests(unittest.TestCase):
credentials = _create_test_credentials() credentials = _create_test_credentials()
credentials.set_store(store) credentials.set_store(store)
store.put(credentials) store.put(credentials)
actual_token = 'b'
def child_process_func( def child_process_func(
die_event, ready_event, check_event): # pragma: NO COVER die_event, ready_event, check_event): # pragma: NO COVER
@@ -156,10 +174,12 @@ class MultiprocessStorageBehaviorTests(unittest.TestCase):
credentials.store.acquire_lock = replacement_acquire_lock credentials.store.acquire_lock = replacement_acquire_lock
http = _generate_token_response_http('b') http = _generate_token_response_http(actual_token)
credentials.refresh(http) credentials.refresh(http)
self.assertEqual(credentials.access_token, actual_token)
self.assertEqual(credentials.access_token, 'b') # Verify mock http.
self._verify_refresh_payload(http, credentials)
check_event = multiprocessing.Event() check_event = multiprocessing.Event()
with scoped_child_process(child_process_func, check_event=check_event): with scoped_child_process(child_process_func, check_event=check_event):
@@ -168,15 +188,17 @@ class MultiprocessStorageBehaviorTests(unittest.TestCase):
store._backend._process_lock.acquire(blocking=False)) store._backend._process_lock.acquire(blocking=False))
check_event.set() check_event.set()
# The child process will refresh first, so we should end up http = _generate_token_response_http('not ' + actual_token)
# with 'b' as the token.
http = mock.Mock()
credentials.refresh(http=http) credentials.refresh(http=http)
self.assertEqual(credentials.access_token, 'b') # The child process will refresh first, so we should end up
self.assertFalse(http.request.called) # with `actual_token`' as the token.
self.assertEqual(credentials.access_token, actual_token)
# Make sure the refresh did not make a request.
self.assertEqual(http.requests, 0)
retrieved = store.get() retrieved = store.get()
self.assertEqual(retrieved.access_token, 'b') self.assertEqual(retrieved.access_token, actual_token)
def test_read_only_file_fail_lock(self): def test_read_only_file_fail_lock(self):
credentials = _create_test_credentials() credentials = _create_test_credentials()
@@ -233,7 +255,7 @@ class MultiprocessStorageUnitTests(unittest.TestCase):
def test__read_write_credentials_file(self): def test__read_write_credentials_file(self):
credentials = _create_test_credentials() credentials = _create_test_credentials()
contents = StringIO() contents = six.StringIO()
multiprocess_file_storage._write_credentials_file( multiprocess_file_storage._write_credentials_file(
contents, {'key': credentials}) contents, {'key': credentials})
@@ -253,23 +275,23 @@ class MultiprocessStorageUnitTests(unittest.TestCase):
# the invalid one but still load the valid one. # the invalid one but still load the valid one.
data['credentials']['invalid'] = '123' data['credentials']['invalid'] = '123'
results = multiprocess_file_storage._load_credentials_file( results = multiprocess_file_storage._load_credentials_file(
StringIO(json.dumps(data))) six.StringIO(json.dumps(data)))
self.assertNotIn('invalid', results) self.assertNotIn('invalid', results)
self.assertEqual( self.assertEqual(
results['key'].access_token, credentials.access_token) results['key'].access_token, credentials.access_token)
def test__load_credentials_file_invalid_json(self): def test__load_credentials_file_invalid_json(self):
contents = StringIO('{[') contents = six.StringIO('{[')
self.assertEqual( self.assertEqual(
multiprocess_file_storage._load_credentials_file(contents), {}) multiprocess_file_storage._load_credentials_file(contents), {})
def test__load_credentials_file_no_file_version(self): def test__load_credentials_file_no_file_version(self):
contents = StringIO('{}') contents = six.StringIO('{}')
self.assertEqual( self.assertEqual(
multiprocess_file_storage._load_credentials_file(contents), {}) multiprocess_file_storage._load_credentials_file(contents), {})
def test__load_credentials_file_bad_file_version(self): def test__load_credentials_file_bad_file_version(self):
contents = StringIO(json.dumps({'file_version': 1})) contents = six.StringIO(json.dumps({'file_version': 1}))
self.assertEqual( self.assertEqual(
multiprocess_file_storage._load_credentials_file(contents), {}) multiprocess_file_storage._load_credentials_file(contents), {})

View File

@@ -75,7 +75,7 @@ class HttpMockSequence(object):
({'status': '200'}, b'{"access_token":"1/3w","expires_in":3600}'), ({'status': '200'}, b'{"access_token":"1/3w","expires_in":3600}'),
({'status': '200'}, 'echo_request_headers'), ({'status': '200'}, 'echo_request_headers'),
]) ])
resp, content = http.request("http://examples.com") resp, content = http.request('http://examples.com')
There are special values you can pass in for content to trigger There are special values you can pass in for content to trigger
behavours that are helpful in testing. behavours that are helpful in testing.

View File

@@ -35,6 +35,7 @@ from oauth2client import _helpers
from oauth2client import client from oauth2client import client
from oauth2client import clientsecrets from oauth2client import clientsecrets
from oauth2client import service_account from oauth2client import service_account
from oauth2client import transport
from . import http_mock from . import http_mock
__author__ = 'jcgregorio@google.com (Joe Gregorio)' __author__ = 'jcgregorio@google.com (Joe Gregorio)'
@@ -899,7 +900,7 @@ class BasicCredentialsTests(unittest.TestCase):
({'status': http_client.OK}, 'echo_request_headers'), ({'status': http_client.OK}, 'echo_request_headers'),
]) ])
http = self.credentials.authorize(http) http = self.credentials.authorize(http)
resp, content = http.request('http://example.com') resp, content = transport.request(http, 'http://example.com')
self.assertEqual(b'Bearer 1/3w', content[b'Authorization']) self.assertEqual(b'Bearer 1/3w', content[b'Authorization'])
self.assertFalse(self.credentials.access_token_expired) self.assertFalse(self.credentials.access_token_expired)
self.assertEqual(token_response, self.credentials.token_response) self.assertEqual(token_response, self.credentials.token_response)
@@ -918,7 +919,7 @@ class BasicCredentialsTests(unittest.TestCase):
http = http_mock.HttpMock(data=encoded_response) http = http_mock.HttpMock(data=encoded_response)
http = self.credentials.authorize(http) http = self.credentials.authorize(http)
http = self.credentials.authorize(http) http = self.credentials.authorize(http)
http.request('http://example.com') transport.request(http, 'http://example.com')
def test_token_refresh_failure(self): def test_token_refresh_failure(self):
for status_code in client.REFRESH_STATUS_CODES: for status_code in client.REFRESH_STATUS_CODES:
@@ -930,7 +931,7 @@ class BasicCredentialsTests(unittest.TestCase):
http = self.credentials.authorize(http) http = self.credentials.authorize(http)
with self.assertRaises( with self.assertRaises(
client.HttpAccessTokenRefreshError) as exc_manager: client.HttpAccessTokenRefreshError) as exc_manager:
http.request('http://example.com') transport.request(http, 'http://example.com')
self.assertEqual(http_client.BAD_REQUEST, self.assertEqual(http_client.BAD_REQUEST,
exc_manager.exception.status) exc_manager.exception.status)
self.assertTrue(self.credentials.access_token_expired) self.assertTrue(self.credentials.access_token_expired)
@@ -957,7 +958,7 @@ class BasicCredentialsTests(unittest.TestCase):
def test_non_401_error_response(self): def test_non_401_error_response(self):
http = http_mock.HttpMock(headers={'status': http_client.BAD_REQUEST}) http = http_mock.HttpMock(headers={'status': http_client.BAD_REQUEST})
http = self.credentials.authorize(http) http = self.credentials.authorize(http)
resp, content = http.request('http://example.com') resp, content = transport.request(http, 'http://example.com')
self.assertEqual(http_client.BAD_REQUEST, resp.status) self.assertEqual(http_client.BAD_REQUEST, resp.status)
self.assertEqual(None, self.credentials.token_response) self.assertEqual(None, self.credentials.token_response)
@@ -1001,7 +1002,8 @@ class BasicCredentialsTests(unittest.TestCase):
http = credentials.authorize(http_mock.HttpMock()) http = credentials.authorize(http_mock.HttpMock())
headers = {u'foo': 3, b'bar': True, 'baz': b'abc'} headers = {u'foo': 3, b'bar': True, 'baz': b'abc'}
cleaned_headers = {b'foo': b'3', b'bar': b'True', b'baz': b'abc'} cleaned_headers = {b'foo': b'3', b'bar': b'True', b'baz': b'abc'}
http.request(u'http://example.com', method=u'GET', headers=headers) transport.request(
http, u'http://example.com', method=u'GET', headers=headers)
for k, v in cleaned_headers.items(): for k, v in cleaned_headers.items():
self.assertTrue(k in http.headers) self.assertTrue(k in http.headers)
self.assertEqual(v, http.headers[k]) self.assertEqual(v, http.headers[k])
@@ -1009,8 +1011,9 @@ class BasicCredentialsTests(unittest.TestCase):
# Next, test that we do fail on unicode. # Next, test that we do fail on unicode.
unicode_str = six.unichr(40960) + 'abcd' unicode_str = six.unichr(40960) + 'abcd'
with self.assertRaises(client.NonAsciiHeaderError): with self.assertRaises(client.NonAsciiHeaderError):
http.request(u'http://example.com', method=u'GET', transport.request(
headers={u'foo': unicode_str}) http, u'http://example.com', method=u'GET',
headers={u'foo': unicode_str})
def test_no_unicode_in_request_params(self): def test_no_unicode_in_request_params(self):
access_token = u'foo' access_token = u'foo'
@@ -1027,8 +1030,9 @@ class BasicCredentialsTests(unittest.TestCase):
http = http_mock.HttpMock() http = http_mock.HttpMock()
http = credentials.authorize(http) http = credentials.authorize(http)
http.request(u'http://example.com', method=u'GET', transport.request(
headers={u'foo': u'bar'}) http, u'http://example.com', method=u'GET',
headers={u'foo': u'bar'})
for k, v in six.iteritems(http.headers): for k, v in six.iteritems(http.headers):
self.assertIsInstance(k, six.binary_type) self.assertIsInstance(k, six.binary_type)
self.assertIsInstance(v, six.binary_type) self.assertIsInstance(v, six.binary_type)
@@ -1036,8 +1040,8 @@ class BasicCredentialsTests(unittest.TestCase):
# Test again with unicode strings that can't simply be converted # Test again with unicode strings that can't simply be converted
# to ASCII. # to ASCII.
with self.assertRaises(client.NonAsciiHeaderError): with self.assertRaises(client.NonAsciiHeaderError):
http.request( transport.request(
u'http://example.com', method=u'GET', http, u'http://example.com', method=u'GET',
headers={u'foo': u'\N{COMET}'}) headers={u'foo': u'\N{COMET}'})
self.credentials.token_response = 'foobar' self.credentials.token_response = 'foobar'
@@ -1473,7 +1477,7 @@ class BasicCredentialsTests(unittest.TestCase):
({'status': '200'}, 'echo_request_headers'), ({'status': '200'}, 'echo_request_headers'),
]) ])
http = self.credentials.authorize(http) http = self.credentials.authorize(http)
resp, content = http.request('http://example.com') resp, content = transport.request(http, 'http://example.com')
self.assertEqual(self.credentials.id_token, body) self.assertEqual(self.credentials.id_token, body)
@@ -1492,7 +1496,7 @@ class AccessTokenCredentialsTests(unittest.TestCase):
headers={'status': status_code}, data=b'') headers={'status': status_code}, data=b'')
http = self.credentials.authorize(http) http = self.credentials.authorize(http)
with self.assertRaises(client.AccessTokenCredentialsError): with self.assertRaises(client.AccessTokenCredentialsError):
resp, content = http.request('http://example.com') resp, content = transport.request(http, 'http://example.com')
def test_token_revoke_success(self): def test_token_revoke_success(self):
_token_revoke_test_helper( _token_revoke_test_helper(
@@ -1507,7 +1511,7 @@ class AccessTokenCredentialsTests(unittest.TestCase):
def test_non_401_error_response(self): def test_non_401_error_response(self):
http = http_mock.HttpMock(headers={'status': http_client.BAD_REQUEST}) http = http_mock.HttpMock(headers={'status': http_client.BAD_REQUEST})
http = self.credentials.authorize(http) http = self.credentials.authorize(http)
resp, content = http.request('http://example.com') resp, content = transport.request(http, 'http://example.com')
self.assertEqual(http_client.BAD_REQUEST, resp.status) self.assertEqual(http_client.BAD_REQUEST, resp.status)
def test_auth_header_sent(self): def test_auth_header_sent(self):
@@ -1515,7 +1519,7 @@ class AccessTokenCredentialsTests(unittest.TestCase):
({'status': '200'}, 'echo_request_headers'), ({'status': '200'}, 'echo_request_headers'),
]) ])
http = self.credentials.authorize(http) http = self.credentials.authorize(http)
resp, content = http.request('http://example.com') resp, content = transport.request(http, 'http://example.com')
self.assertEqual(b'Bearer foo', content[b'Authorization']) self.assertEqual(b'Bearer foo', content[b'Authorization'])
@@ -1551,7 +1555,7 @@ class TestAssertionCredentials(unittest.TestCase):
({'status': '200'}, 'echo_request_headers'), ({'status': '200'}, 'echo_request_headers'),
]) ])
http = self.credentials.authorize(http) http = self.credentials.authorize(http)
resp, content = http.request('http://example.com') resp, content = transport.request(http, 'http://example.com')
self.assertEqual(b'Bearer 1/3w', content[b'Authorization']) self.assertEqual(b'Bearer 1/3w', content[b'Authorization'])
def test_token_revoke_success(self): def test_token_revoke_success(self):
@@ -1742,16 +1746,17 @@ class OAuth2WebServerFlowTest(unittest.TestCase):
device_code, user_code, None, ver_url, None) device_code, user_code, None, ver_url, None)
self.assertEqual(result, expected) self.assertEqual(result, expected)
self.assertEqual(len(http.requests), 1) self.assertEqual(len(http.requests), 1)
self.assertEqual( info = http.requests[0]
http.requests[0]['uri'], oauth2client.GOOGLE_DEVICE_URI) self.assertEqual(info['uri'], oauth2client.GOOGLE_DEVICE_URI)
body = http.requests[0]['body'] expected_body = {
self.assertEqual(urllib.parse.parse_qs(body), 'client_id': [flow.client_id],
{'client_id': [flow.client_id], 'scope': [flow.scope],
'scope': [flow.scope]}) }
self.assertEqual(urllib.parse.parse_qs(info['body']), expected_body)
headers = {'content-type': 'application/x-www-form-urlencoded'} headers = {'content-type': 'application/x-www-form-urlencoded'}
if extra_headers is not None: if extra_headers is not None:
headers.update(extra_headers) headers.update(extra_headers)
self.assertEqual(http.requests[0]['headers'], headers) self.assertEqual(info['headers'], headers)
def test_step1_get_device_and_user_codes(self): def test_step1_get_device_and_user_codes(self):
self._step1_get_device_and_user_codes_helper() self._step1_get_device_and_user_codes_helper()

View File

@@ -12,10 +12,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
"""Oauth2client.file tests """Unit tests for oauth2client.file."""
Unit tests for oauth2client.file
"""
import copy import copy
import datetime import datetime
@@ -30,11 +27,13 @@ import warnings
import mock import mock
import six import six
from six.moves import http_client from six.moves import http_client
from six.moves import urllib_parse
from oauth2client import _helpers from oauth2client import _helpers
from oauth2client import client from oauth2client import client
from oauth2client import file from oauth2client import file as file_module
from .http_mock import HttpMockSequence from oauth2client import transport
from . import http_mock
try: try:
# Python2 # Python2
@@ -80,14 +79,14 @@ class OAuth2ClientFileTests(unittest.TestCase):
@mock.patch('warnings.warn') @mock.patch('warnings.warn')
def test_non_existent_file_storage(self, warn_mock): def test_non_existent_file_storage(self, warn_mock):
storage = file.Storage(FILENAME) storage = file_module.Storage(FILENAME)
credentials = storage.get() credentials = storage.get()
warn_mock.assert_called_with( warn_mock.assert_called_with(
_helpers._MISSING_FILE_MESSAGE.format(FILENAME)) _helpers._MISSING_FILE_MESSAGE.format(FILENAME))
self.assertIsNone(credentials) self.assertIsNone(credentials)
def test_directory_file_storage(self): def test_directory_file_storage(self):
storage = file.Storage(FILENAME) storage = file_module.Storage(FILENAME)
os.mkdir(FILENAME) os.mkdir(FILENAME)
try: try:
with self.assertRaises(IOError): with self.assertRaises(IOError):
@@ -99,7 +98,7 @@ class OAuth2ClientFileTests(unittest.TestCase):
def test_no_sym_link_credentials(self): def test_no_sym_link_credentials(self):
SYMFILENAME = FILENAME + '.sym' SYMFILENAME = FILENAME + '.sym'
os.symlink(FILENAME, SYMFILENAME) os.symlink(FILENAME, SYMFILENAME)
storage = file.Storage(SYMFILENAME) storage = file_module.Storage(SYMFILENAME)
try: try:
with self.assertRaises(IOError): with self.assertRaises(IOError):
storage.get() storage.get()
@@ -116,7 +115,7 @@ class OAuth2ClientFileTests(unittest.TestCase):
# Storage should be not be able to read that object, as the capability # Storage should be not be able to read that object, as the capability
# to read and write credentials as pickled objects has been removed. # to read and write credentials as pickled objects has been removed.
storage = file.Storage(FILENAME) storage = file_module.Storage(FILENAME)
read_credentials = storage.get() read_credentials = storage.get()
self.assertIsNone(read_credentials) self.assertIsNone(read_credentials)
@@ -134,7 +133,7 @@ class OAuth2ClientFileTests(unittest.TestCase):
datetime.timedelta(minutes=15)) datetime.timedelta(minutes=15))
credentials = self._create_test_credentials(expiration=expiration) credentials = self._create_test_credentials(expiration=expiration)
storage = file.Storage(FILENAME) storage = file_module.Storage(FILENAME)
storage.put(credentials) storage.put(credentials)
credentials = storage.get() credentials = storage.get()
new_cred = copy.copy(credentials) new_cred = copy.copy(credentials)
@@ -143,13 +142,29 @@ class OAuth2ClientFileTests(unittest.TestCase):
access_token = '1/3w' access_token = '1/3w'
token_response = {'access_token': access_token, 'expires_in': 3600} token_response = {'access_token': access_token, 'expires_in': 3600}
http = HttpMockSequence([ response_content = json.dumps(token_response).encode('utf-8')
({'status': '200'}, json.dumps(token_response).encode('utf-8')), http = http_mock.HttpMock(data=response_content)
])
credentials._refresh(http.request) credentials._refresh(http)
self.assertEquals(credentials.access_token, access_token) self.assertEquals(credentials.access_token, access_token)
# Verify mocks.
self.assertEqual(http.requests, 1)
self.assertEqual(http.uri, credentials.token_uri)
self.assertEqual(http.method, 'POST')
expected_body = {
'grant_type': ['refresh_token'],
'client_id': [credentials.client_id],
'client_secret': [credentials.client_secret],
'refresh_token': [credentials.refresh_token],
}
self.assertEqual(urllib_parse.parse_qs(http.body), expected_body)
expected_headers = {
'content-type': 'application/x-www-form-urlencoded',
'user-agent': credentials.user_agent,
}
self.assertEqual(http.headers, expected_headers)
def test_token_refresh_store_expires_soon(self): def test_token_refresh_store_expires_soon(self):
# Tests the case where an access token that is valid when it is read # Tests the case where an access token that is valid when it is read
# from the store expires before the original request succeeds. # from the store expires before the original request succeeds.
@@ -157,7 +172,7 @@ class OAuth2ClientFileTests(unittest.TestCase):
datetime.timedelta(minutes=15)) datetime.timedelta(minutes=15))
credentials = self._create_test_credentials(expiration=expiration) credentials = self._create_test_credentials(expiration=expiration)
storage = file.Storage(FILENAME) storage = file_module.Storage(FILENAME)
storage.put(credentials) storage.put(credentials)
credentials = storage.get() credentials = storage.get()
new_cred = copy.copy(credentials) new_cred = copy.copy(credentials)
@@ -166,19 +181,19 @@ class OAuth2ClientFileTests(unittest.TestCase):
access_token = '1/3w' access_token = '1/3w'
token_response = {'access_token': access_token, 'expires_in': 3600} token_response = {'access_token': access_token, 'expires_in': 3600}
http = HttpMockSequence([ http = http_mock.HttpMockSequence([
({'status': str(int(http_client.UNAUTHORIZED))}, ({'status': http_client.UNAUTHORIZED},
b'Initial token expired'), b'Initial token expired'),
({'status': str(int(http_client.UNAUTHORIZED))}, ({'status': http_client.UNAUTHORIZED},
b'Store token expired'), b'Store token expired'),
({'status': str(int(http_client.OK))}, ({'status': http_client.OK},
json.dumps(token_response).encode('utf-8')), json.dumps(token_response).encode('utf-8')),
({'status': str(int(http_client.OK))}, ({'status': http_client.OK},
b'Valid response to original request') b'Valid response to original request')
]) ])
credentials.authorize(http) credentials.authorize(http)
http.request('https://example.com') transport.request(http, 'https://example.com')
self.assertEqual(credentials.access_token, access_token) self.assertEqual(credentials.access_token, access_token)
def test_token_refresh_good_store(self): def test_token_refresh_good_store(self):
@@ -186,7 +201,7 @@ class OAuth2ClientFileTests(unittest.TestCase):
datetime.timedelta(minutes=15)) datetime.timedelta(minutes=15))
credentials = self._create_test_credentials(expiration=expiration) credentials = self._create_test_credentials(expiration=expiration)
storage = file.Storage(FILENAME) storage = file_module.Storage(FILENAME)
storage.put(credentials) storage.put(credentials)
credentials = storage.get() credentials = storage.get()
new_cred = copy.copy(credentials) new_cred = copy.copy(credentials)
@@ -201,7 +216,7 @@ class OAuth2ClientFileTests(unittest.TestCase):
datetime.timedelta(minutes=15)) datetime.timedelta(minutes=15))
credentials = self._create_test_credentials(expiration=expiration) credentials = self._create_test_credentials(expiration=expiration)
storage = file.Storage(FILENAME) storage = file_module.Storage(FILENAME)
storage.put(credentials) storage.put(credentials)
credentials = storage.get() credentials = storage.get()
new_cred = copy.copy(credentials) new_cred = copy.copy(credentials)
@@ -211,27 +226,27 @@ class OAuth2ClientFileTests(unittest.TestCase):
valid_access_token = '1/3w' valid_access_token = '1/3w'
token_response = {'access_token': valid_access_token, token_response = {'access_token': valid_access_token,
'expires_in': 3600} 'expires_in': 3600}
http = HttpMockSequence([ http = http_mock.HttpMockSequence([
({'status': str(int(http_client.UNAUTHORIZED))}, ({'status': http_client.UNAUTHORIZED},
b'Initial token expired'), b'Initial token expired'),
({'status': str(int(http_client.UNAUTHORIZED))}, ({'status': http_client.UNAUTHORIZED},
b'Store token expired'), b'Store token expired'),
({'status': str(int(http_client.OK))}, ({'status': http_client.OK},
json.dumps(token_response).encode('utf-8')), json.dumps(token_response).encode('utf-8')),
({'status': str(int(http_client.OK))}, 'echo_request_body') ({'status': http_client.OK}, 'echo_request_body')
]) ])
body = six.StringIO('streaming body') body = six.StringIO('streaming body')
credentials.authorize(http) credentials.authorize(http)
_, content = http.request('https://example.com', body=body) _, content = transport.request(http, 'https://example.com', body=body)
self.assertEqual(content, 'streaming body') self.assertEqual(content, 'streaming body')
self.assertEqual(credentials.access_token, valid_access_token) self.assertEqual(credentials.access_token, valid_access_token)
def test_credentials_delete(self): def test_credentials_delete(self):
credentials = self._create_test_credentials() credentials = self._create_test_credentials()
storage = file.Storage(FILENAME) storage = file_module.Storage(FILENAME)
storage.put(credentials) storage.put(credentials)
credentials = storage.get() credentials = storage.get()
self.assertIsNotNone(credentials) self.assertIsNotNone(credentials)
@@ -245,7 +260,7 @@ class OAuth2ClientFileTests(unittest.TestCase):
credentials = client.AccessTokenCredentials(access_token, user_agent) credentials = client.AccessTokenCredentials(access_token, user_agent)
storage = file.Storage(FILENAME) storage = file_module.Storage(FILENAME)
credentials = storage.put(credentials) credentials = storage.put(credentials)
credentials = storage.get() credentials = storage.get()

View File

@@ -20,13 +20,15 @@ import time
import unittest import unittest
import mock import mock
from six.moves import http_client
from oauth2client import _helpers from oauth2client import _helpers
from oauth2client import client from oauth2client import client
from oauth2client import crypt from oauth2client import crypt
from oauth2client import file from oauth2client import file as file_module
from oauth2client import service_account from oauth2client import service_account
from .http_mock import HttpMockSequence from oauth2client import transport
from . import http_mock
__author__ = 'jcgregorio@google.com (Joe Gregorio)' __author__ = 'jcgregorio@google.com (Joe Gregorio)'
@@ -114,25 +116,30 @@ class CryptTests(unittest.TestCase):
self.assertEqual('billy bob', contents['user']) self.assertEqual('billy bob', contents['user'])
self.assertEqual('data', contents['metadata']['meta']) self.assertEqual('data', contents['metadata']['meta'])
def _verify_http_mock(self, http):
self.assertEqual(http.requests, 1)
self.assertEqual(http.uri, client.ID_TOKEN_VERIFICATION_CERTS)
self.assertEqual(http.method, 'GET')
self.assertIsNone(http.body)
self.assertIsNone(http.headers)
def test_verify_id_token_with_certs_uri(self): def test_verify_id_token_with_certs_uri(self):
jwt = self._create_signed_jwt() jwt = self._create_signed_jwt()
http = HttpMockSequence([ http = http_mock.HttpMock(data=datafile('certs.json'))
({'status': '200'}, datafile('certs.json')),
])
contents = client.verify_id_token( contents = client.verify_id_token(
jwt, 'some_audience_address@testing.gserviceaccount.com', jwt, 'some_audience_address@testing.gserviceaccount.com',
http=http) http=http)
self.assertEqual('billy bob', contents['user']) self.assertEqual('billy bob', contents['user'])
self.assertEqual('data', contents['metadata']['meta']) self.assertEqual('data', contents['metadata']['meta'])
# Verify mocks.
self._verify_http_mock(http)
def test_verify_id_token_with_certs_uri_default_http(self): def test_verify_id_token_with_certs_uri_default_http(self):
jwt = self._create_signed_jwt() jwt = self._create_signed_jwt()
http = HttpMockSequence([ http = http_mock.HttpMock(data=datafile('certs.json'))
({'status': '200'}, datafile('certs.json')),
])
with mock.patch('oauth2client.transport._CACHED_HTTP', new=http): with mock.patch('oauth2client.transport._CACHED_HTTP', new=http):
contents = client.verify_id_token( contents = client.verify_id_token(
@@ -141,17 +148,23 @@ class CryptTests(unittest.TestCase):
self.assertEqual('billy bob', contents['user']) self.assertEqual('billy bob', contents['user'])
self.assertEqual('data', contents['metadata']['meta']) self.assertEqual('data', contents['metadata']['meta'])
# Verify mocks.
self._verify_http_mock(http)
def test_verify_id_token_with_certs_uri_fails(self): def test_verify_id_token_with_certs_uri_fails(self):
jwt = self._create_signed_jwt() jwt = self._create_signed_jwt()
test_email = 'some_audience_address@testing.gserviceaccount.com' test_email = 'some_audience_address@testing.gserviceaccount.com'
http = HttpMockSequence([ http = http_mock.HttpMock(
({'status': '404'}, datafile('certs.json')), headers={'status': http_client.NOT_FOUND},
]) data=datafile('certs.json'))
with self.assertRaises(client.VerifyJwtTokenError): with self.assertRaises(client.VerifyJwtTokenError):
client.verify_id_token(jwt, test_email, http=http) client.verify_id_token(jwt, test_email, http=http)
# Verify mocks.
self._verify_http_mock(http)
def test_verify_id_token_bad_tokens(self): def test_verify_id_token_bad_tokens(self):
private_key = datafile('privatekey.' + self.format_) private_key = datafile('privatekey.' + self.format_)
@@ -261,12 +274,13 @@ class SignedJwtAssertionCredentialsTests(unittest.TestCase):
def test_credentials_good(self): def test_credentials_good(self):
credentials = self._make_credentials() credentials = self._make_credentials()
http = HttpMockSequence([ http = http_mock.HttpMockSequence([
({'status': '200'}, b'{"access_token":"1/3w","expires_in":3600}'), ({'status': http_client.OK},
({'status': '200'}, 'echo_request_headers'), b'{"access_token":"1/3w","expires_in":3600}'),
({'status': http_client.OK}, 'echo_request_headers'),
]) ])
http = credentials.authorize(http) http = credentials.authorize(http)
resp, content = http.request('http://example.org') resp, content = transport.request(http, 'http://example.org')
self.assertEqual(b'Bearer 1/3w', content[b'Authorization']) self.assertEqual(b'Bearer 1/3w', content[b'Authorization'])
def test_credentials_to_from_json(self): def test_credentials_to_from_json(self):
@@ -280,14 +294,16 @@ class SignedJwtAssertionCredentialsTests(unittest.TestCase):
self.assertEqual(credentials._kwargs, restored._kwargs) self.assertEqual(credentials._kwargs, restored._kwargs)
def _credentials_refresh(self, credentials): def _credentials_refresh(self, credentials):
http = HttpMockSequence([ http = http_mock.HttpMockSequence([
({'status': '200'}, b'{"access_token":"1/3w","expires_in":3600}'), ({'status': http_client.OK},
({'status': '401'}, b''), b'{"access_token":"1/3w","expires_in":3600}'),
({'status': '200'}, b'{"access_token":"3/3w","expires_in":3600}'), ({'status': http_client.UNAUTHORIZED}, b''),
({'status': '200'}, 'echo_request_headers'), ({'status': http_client.OK},
b'{"access_token":"3/3w","expires_in":3600}'),
({'status': http_client.OK}, 'echo_request_headers'),
]) ])
http = credentials.authorize(http) http = credentials.authorize(http)
_, content = http.request('http://example.org') _, content = transport.request(http, 'http://example.org')
return content return content
def test_credentials_refresh_without_storage(self): def test_credentials_refresh_without_storage(self):
@@ -300,7 +316,7 @@ class SignedJwtAssertionCredentialsTests(unittest.TestCase):
filehandle, filename = tempfile.mkstemp() filehandle, filename = tempfile.mkstemp()
os.close(filehandle) os.close(filehandle)
store = file.Storage(filename) store = file_module.Storage(filename)
store.put(credentials) store.put(credentials)
credentials.set_store(store) credentials.set_store(store)

View File

@@ -407,13 +407,15 @@ class JWTAccessCredentialsTests(unittest.TestCase):
time.return_value = T1 time.return_value = T1
token_info = self.jwt.get_access_token() token_info = self.jwt.get_access_token()
certs = {'key': datafile('public_cert.pem')}
payload = crypt.verify_signed_jwt_with_certs( payload = crypt.verify_signed_jwt_with_certs(
token_info.access_token, token_info.access_token, certs, audience=self.url)
{'key': datafile('public_cert.pem')}, audience=self.url) self.assertEqual(len(payload), 5)
self.assertEqual(payload['iss'], self.service_account_email) self.assertEqual(payload['iss'], self.service_account_email)
self.assertEqual(payload['sub'], self.service_account_email) self.assertEqual(payload['sub'], self.service_account_email)
self.assertEqual(payload['iat'], T1) self.assertEqual(payload['iat'], T1)
self.assertEqual(payload['exp'], T1_EXPIRY) self.assertEqual(payload['exp'], T1_EXPIRY)
self.assertEqual(payload['aud'], self.url)
self.assertEqual(token_info.expires_in, T1_EXPIRY - T1) self.assertEqual(token_info.expires_in, T1_EXPIRY - T1)
# Verify that we vend the same token after 100 seconds # Verify that we vend the same token after 100 seconds
@@ -444,19 +446,20 @@ class JWTAccessCredentialsTests(unittest.TestCase):
utcnow.return_value = T1_DATE utcnow.return_value = T1_DATE
time.return_value = T1 time.return_value = T1
token_info = self.jwt.get_access_token( audience = 'https://test2.url.com'
additional_claims={'aud': 'https://test2.url.com', subject = 'dummy2@google.com'
'sub': 'dummy2@google.com' claims = {'aud': audience, 'sub': subject}
}) token_info = self.jwt.get_access_token(additional_claims=claims)
certs = {'key': datafile('public_cert.pem')}
payload = crypt.verify_signed_jwt_with_certs( payload = crypt.verify_signed_jwt_with_certs(
token_info.access_token, token_info.access_token, certs, audience=audience)
{'key': datafile('public_cert.pem')},
audience='https://test2.url.com')
expires_in = token_info.expires_in expires_in = token_info.expires_in
self.assertEqual(len(payload), 5)
self.assertEqual(payload['iss'], self.service_account_email) self.assertEqual(payload['iss'], self.service_account_email)
self.assertEqual(payload['sub'], 'dummy2@google.com') self.assertEqual(payload['sub'], subject)
self.assertEqual(payload['iat'], T1) self.assertEqual(payload['iat'], T1)
self.assertEqual(payload['exp'], T1_EXPIRY) self.assertEqual(payload['exp'], T1_EXPIRY)
self.assertEqual(payload['aud'], audience)
self.assertEqual(expires_in, T1_EXPIRY - T1) self.assertEqual(expires_in, T1_EXPIRY - T1)
def test_revoke(self): def test_revoke(self):
@@ -502,13 +505,15 @@ class JWTAccessCredentialsTests(unittest.TestCase):
self.assertIsNone(info['body']) self.assertIsNone(info['body'])
self.assertEqual(len(info['headers']), 1) self.assertEqual(len(info['headers']), 1)
bearer, token = info['headers'][b'Authorization'].split() bearer, token = info['headers'][b'Authorization'].split()
self.assertEqual(bearer, b'Bearer')
payload = crypt.verify_signed_jwt_with_certs( payload = crypt.verify_signed_jwt_with_certs(
token, certs, audience=self.url) token, certs, audience=self.url)
self.assertEqual(len(payload), 5)
self.assertEqual(payload['iss'], self.service_account_email) self.assertEqual(payload['iss'], self.service_account_email)
self.assertEqual(payload['sub'], self.service_account_email) self.assertEqual(payload['sub'], self.service_account_email)
self.assertEqual(payload['iat'], T1) self.assertEqual(payload['iat'], T1)
self.assertEqual(payload['exp'], T1_EXPIRY) self.assertEqual(payload['exp'], T1_EXPIRY)
self.assertEqual(bearer, b'Bearer') self.assertEqual(payload['aud'], self.url)
@mock.patch('oauth2client.client._UTCNOW') @mock.patch('oauth2client.client._UTCNOW')
@mock.patch('time.time') @mock.patch('time.time')
@@ -538,53 +543,108 @@ class JWTAccessCredentialsTests(unittest.TestCase):
self.assertIsNone(info['body']) self.assertIsNone(info['body'])
self.assertEqual(len(info['headers']), 1) self.assertEqual(len(info['headers']), 1)
bearer, token = info['headers'][b'Authorization'].split() bearer, token = info['headers'][b'Authorization'].split()
self.assertEqual(bearer, b'Bearer')
certs = {'key': datafile('public_cert.pem')} certs = {'key': datafile('public_cert.pem')}
payload = crypt.verify_signed_jwt_with_certs( payload = crypt.verify_signed_jwt_with_certs(
token, certs, audience=self.url) token, certs, audience=self.url)
self.assertEqual(len(payload), 5)
self.assertEqual(payload['iss'], self.service_account_email) self.assertEqual(payload['iss'], self.service_account_email)
self.assertEqual(payload['sub'], self.service_account_email) self.assertEqual(payload['sub'], self.service_account_email)
self.assertEqual(payload['iat'], T1) self.assertEqual(payload['iat'], T1)
self.assertEqual(payload['exp'], T1_EXPIRY) self.assertEqual(payload['exp'], T1_EXPIRY)
self.assertEqual(bearer, b'Bearer') self.assertEqual(payload['aud'], self.url)
@mock.patch('oauth2client.client._UTCNOW') @mock.patch('oauth2client.client._UTCNOW')
def test_authorize_stale_token(self, utcnow): def test_authorize_stale_token(self, utcnow):
utcnow.return_value = T1_DATE utcnow.return_value = T1_DATE
# Create an initial token # Create an initial token
h = http_mock.HttpMockSequence([ http = http_mock.HttpMockSequence([
({'status': http_client.OK}, b''), ({'status': http_client.OK}, b''),
({'status': http_client.OK}, b''), ({'status': http_client.OK}, b''),
]) ])
self.jwt.authorize(h) self.jwt.authorize(http)
h.request(self.url) transport.request(http, self.url)
token_1 = self.jwt.access_token token_1 = self.jwt.access_token
# Expire the token # Expire the token
utcnow.return_value = T3_DATE utcnow.return_value = T3_DATE
h.request(self.url) transport.request(http, self.url)
token_2 = self.jwt.access_token token_2 = self.jwt.access_token
self.assertEquals(self.jwt.token_expiry, T3_EXPIRY_DATE) self.assertEquals(self.jwt.token_expiry, T3_EXPIRY_DATE)
self.assertNotEqual(token_1, token_2) self.assertNotEqual(token_1, token_2)
# Verify mocks.
certs = {'key': datafile('public_cert.pem')}
self.assertEqual(len(http.requests), 2)
issued_at_vals = (T1, T3)
exp_vals = (T1_EXPIRY, T3_EXPIRY)
for info, issued_at, exp_val in zip(http.requests, issued_at_vals,
exp_vals):
self.assertEqual(info['uri'], self.url)
self.assertEqual(info['method'], 'GET')
self.assertIsNone(info['body'])
self.assertEqual(len(info['headers']), 1)
bearer, token = info['headers'][b'Authorization'].split()
self.assertEqual(bearer, b'Bearer')
# To parse the token, skip the time check, since this
# test intentionally has stale tokens.
with mock.patch('oauth2client.crypt._verify_time_range',
return_value=True):
payload = crypt.verify_signed_jwt_with_certs(
token, certs, audience=self.url)
self.assertEqual(len(payload), 5)
self.assertEqual(payload['iss'], self.service_account_email)
self.assertEqual(payload['sub'], self.service_account_email)
self.assertEqual(payload['iat'], issued_at)
self.assertEqual(payload['exp'], exp_val)
self.assertEqual(payload['aud'], self.url)
@mock.patch('oauth2client.client._UTCNOW') @mock.patch('oauth2client.client._UTCNOW')
def test_authorize_401(self, utcnow): def test_authorize_401(self, utcnow):
utcnow.return_value = T1_DATE utcnow.return_value = T1_DATE
h = http_mock.HttpMockSequence([ http = http_mock.HttpMockSequence([
({'status': http_client.OK}, b''), ({'status': http_client.OK}, b''),
({'status': http_client.UNAUTHORIZED}, b''), ({'status': http_client.UNAUTHORIZED}, b''),
({'status': http_client.OK}, b''), ({'status': http_client.OK}, b''),
]) ])
self.jwt.authorize(h) self.jwt.authorize(http)
h.request(self.url) transport.request(http, self.url)
token_1 = self.jwt.access_token token_1 = self.jwt.access_token
utcnow.return_value = T2_DATE utcnow.return_value = T2_DATE
self.assertEquals(h.request(self.url)[0].status, 200) response, _ = transport.request(http, self.url)
self.assertEquals(response.status, http_client.OK)
token_2 = self.jwt.access_token token_2 = self.jwt.access_token
# Check the 401 forced a new token # Check the 401 forced a new token
self.assertNotEqual(token_1, token_2) self.assertNotEqual(token_1, token_2)
# Verify mocks.
certs = {'key': datafile('public_cert.pem')}
self.assertEqual(len(http.requests), 3)
issued_at_vals = (T1, T1, T2)
exp_vals = (T1_EXPIRY, T1_EXPIRY, T2_EXPIRY)
for info, issued_at, exp_val in zip(http.requests, issued_at_vals,
exp_vals):
self.assertEqual(info['uri'], self.url)
self.assertEqual(info['method'], 'GET')
self.assertIsNone(info['body'])
self.assertEqual(len(info['headers']), 1)
bearer, token = info['headers'][b'Authorization'].split()
self.assertEqual(bearer, b'Bearer')
# To parse the token, skip the time check, since this
# test intentionally has stale tokens.
with mock.patch('oauth2client.crypt._verify_time_range',
return_value=True):
payload = crypt.verify_signed_jwt_with_certs(
token, certs, audience=self.url)
self.assertEqual(len(payload), 5)
self.assertEqual(payload['iss'], self.service_account_email)
self.assertEqual(payload['sub'], self.service_account_email)
self.assertEqual(payload['iat'], issued_at)
self.assertEqual(payload['exp'], exp_val)
self.assertEqual(payload['aud'], self.url)
@mock.patch('oauth2client.client._UTCNOW') @mock.patch('oauth2client.client._UTCNOW')
def test_refresh(self, utcnow): def test_refresh(self, utcnow):
utcnow.return_value = T1_DATE utcnow.return_value = T1_DATE