Merge pull request #272 from dhermes/add-from-bytes

Adding _from_bytes helpers as a foil for _to_bytes.
This commit is contained in:
Danny Hermes
2015-08-18 21:06:50 -07:00
9 changed files with 102 additions and 59 deletions

View File

@@ -68,6 +68,27 @@ def _to_bytes(value, encoding='ascii'):
raise ValueError('%r could not be converted to bytes' % (value,)) raise ValueError('%r could not be converted to bytes' % (value,))
def _from_bytes(value):
"""Converts bytes to a string value, if necessary.
Args:
value: The string/bytes value to be converted.
Returns:
The original value converted to unicode (if bytes) or as passed in
if it started out as unicode.
Raises:
ValueError if the value could not be converted to unicode.
"""
result = (value.decode('utf-8')
if isinstance(value, six.binary_type) else value)
if isinstance(result, six.text_type):
return result
else:
raise ValueError('%r could not be converted to unicode' % (value,))
def _urlsafe_b64encode(raw_bytes): def _urlsafe_b64encode(raw_bytes):
raw_bytes = _to_bytes(raw_bytes, encoding='utf-8') raw_bytes = _to_bytes(raw_bytes, encoding='utf-8')
return base64.urlsafe_b64encode(raw_bytes).rstrip(b'=') return base64.urlsafe_b64encode(raw_bytes).rstrip(b'=')

View File

@@ -40,6 +40,7 @@ from oauth2client import GOOGLE_DEVICE_URI
from oauth2client import GOOGLE_REVOKE_URI from oauth2client import GOOGLE_REVOKE_URI
from oauth2client import GOOGLE_TOKEN_URI from oauth2client import GOOGLE_TOKEN_URI
from oauth2client import GOOGLE_TOKEN_INFO_URI from oauth2client import GOOGLE_TOKEN_INFO_URI
from oauth2client._helpers import _from_bytes
from oauth2client._helpers import _to_bytes from oauth2client._helpers import _to_bytes
from oauth2client._helpers import _urlsafe_b64decode from oauth2client._helpers import _urlsafe_b64decode
from oauth2client import clientsecrets from oauth2client import clientsecrets
@@ -269,32 +270,32 @@ class Credentials(object):
@classmethod @classmethod
def new_from_json(cls, s): def new_from_json(cls, s):
"""Utility class method to instantiate a Credentials subclass from a JSON """Utility class method to instantiate a Credentials subclass from JSON.
representation produced by to_json().
Expects the JSON string to have been produced by to_json().
Args: Args:
s: string, JSON from to_json(). s: string or bytes, JSON from to_json().
Returns: Returns:
An instance of the subclass of Credentials that was serialized with An instance of the subclass of Credentials that was serialized with
to_json(). to_json().
""" """
if isinstance(s, bytes): json_string_as_unicode = _from_bytes(s)
s = s.decode('utf-8') data = json.loads(json_string_as_unicode)
data = json.loads(s)
# Find and call the right classmethod from_json() to restore the object. # Find and call the right classmethod from_json() to restore the object.
module = data['_module'] module_name = data['_module']
try: try:
m = __import__(module) module_obj = __import__(module_name)
except ImportError: except ImportError:
# In case there's an object from the old package structure, update it # In case there's an object from the old package structure, update it
module = module.replace('.googleapiclient', '') module_name = module_name.replace('.googleapiclient', '')
m = __import__(module) module_obj = __import__(module_name)
m = __import__(module, fromlist=module.split('.')[:-1]) module_obj = __import__(module_name, fromlist=module_name.split('.')[:-1])
kls = getattr(m, data['_class']) kls = getattr(module_obj, data['_class'])
from_json = getattr(kls, 'from_json') from_json = getattr(kls, 'from_json')
return from_json(s) return from_json(json_string_as_unicode)
@classmethod @classmethod
def from_json(cls, unused_data): def from_json(cls, unused_data):
@@ -673,8 +674,7 @@ class OAuth2Credentials(Credentials):
Returns: Returns:
An instance of a Credentials subclass. An instance of a Credentials subclass.
""" """
if isinstance(s, bytes): s = _from_bytes(s)
s = s.decode('utf-8')
data = json.loads(s) data = json.loads(s)
if (data.get('token_expiry') and if (data.get('token_expiry') and
not isinstance(data['token_expiry'], datetime.datetime)): not isinstance(data['token_expiry'], datetime.datetime)):
@@ -845,8 +845,7 @@ class OAuth2Credentials(Credentials):
logger.info('Refreshing access_token') logger.info('Refreshing access_token')
resp, content = http_request( resp, content = http_request(
self.token_uri, method='POST', body=body, headers=headers) self.token_uri, method='POST', body=body, headers=headers)
if isinstance(content, bytes): content = _from_bytes(content)
content = content.decode('utf-8')
if resp.status == 200: if resp.status == 200:
d = json.loads(content) d = json.loads(content)
self.token_response = d self.token_response = d
@@ -905,16 +904,12 @@ class OAuth2Credentials(Credentials):
query_params = {'token': token} query_params = {'token': token}
token_revoke_uri = _update_query_params(self.revoke_uri, query_params) token_revoke_uri = _update_query_params(self.revoke_uri, query_params)
resp, content = http_request(token_revoke_uri) resp, content = http_request(token_revoke_uri)
if isinstance(content, bytes):
content = content.decode('utf-8')
if resp.status == 200: if resp.status == 200:
self.invalid = True self.invalid = True
else: else:
error_msg = 'Invalid response %s.' % resp.status error_msg = 'Invalid response %s.' % resp.status
try: try:
d = json.loads(content) d = json.loads(_from_bytes(content))
if 'error' in d: if 'error' in d:
error_msg = d['error'] error_msg = d['error']
except (TypeError, ValueError): except (TypeError, ValueError):
@@ -949,10 +944,7 @@ class OAuth2Credentials(Credentials):
query_params = {'access_token': token, 'fields': 'scope'} query_params = {'access_token': token, 'fields': 'scope'}
token_info_uri = _update_query_params(self.token_info_uri, query_params) token_info_uri = _update_query_params(self.token_info_uri, query_params)
resp, content = http_request(token_info_uri) resp, content = http_request(token_info_uri)
content = _from_bytes(content)
if six.PY3 and isinstance(content, bytes):
content = content.decode('utf-8')
if resp.status == 200: if resp.status == 200:
d = json.loads(content) d = json.loads(content)
self.scopes = set(util.string_to_scopes(d.get('scope', ''))) self.scopes = set(util.string_to_scopes(d.get('scope', '')))
@@ -1018,9 +1010,7 @@ class AccessTokenCredentials(OAuth2Credentials):
@classmethod @classmethod
def from_json(cls, s): def from_json(cls, s):
if isinstance(s, bytes): data = json.loads(_from_bytes(s))
s = s.decode('utf-8')
data = json.loads(s)
retval = AccessTokenCredentials( retval = AccessTokenCredentials(
data['access_token'], data['access_token'],
data['user_agent']) data['user_agent'])
@@ -1612,7 +1602,7 @@ class SignedJwtAssertionCredentials(AssertionCredentials):
@classmethod @classmethod
def from_json(cls, s): def from_json(cls, s):
data = json.loads(s) data = json.loads(_from_bytes(s))
retval = SignedJwtAssertionCredentials( retval = SignedJwtAssertionCredentials(
data['service_account_name'], data['service_account_name'],
base64.b64decode(data['private_key']), base64.b64decode(data['private_key']),
@@ -1675,9 +1665,8 @@ def verify_id_token(id_token, audience, http=None,
http = _cached_http http = _cached_http
resp, content = http.request(cert_uri) resp, content = http.request(cert_uri)
if resp.status == 200: if resp.status == 200:
certs = json.loads(content.decode('utf-8')) certs = json.loads(_from_bytes(content))
return crypt.verify_signed_jwt_with_certs(id_token, certs, audience) return crypt.verify_signed_jwt_with_certs(id_token, certs, audience)
else: else:
raise VerifyJwtTokenError('Status code: %d' % resp.status) raise VerifyJwtTokenError('Status code: %d' % resp.status)
@@ -1703,7 +1692,7 @@ def _extract_id_token(id_token):
raise VerifyJwtTokenError( raise VerifyJwtTokenError(
'Wrong number of segments in token: %s' % id_token) 'Wrong number of segments in token: %s' % id_token)
return json.loads(_urlsafe_b64decode(segments[1]).decode('utf-8')) return json.loads(_from_bytes(_urlsafe_b64decode(segments[1])))
def _parse_exchange_token_response(content): def _parse_exchange_token_response(content):
@@ -1720,12 +1709,12 @@ def _parse_exchange_token_response(content):
i.e. {}. That basically indicates a failure. i.e. {}. That basically indicates a failure.
""" """
resp = {} resp = {}
content = _from_bytes(content)
try: try:
resp = json.loads(content.decode('utf-8')) resp = json.loads(content)
except Exception: except Exception:
# different JSON libs raise different exceptions, # different JSON libs raise different exceptions,
# so we just do a catch-all here # so we just do a catch-all here
content = content.decode('utf-8')
resp = dict(urllib.parse.parse_qsl(content)) resp = dict(urllib.parse.parse_qsl(content))
# some providers respond with 'expires', others with 'expires_in' # some providers respond with 'expires', others with 'expires_in'
@@ -2000,6 +1989,7 @@ class OAuth2WebServerFlow(Flow):
resp, content = http.request(self.device_uri, method='POST', body=body, resp, content = http.request(self.device_uri, method='POST', body=body,
headers=headers) headers=headers)
content = _from_bytes(content)
if resp.status == 200: if resp.status == 200:
try: try:
flow_info = json.loads(content) flow_info = json.loads(content)

View File

@@ -19,6 +19,7 @@ import json
import logging import logging
import time import time
from oauth2client._helpers import _from_bytes
from oauth2client._helpers import _json_encode from oauth2client._helpers import _json_encode
from oauth2client._helpers import _to_bytes from oauth2client._helpers import _to_bytes
from oauth2client._helpers import _urlsafe_b64decode from oauth2client._helpers import _urlsafe_b64decode
@@ -124,7 +125,7 @@ def verify_signed_jwt_with_certs(jwt, certs, audience):
# Parse token. # Parse token.
json_body = _urlsafe_b64decode(segments[1]) json_body = _urlsafe_b64decode(segments[1])
try: try:
parsed = json.loads(json_body.decode('utf-8')) parsed = json.loads(_from_bytes(json_body))
except: except:
raise AppIdentityError('Can\'t parse token: %s' % json_body) raise AppIdentityError('Can\'t parse token: %s' % json_body)

View File

@@ -23,6 +23,7 @@ import json
import logging import logging
from six.moves import urllib from six.moves import urllib
from oauth2client._helpers import _from_bytes
from oauth2client import util from oauth2client import util
from oauth2client.client import AccessTokenRefreshError from oauth2client.client import AccessTokenRefreshError
from oauth2client.client import AssertionCredentials from oauth2client.client import AssertionCredentials
@@ -63,7 +64,7 @@ class AppAssertionCredentials(AssertionCredentials):
@classmethod @classmethod
def from_json(cls, json_data): def from_json(cls, json_data):
data = json.loads(json_data) data = json.loads(_from_bytes(json_data))
return AppAssertionCredentials(data['scope']) return AppAssertionCredentials(data['scope'])
def _refresh(self, http_request): def _refresh(self, http_request):
@@ -81,6 +82,7 @@ class AppAssertionCredentials(AssertionCredentials):
query = '?scope=%s' % urllib.parse.quote(self.scope, '') query = '?scope=%s' % urllib.parse.quote(self.scope, '')
uri = META.replace('{?scope}', query) uri = META.replace('{?scope}', query)
response, content = http_request(uri) response, content = http_request(uri)
content = _from_bytes(content)
if response.status == 200: if response.status == 200:
try: try:
d = json.loads(content) d = json.loads(content)

View File

@@ -18,7 +18,6 @@ This credentials class is implemented on top of rsa library.
""" """
import base64 import base64
import six
import time import time
from pyasn1.codec.ber import decoder from pyasn1.codec.ber import decoder

View File

@@ -15,6 +15,7 @@
import unittest import unittest
from oauth2client._helpers import _from_bytes
from oauth2client._helpers import _json_encode from oauth2client._helpers import _json_encode
from oauth2client._helpers import _parse_pem_key from oauth2client._helpers import _parse_pem_key
from oauth2client._helpers import _to_bytes from oauth2client._helpers import _to_bytes
@@ -66,6 +67,22 @@ class Test__to_bytes(unittest.TestCase):
self.assertRaises(ValueError, _to_bytes, value) self.assertRaises(ValueError, _to_bytes, value)
class Test__from_bytes(unittest.TestCase):
def test_with_unicode(self):
value = u'bytes-val'
self.assertEqual(_from_bytes(value), value)
def test_with_bytes(self):
value = b'string-val'
decoded_value = u'string-val'
self.assertEqual(_from_bytes(value), decoded_value)
def test_with_nonstring_type(self):
value = object()
self.assertRaises(ValueError, _from_bytes, value)
class Test__urlsafe_b64encode(unittest.TestCase): class Test__urlsafe_b64encode(unittest.TestCase):
DEADBEEF_ENCODED = b'ZGVhZGJlZWY' DEADBEEF_ENCODED = b'ZGVhZGJlZWY'

View File

@@ -29,8 +29,8 @@ import os
import tempfile import tempfile
import time import time
import unittest import unittest
import urllib
import urlparse from six.moves import urllib
import dev_appserver import dev_appserver
dev_appserver.fix_sys_path() dev_appserver.fix_sys_path()
@@ -554,7 +554,7 @@ class DecoratorTests(unittest.TestCase):
self.assertEqual(self.decorator.credentials, None) self.assertEqual(self.decorator.credentials, None)
response = self.app.get('http://localhost/foo_path') response = self.app.get('http://localhost/foo_path')
self.assertTrue(response.status.startswith('302')) self.assertTrue(response.status.startswith('302'))
q = urlparse.parse_qs(response.headers['Location'].split('?', 1)[1]) q = urllib.parse.parse_qs(response.headers['Location'].split('?', 1)[1])
self.assertEqual('http://localhost/oauth2callback', q['redirect_uri'][0]) self.assertEqual('http://localhost/oauth2callback', q['redirect_uri'][0])
self.assertEqual('foo_client_id', q['client_id'][0]) self.assertEqual('foo_client_id', q['client_id'][0])
self.assertEqual('foo_scope bar_scope', q['scope'][0]) self.assertEqual('foo_scope bar_scope', q['scope'][0])
@@ -575,10 +575,10 @@ class DecoratorTests(unittest.TestCase):
self.assertEqual('http://localhost/foo_path', parts[0]) self.assertEqual('http://localhost/foo_path', parts[0])
self.assertEqual(None, self.decorator.credentials) self.assertEqual(None, self.decorator.credentials)
if self.decorator._token_response_param: if self.decorator._token_response_param:
response_query = urlparse.parse_qs(parts[1]) response_query = urllib.parse.parse_qs(parts[1])
response = response_query[self.decorator._token_response_param][0] response = response_query[self.decorator._token_response_param][0]
self.assertEqual(Http2Mock.content, self.assertEqual(Http2Mock.content,
json.loads(urllib.unquote(response))) json.loads(urllib.parse.unquote(response)))
self.assertEqual(self.decorator.flow, self.decorator._tls.flow) self.assertEqual(self.decorator.flow, self.decorator._tls.flow)
self.assertEqual(self.decorator.credentials, self.assertEqual(self.decorator.credentials,
self.decorator._tls.credentials) self.decorator._tls.credentials)
@@ -609,7 +609,7 @@ class DecoratorTests(unittest.TestCase):
# Invalid Credentials should start the OAuth dance again. # Invalid Credentials should start the OAuth dance again.
response = self.app.get('/foo_path') response = self.app.get('/foo_path')
self.assertTrue(response.status.startswith('302')) self.assertTrue(response.status.startswith('302'))
q = urlparse.parse_qs(response.headers['Location'].split('?', 1)[1]) q = urllib.parse.parse_qs(response.headers['Location'].split('?', 1)[1])
self.assertEqual('http://localhost/oauth2callback', q['redirect_uri'][0]) self.assertEqual('http://localhost/oauth2callback', q['redirect_uri'][0])
def test_storage_delete(self): def test_storage_delete(self):
@@ -654,7 +654,7 @@ class DecoratorTests(unittest.TestCase):
self.assertEqual('200 OK', response.status) self.assertEqual('200 OK', response.status)
self.assertEqual(False, self.decorator.has_credentials()) self.assertEqual(False, self.decorator.has_credentials())
url = self.decorator.authorize_url() url = self.decorator.authorize_url()
q = urlparse.parse_qs(url.split('?', 1)[1]) q = urllib.parse.parse_qs(url.split('?', 1)[1])
self.assertEqual('http://localhost/oauth2callback', q['redirect_uri'][0]) self.assertEqual('http://localhost/oauth2callback', q['redirect_uri'][0])
self.assertEqual('foo_client_id', q['client_id'][0]) self.assertEqual('foo_client_id', q['client_id'][0])
self.assertEqual('foo_scope bar_scope', q['scope'][0]) self.assertEqual('foo_scope bar_scope', q['scope'][0])

View File

@@ -20,11 +20,14 @@ Unit tests for oauth2client.gce.
__author__ = 'jcgregorio@google.com (Joe Gregorio)' __author__ = 'jcgregorio@google.com (Joe Gregorio)'
import json
from six.moves import urllib
import unittest import unittest
import httplib2 import httplib2
import mock import mock
from oauth2client._helpers import _to_bytes
from oauth2client.client import AccessTokenRefreshError from oauth2client.client import AccessTokenRefreshError
from oauth2client.client import Credentials from oauth2client.client import Credentials
from oauth2client.client import save_to_well_known_file from oauth2client.client import save_to_well_known_file
@@ -33,22 +36,32 @@ from oauth2client.gce import AppAssertionCredentials
class AssertionCredentialsTests(unittest.TestCase): class AssertionCredentialsTests(unittest.TestCase):
def test_good_refresh(self): def _refresh_success_helper(self, bytes_response=False):
access_token = u'this-is-a-token'
return_val = json.dumps({u'accessToken': access_token})
if bytes_response:
return_val = _to_bytes(return_val)
http = mock.MagicMock() http = mock.MagicMock()
http.request = mock.MagicMock( http.request = mock.MagicMock(
return_value=(mock.Mock(status=200), return_value=(mock.Mock(status=200), return_val))
'{"accessToken": "this-is-a-token"}'))
c = AppAssertionCredentials(scope=['http://example.com/a', scopes = ['http://example.com/a', 'http://example.com/b']
'http://example.com/b']) credentials = AppAssertionCredentials(scope=scopes)
self.assertEquals(None, c.access_token) self.assertEquals(None, credentials.access_token)
c.refresh(http) credentials.refresh(http)
self.assertEquals('this-is-a-token', c.access_token) self.assertEquals(access_token, credentials.access_token)
http.request.assert_called_once_with( base_metadata_uri = ('http://metadata.google.internal/0.1/meta-data/'
'http://metadata.google.internal/0.1/meta-data/service-accounts/' 'service-accounts/default/acquire')
'default/acquire' escaped_scopes = urllib.parse.quote(' '.join(scopes), safe='')
'?scope=http%3A%2F%2Fexample.com%2Fa%20http%3A%2F%2Fexample.com%2Fb') request_uri = base_metadata_uri + '?scope=' + escaped_scopes
http.request.assert_called_once_with(request_uri)
def test_refresh_success(self):
self._refresh_success_helper(bytes_response=False)
def test_refresh_success_bytes(self):
self._refresh_success_helper(bytes_response=True)
def test_fail_refresh(self): def test_fail_refresh(self):
http = mock.MagicMock() http = mock.MagicMock()

View File

@@ -801,8 +801,8 @@ class BasicCredentialsTests(unittest.TestCase):
http = credentials.authorize(http) http = credentials.authorize(http)
http.request(u'http://example.com', method=u'GET', headers={u'foo': u'bar'}) http.request(u'http://example.com', method=u'GET', headers={u'foo': u'bar'})
for k, v in six.iteritems(http.headers): for k, v in six.iteritems(http.headers):
self.assertEqual(six.binary_type, type(k)) self.assertTrue(isinstance(k, six.binary_type))
self.assertEqual(six.binary_type, type(v)) self.assertTrue(isinstance(v, six.binary_type))
# Test again with unicode strings that can't simply be converted to ASCII. # Test again with unicode strings that can't simply be converted to ASCII.
try: try: