#!/usr/bin/python2.4 # # Copyright 2010 Google Inc. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Oauth2client tests Unit tests for oauth2client. """ __author__ = 'jcgregorio@google.com (Joe Gregorio)' import base64 import datetime import httplib2 import unittest import urlparse try: from urlparse import parse_qs except ImportError: from cgi import parse_qs from apiclient.http import HttpMockSequence from oauth2client.anyjson import simplejson from oauth2client.client import AccessTokenCredentials from oauth2client.client import AccessTokenCredentialsError from oauth2client.client import AccessTokenRefreshError from oauth2client.client import AssertionCredentials from oauth2client.client import FlowExchangeError from oauth2client.client import OAuth2Credentials from oauth2client.client import OAuth2WebServerFlow from oauth2client.client import OOB_CALLBACK_URN from oauth2client.client import VerifyJwtTokenError from oauth2client.client import _extract_id_token class OAuth2CredentialsTests(unittest.TestCase): def setUp(self): access_token = "foo" client_id = "some_client_id" client_secret = "cOuDdkfjxxnv+" refresh_token = "1/0/a.df219fjls0" token_expiry = datetime.datetime.utcnow() token_uri = "https://www.google.com/accounts/o8/oauth2/token" user_agent = "refresh_checker/1.0" self.credentials = OAuth2Credentials( access_token, client_id, client_secret, refresh_token, token_expiry, token_uri, user_agent) def test_token_refresh_success(self): http = HttpMockSequence([ ({'status': '401'}, ''), ({'status': '200'}, '{"access_token":"1/3w","expires_in":3600}'), ({'status': '200'}, 'echo_request_headers'), ]) http = self.credentials.authorize(http) resp, content = http.request("http://example.com") self.assertEqual('Bearer 1/3w', content['Authorization']) def test_token_refresh_failure(self): http = HttpMockSequence([ ({'status': '401'}, ''), ({'status': '400'}, '{"error":"access_denied"}'), ]) http = self.credentials.authorize(http) try: http.request("http://example.com") self.fail("should raise AccessTokenRefreshError exception") except AccessTokenRefreshError: pass def test_non_401_error_response(self): http = HttpMockSequence([ ({'status': '400'}, ''), ]) http = self.credentials.authorize(http) resp, content = http.request("http://example.com") self.assertEqual(400, resp.status) def test_to_from_json(self): json = self.credentials.to_json() instance = OAuth2Credentials.from_json(json) self.assertEqual(OAuth2Credentials, type(instance)) instance.token_expiry = None self.credentials.token_expiry = None self.assertEqual(instance.__dict__, self.credentials.__dict__) class AccessTokenCredentialsTests(unittest.TestCase): def setUp(self): access_token = "foo" user_agent = "refresh_checker/1.0" self.credentials = AccessTokenCredentials(access_token, user_agent) def test_token_refresh_success(self): http = HttpMockSequence([ ({'status': '401'}, ''), ]) http = self.credentials.authorize(http) try: resp, content = http.request("http://example.com") self.fail("should throw exception if token expires") except AccessTokenCredentialsError: pass except Exception: self.fail("should only throw AccessTokenCredentialsError") def test_non_401_error_response(self): http = HttpMockSequence([ ({'status': '400'}, ''), ]) http = self.credentials.authorize(http) resp, content = http.request('http://example.com') self.assertEqual(400, resp.status) def test_auth_header_sent(self): http = HttpMockSequence([ ({'status': '200'}, 'echo_request_headers'), ]) http = self.credentials.authorize(http) resp, content = http.request('http://example.com') self.assertEqual('Bearer foo', content['Authorization']) class TestAssertionCredentials(unittest.TestCase): assertion_text = "This is the assertion" assertion_type = "http://www.google.com/assertionType" class AssertionCredentialsTestImpl(AssertionCredentials): def _generate_assertion(self): return TestAssertionCredentials.assertion_text def setUp(self): user_agent = "fun/2.0" self.credentials = self.AssertionCredentialsTestImpl(self.assertion_type, user_agent) def test_assertion_body(self): body = urlparse.parse_qs(self.credentials._generate_refresh_request_body()) self.assertEqual(self.assertion_text, body['assertion'][0]) self.assertEqual(self.assertion_type, body['assertion_type'][0]) def test_assertion_refresh(self): http = HttpMockSequence([ ({'status': '200'}, '{"access_token":"1/3w"}'), ({'status': '200'}, 'echo_request_headers'), ]) http = self.credentials.authorize(http) resp, content = http.request("http://example.com") self.assertEqual('Bearer 1/3w', content['Authorization']) class ExtractIdTokenText(unittest.TestCase): """Tests _extract_id_token().""" def test_extract_success(self): body = {'foo': 'bar'} payload = base64.urlsafe_b64encode(simplejson.dumps(body)).strip('=') jwt = 'stuff.' + payload + '.signature' extracted = _extract_id_token(jwt) self.assertEqual(extracted, body) def test_extract_failure(self): body = {'foo': 'bar'} payload = base64.urlsafe_b64encode(simplejson.dumps(body)).strip('=') jwt = 'stuff.' + payload self.assertRaises(VerifyJwtTokenError, _extract_id_token, jwt) class OAuth2WebServerFlowTest(unittest.TestCase): def setUp(self): self.flow = OAuth2WebServerFlow( client_id='client_id+1', client_secret='secret+1', scope='foo', user_agent='unittest-sample/1.0', ) def test_construct_authorize_url(self): authorize_url = self.flow.step1_get_authorize_url('OOB_CALLBACK_URN') parsed = urlparse.urlparse(authorize_url) q = parse_qs(parsed[4]) self.assertEqual('client_id+1', q['client_id'][0]) self.assertEqual('code', q['response_type'][0]) self.assertEqual('foo', q['scope'][0]) self.assertEqual('OOB_CALLBACK_URN', q['redirect_uri'][0]) self.assertEqual('offline', q['access_type'][0]) def test_override_flow_access_type(self): """Passing access_type overrides the default.""" flow = OAuth2WebServerFlow( client_id='client_id+1', client_secret='secret+1', scope='foo', user_agent='unittest-sample/1.0', access_type='online' ) authorize_url = flow.step1_get_authorize_url('OOB_CALLBACK_URN') parsed = urlparse.urlparse(authorize_url) q = parse_qs(parsed[4]) self.assertEqual('client_id+1', q['client_id'][0]) self.assertEqual('code', q['response_type'][0]) self.assertEqual('foo', q['scope'][0]) self.assertEqual('OOB_CALLBACK_URN', q['redirect_uri'][0]) self.assertEqual('online', q['access_type'][0]) def test_exchange_failure(self): http = HttpMockSequence([ ({'status': '400'}, '{"error":"invalid_request"}'), ]) try: credentials = self.flow.step2_exchange('some random code', http) self.fail("should raise exception if exchange doesn't get 200") except FlowExchangeError: pass def test_exchange_success(self): http = HttpMockSequence([ ({'status': '200'}, """{ "access_token":"SlAV32hkKG", "expires_in":3600, "refresh_token":"8xLOxBtZp8" }"""), ]) credentials = self.flow.step2_exchange('some random code', http) self.assertEqual('SlAV32hkKG', credentials.access_token) self.assertNotEqual(None, credentials.token_expiry) self.assertEqual('8xLOxBtZp8', credentials.refresh_token) def test_exchange_no_expires_in(self): http = HttpMockSequence([ ({'status': '200'}, """{ "access_token":"SlAV32hkKG", "refresh_token":"8xLOxBtZp8" }"""), ]) credentials = self.flow.step2_exchange('some random code', http) self.assertEqual(None, credentials.token_expiry) def test_exchange_id_token_fail(self): http = HttpMockSequence([ ({'status': '200'}, """{ "access_token":"SlAV32hkKG", "refresh_token":"8xLOxBtZp8", "id_token": "stuff.payload"}"""), ]) self.assertRaises(VerifyJwtTokenError, self.flow.step2_exchange, 'some random code', http) def test_exchange_id_token_fail(self): body = {'foo': 'bar'} payload = base64.urlsafe_b64encode(simplejson.dumps(body)).strip('=') jwt = (base64.urlsafe_b64encode('stuff')+ '.' + payload + '.' + base64.urlsafe_b64encode('signature')) http = HttpMockSequence([ ({'status': '200'}, """{ "access_token":"SlAV32hkKG", "refresh_token":"8xLOxBtZp8", "id_token": "%s"}""" % jwt), ]) credentials = self.flow.step2_exchange('some random code', http) self.assertEqual(credentials.id_token, body) if __name__ == '__main__': unittest.main()