Re-enable Python3 tests and fixing Py2/Py3 issues in tests.

Partially fixes #85.
This commit is contained in:
Danny Hermes
2014-12-06 03:02:47 -08:00
parent b20d3d6f66
commit 3c647bd40c
5 changed files with 46 additions and 36 deletions

View File

@@ -5,6 +5,8 @@ env:
- TOX_ENV=py26openssl14 - TOX_ENV=py26openssl14
- TOX_ENV=py27openssl13 - TOX_ENV=py27openssl13
- TOX_ENV=py27openssl14 - TOX_ENV=py27openssl14
- TOX_ENV=py33openssl14
- TOX_ENV=py34openssl14
- TOX_ENV=pypyopenssl13 - TOX_ENV=pypyopenssl13
- TOX_ENV=pypyopenssl14 - TOX_ENV=pypyopenssl14
install: install:

View File

@@ -409,7 +409,7 @@ def clean_headers(headers):
clean = {} clean = {}
try: try:
for k, v in six.iteritems(headers): for k, v in six.iteritems(headers):
clean[str(k)] = str(v) clean[k.encode('ascii')] = v.encode('ascii')
except UnicodeEncodeError: except UnicodeEncodeError:
raise NonAsciiHeaderError(k + ': ' + v) raise NonAsciiHeaderError(k + ': ' + v)
return clean return clean
@@ -1252,16 +1252,14 @@ def _get_well_known_file():
return default_config_path return default_config_path
def _get_application_default_credential_from_file( def _get_application_default_credential_from_file(filename):
application_default_credential_filename):
"""Build the Application Default Credentials from file.""" """Build the Application Default Credentials from file."""
from oauth2client import service_account from oauth2client import service_account
# read the credentials from the file # read the credentials from the file
with open(application_default_credential_filename) as ( with open(filename) as file_obj:
application_default_credential): client_credentials = json.load(file_obj)
client_credentials = json.load(application_default_credential)
credentials_type = client_credentials.get('type') credentials_type = client_credentials.get('type')
if credentials_type == AUTHORIZED_USER: if credentials_type == AUTHORIZED_USER:
@@ -1545,12 +1543,15 @@ def _extract_id_token(id_token):
Does the extraction w/o checking the signature. Does the extraction w/o checking the signature.
Args: Args:
id_token: string, OAuth 2.0 id_token. id_token: string or bytestring, OAuth 2.0 id_token.
Returns: Returns:
object, The deserialized JSON payload. object, The deserialized JSON payload.
""" """
segments = id_token.split('.') if type(id_token) == bytes:
segments = id_token.split(b'.')
else:
segments = id_token.split(u'.')
if len(segments) != 3: if len(segments) != 3:
raise VerifyJwtTokenError( raise VerifyJwtTokenError(
@@ -1578,6 +1579,7 @@ def _parse_exchange_token_response(content):
except Exception: except Exception:
# different JSON libs raise different exceptions, # different JSON libs raise different exceptions,
# so we just do a catch-all here # so we just do a catch-all here
content = content.decode('utf-8')
resp = dict(urllib.parse.parse_qsl(content)) resp = dict(urllib.parse.parse_qsl(content))
# some providers respond with 'expires', others with 'expires_in' # some providers respond with 'expires', others with 'expires_in'

View File

@@ -19,6 +19,7 @@ This credentials class is implemented on top of rsa library.
import base64 import base64
import json import json
import six
import time import time
from pyasn1.codec.ber import decoder from pyasn1.codec.ber import decoder
@@ -131,6 +132,8 @@ def _urlsafe_b64encode(data):
def _get_private_key(private_key_pkcs8_text): def _get_private_key(private_key_pkcs8_text):
"""Get an RSA private key object from a pkcs8 representation.""" """Get an RSA private key object from a pkcs8 representation."""
if not isinstance(private_key_pkcs8_text, six.binary_type):
private_key_pkcs8_text = private_key_pkcs8_text.encode('ascii')
der = rsa.pem.load_pem(private_key_pkcs8_text, 'PRIVATE KEY') der = rsa.pem.load_pem(private_key_pkcs8_text, 'PRIVATE KEY')
asn1_private_key, _ = decoder.decode(der, asn1Spec=PrivateKeyInfo()) asn1_private_key, _ = decoder.decode(der, asn1Spec=PrivateKeyInfo())
return rsa.PrivateKey.load_pkcs1( return rsa.PrivateKey.load_pkcs1(

View File

@@ -217,7 +217,7 @@ class SignedJwtAssertionCredentialsTests(unittest.TestCase):
]) ])
http = credentials.authorize(http) http = credentials.authorize(http)
resp, content = http.request('http://example.org') resp, content = http.request('http://example.org')
self.assertEqual('Bearer 1/3w', content['Authorization']) self.assertEqual(b'Bearer 1/3w', content[b'Authorization'])
def test_credentials_to_from_json(self): def test_credentials_to_from_json(self):
private_key = datafile('privatekey.%s' % self.format) private_key = datafile('privatekey.%s' % self.format)
@@ -254,7 +254,7 @@ class SignedJwtAssertionCredentialsTests(unittest.TestCase):
content = self._credentials_refresh(credentials) content = self._credentials_refresh(credentials)
self.assertEqual('Bearer 3/3w', content['Authorization']) self.assertEqual(b'Bearer 3/3w', content[b'Authorization'])
def test_credentials_refresh_with_storage(self): def test_credentials_refresh_with_storage(self):
private_key = datafile('privatekey.%s' % self.format) private_key = datafile('privatekey.%s' % self.format)
@@ -272,7 +272,7 @@ class SignedJwtAssertionCredentialsTests(unittest.TestCase):
content = self._credentials_refresh(credentials) content = self._credentials_refresh(credentials)
self.assertEqual('Bearer 3/3w', content['Authorization']) self.assertEqual(b'Bearer 3/3w', content[b'Authorization'])
os.unlink(filename) os.unlink(filename)

View File

@@ -545,7 +545,7 @@ class BasicCredentialsTests(unittest.TestCase):
]) ])
http = self.credentials.authorize(http) http = self.credentials.authorize(http)
resp, content = http.request('http://example.com') resp, content = http.request('http://example.com')
self.assertEqual('Bearer 1/3w', content['Authorization']) self.assertEqual(b'Bearer 1/3w', content[b'Authorization'])
self.assertFalse(self.credentials.access_token_expired) self.assertFalse(self.credentials.access_token_expired)
self.assertEqual(token_response, self.credentials.token_response) self.assertEqual(token_response, self.credentials.token_response)
@@ -615,10 +615,10 @@ class BasicCredentialsTests(unittest.TestCase):
http = credentials.authorize(http) http = credentials.authorize(http)
http.request(u'http://example.com', method=u'GET', headers={u'foo': u'bar'}) http.request(u'http://example.com', method=u'GET', headers={u'foo': u'bar'})
for k, v in six.iteritems(http.headers): for k, v in six.iteritems(http.headers):
self.assertEqual(str, type(k)) self.assertEqual(six.binary_type, type(k))
self.assertEqual(str, type(v)) self.assertEqual(six.binary_type, type(v))
# Test again with unicode strings that can't simple be converted to ASCII. # Test again with unicode strings that can't simply be converted to ASCII.
try: try:
http.request( http.request(
u'http://example.com', method=u'GET', headers={u'foo': u'\N{COMET}'}) u'http://example.com', method=u'GET', headers={u'foo': u'\N{COMET}'})
@@ -707,7 +707,7 @@ class AccessTokenCredentialsTests(unittest.TestCase):
]) ])
http = self.credentials.authorize(http) http = self.credentials.authorize(http)
resp, content = http.request('http://example.com') resp, content = http.request('http://example.com')
self.assertEqual('Bearer foo', content['Authorization']) self.assertEqual(b'Bearer foo', content[b'Authorization'])
class TestAssertionCredentials(unittest.TestCase): class TestAssertionCredentials(unittest.TestCase):
@@ -738,7 +738,7 @@ class TestAssertionCredentials(unittest.TestCase):
]) ])
http = self.credentials.authorize(http) http = self.credentials.authorize(http)
resp, content = http.request('http://example.com') resp, content = http.request('http://example.com')
self.assertEqual('Bearer 1/3w', content['Authorization']) self.assertEqual(b'Bearer 1/3w', content[b'Authorization'])
def test_token_revoke_success(self): def test_token_revoke_success(self):
_token_revoke_test_helper( _token_revoke_test_helper(
@@ -769,16 +769,18 @@ class ExtractIdTokenTest(unittest.TestCase):
def test_extract_success(self): def test_extract_success(self):
body = {'foo': 'bar'} body = {'foo': 'bar'}
payload = base64.urlsafe_b64encode(json.dumps(body)).strip('=') body_json = json.dumps(body).encode('ascii')
jwt = 'stuff.' + payload + '.signature' payload = base64.urlsafe_b64encode(body_json).strip(b'=')
jwt = b'stuff.' + payload + b'.signature'
extracted = _extract_id_token(jwt) extracted = _extract_id_token(jwt)
self.assertEqual(extracted, body) self.assertEqual(extracted, body)
def test_extract_failure(self): def test_extract_failure(self):
body = {'foo': 'bar'} body = {'foo': 'bar'}
payload = base64.urlsafe_b64encode(json.dumps(body)).strip('=') body_json = json.dumps(body).encode('ascii')
jwt = 'stuff.' + payload payload = base64.urlsafe_b64encode(body_json).strip(b'=')
jwt = b'stuff.' + payload
self.assertRaises(VerifyJwtTokenError, _extract_id_token, jwt) self.assertRaises(VerifyJwtTokenError, _extract_id_token, jwt)
@@ -840,14 +842,14 @@ class OAuth2WebServerFlowTest(unittest.TestCase):
def test_urlencoded_exchange_failure(self): def test_urlencoded_exchange_failure(self):
http = HttpMockSequence([ http = HttpMockSequence([
({'status': '400'}, 'error=invalid_request'), ({'status': '400'}, b'error=invalid_request'),
]) ])
try: try:
credentials = self.flow.step2_exchange('some random code', http=http) credentials = self.flow.step2_exchange('some random code', http=http)
self.fail('should raise exception if exchange doesn\'t get 200') self.fail('should raise exception if exchange doesn\'t get 200')
except FlowExchangeError as e: except FlowExchangeError as e:
self.assertEquals('invalid_request', str(e)) self.assertEqual('invalid_request', str(e))
def test_exchange_failure_with_json_error(self): def test_exchange_failure_with_json_error(self):
# Some providers have 'error' attribute as a JSON object # Some providers have 'error' attribute as a JSON object
@@ -894,12 +896,12 @@ class OAuth2WebServerFlowTest(unittest.TestCase):
code = 'some random code' code = 'some random code'
not_a_dict = FakeDict({'code': code}) not_a_dict = FakeDict({'code': code})
http = HttpMockSequence([ payload = (b'{'
({'status': '200'}, b' "access_token":"SlAV32hkKG",'
"""{ "access_token":"SlAV32hkKG", b' "expires_in":3600,'
"expires_in":3600, b' "refresh_token":"8xLOxBtZp8"'
"refresh_token":"8xLOxBtZp8" }"""), b'}')
]) http = HttpMockSequence([({'status': '200'}, payload),])
credentials = self.flow.step2_exchange(not_a_dict, http=http) credentials = self.flow.step2_exchange(not_a_dict, http=http)
self.assertEqual('SlAV32hkKG', credentials.access_token) self.assertEqual('SlAV32hkKG', credentials.access_token)
@@ -972,9 +974,10 @@ class OAuth2WebServerFlowTest(unittest.TestCase):
def test_exchange_id_token(self): def test_exchange_id_token(self):
body = {'foo': 'bar'} body = {'foo': 'bar'}
payload = base64.urlsafe_b64encode(json.dumps(body)).strip('=') body_json = json.dumps(body).encode('ascii')
jwt = (base64.urlsafe_b64encode('stuff')+ '.' + payload + '.' + payload = base64.urlsafe_b64encode(body_json).strip(b'=')
base64.urlsafe_b64encode('signature')) jwt = (base64.urlsafe_b64encode(b'stuff') + b'.' + payload + b'.' +
base64.urlsafe_b64encode(b'signature'))
http = HttpMockSequence([ http = HttpMockSequence([
({'status': '200'}, ("""{ "access_token":"SlAV32hkKG", ({'status': '200'}, ("""{ "access_token":"SlAV32hkKG",
@@ -994,7 +997,7 @@ class FlowFromCachedClientsecrets(unittest.TestCase):
flow = flow_from_clientsecrets( flow = flow_from_clientsecrets(
'some_secrets', '', redirect_uri='oob', cache=cache_mock) 'some_secrets', '', redirect_uri='oob', cache=cache_mock)
self.assertEquals('foo_client_secret', flow.client_secret) self.assertEqual('foo_client_secret', flow.client_secret)
class CredentialsFromCodeTests(unittest.TestCase): class CredentialsFromCodeTests(unittest.TestCase):
@@ -1014,7 +1017,7 @@ class CredentialsFromCodeTests(unittest.TestCase):
credentials = credentials_from_code(self.client_id, self.client_secret, credentials = credentials_from_code(self.client_id, self.client_secret,
self.scope, self.code, redirect_uri=self.redirect_uri, self.scope, self.code, redirect_uri=self.redirect_uri,
http=http) http=http)
self.assertEquals(credentials.access_token, token) self.assertEqual(credentials.access_token, token)
self.assertNotEqual(None, credentials.token_expiry) self.assertNotEqual(None, credentials.token_expiry)
def test_exchange_code_for_token_fail(self): def test_exchange_code_for_token_fail(self):
@@ -1039,7 +1042,7 @@ class CredentialsFromCodeTests(unittest.TestCase):
credentials = credentials_from_clientsecrets_and_code( credentials = credentials_from_clientsecrets_and_code(
datafile('client_secrets.json'), self.scope, datafile('client_secrets.json'), self.scope,
self.code, http=http) self.code, http=http)
self.assertEquals(credentials.access_token, 'asdfghjkl') self.assertEqual(credentials.access_token, 'asdfghjkl')
self.assertNotEqual(None, credentials.token_expiry) self.assertNotEqual(None, credentials.token_expiry)
def test_exchange_code_and_cached_file_for_token(self): def test_exchange_code_and_cached_file_for_token(self):
@@ -1052,7 +1055,7 @@ class CredentialsFromCodeTests(unittest.TestCase):
credentials = credentials_from_clientsecrets_and_code( credentials = credentials_from_clientsecrets_and_code(
'some_secrets', self.scope, 'some_secrets', self.scope,
self.code, http=http, cache=cache_mock) self.code, http=http, cache=cache_mock)
self.assertEquals(credentials.access_token, 'asdfghjkl') self.assertEqual(credentials.access_token, 'asdfghjkl')
def test_exchange_code_and_file_for_token_fail(self): def test_exchange_code_and_file_for_token_fail(self):
http = HttpMockSequence([ http = HttpMockSequence([