Merge pull request #272 from dhermes/add-from-bytes

Adding _from_bytes helpers as a foil for _to_bytes.
This commit is contained in:
Danny Hermes
2015-08-18 21:06:50 -07:00
9 changed files with 102 additions and 59 deletions

View File

@@ -68,6 +68,27 @@ def _to_bytes(value, encoding='ascii'):
raise ValueError('%r could not be converted to bytes' % (value,))
def _from_bytes(value):
"""Converts bytes to a string value, if necessary.
Args:
value: The string/bytes value to be converted.
Returns:
The original value converted to unicode (if bytes) or as passed in
if it started out as unicode.
Raises:
ValueError if the value could not be converted to unicode.
"""
result = (value.decode('utf-8')
if isinstance(value, six.binary_type) else value)
if isinstance(result, six.text_type):
return result
else:
raise ValueError('%r could not be converted to unicode' % (value,))
def _urlsafe_b64encode(raw_bytes):
raw_bytes = _to_bytes(raw_bytes, encoding='utf-8')
return base64.urlsafe_b64encode(raw_bytes).rstrip(b'=')

View File

@@ -40,6 +40,7 @@ from oauth2client import GOOGLE_DEVICE_URI
from oauth2client import GOOGLE_REVOKE_URI
from oauth2client import GOOGLE_TOKEN_URI
from oauth2client import GOOGLE_TOKEN_INFO_URI
from oauth2client._helpers import _from_bytes
from oauth2client._helpers import _to_bytes
from oauth2client._helpers import _urlsafe_b64decode
from oauth2client import clientsecrets
@@ -269,32 +270,32 @@ class Credentials(object):
@classmethod
def new_from_json(cls, s):
"""Utility class method to instantiate a Credentials subclass from a JSON
representation produced by to_json().
"""Utility class method to instantiate a Credentials subclass from JSON.
Expects the JSON string to have been produced by to_json().
Args:
s: string, JSON from to_json().
s: string or bytes, JSON from to_json().
Returns:
An instance of the subclass of Credentials that was serialized with
to_json().
"""
if isinstance(s, bytes):
s = s.decode('utf-8')
data = json.loads(s)
json_string_as_unicode = _from_bytes(s)
data = json.loads(json_string_as_unicode)
# Find and call the right classmethod from_json() to restore the object.
module = data['_module']
module_name = data['_module']
try:
m = __import__(module)
module_obj = __import__(module_name)
except ImportError:
# In case there's an object from the old package structure, update it
module = module.replace('.googleapiclient', '')
m = __import__(module)
module_name = module_name.replace('.googleapiclient', '')
module_obj = __import__(module_name)
m = __import__(module, fromlist=module.split('.')[:-1])
kls = getattr(m, data['_class'])
module_obj = __import__(module_name, fromlist=module_name.split('.')[:-1])
kls = getattr(module_obj, data['_class'])
from_json = getattr(kls, 'from_json')
return from_json(s)
return from_json(json_string_as_unicode)
@classmethod
def from_json(cls, unused_data):
@@ -673,8 +674,7 @@ class OAuth2Credentials(Credentials):
Returns:
An instance of a Credentials subclass.
"""
if isinstance(s, bytes):
s = s.decode('utf-8')
s = _from_bytes(s)
data = json.loads(s)
if (data.get('token_expiry') and
not isinstance(data['token_expiry'], datetime.datetime)):
@@ -845,8 +845,7 @@ class OAuth2Credentials(Credentials):
logger.info('Refreshing access_token')
resp, content = http_request(
self.token_uri, method='POST', body=body, headers=headers)
if isinstance(content, bytes):
content = content.decode('utf-8')
content = _from_bytes(content)
if resp.status == 200:
d = json.loads(content)
self.token_response = d
@@ -905,16 +904,12 @@ class OAuth2Credentials(Credentials):
query_params = {'token': token}
token_revoke_uri = _update_query_params(self.revoke_uri, query_params)
resp, content = http_request(token_revoke_uri)
if isinstance(content, bytes):
content = content.decode('utf-8')
if resp.status == 200:
self.invalid = True
else:
error_msg = 'Invalid response %s.' % resp.status
try:
d = json.loads(content)
d = json.loads(_from_bytes(content))
if 'error' in d:
error_msg = d['error']
except (TypeError, ValueError):
@@ -949,10 +944,7 @@ class OAuth2Credentials(Credentials):
query_params = {'access_token': token, 'fields': 'scope'}
token_info_uri = _update_query_params(self.token_info_uri, query_params)
resp, content = http_request(token_info_uri)
if six.PY3 and isinstance(content, bytes):
content = content.decode('utf-8')
content = _from_bytes(content)
if resp.status == 200:
d = json.loads(content)
self.scopes = set(util.string_to_scopes(d.get('scope', '')))
@@ -1018,9 +1010,7 @@ class AccessTokenCredentials(OAuth2Credentials):
@classmethod
def from_json(cls, s):
if isinstance(s, bytes):
s = s.decode('utf-8')
data = json.loads(s)
data = json.loads(_from_bytes(s))
retval = AccessTokenCredentials(
data['access_token'],
data['user_agent'])
@@ -1612,7 +1602,7 @@ class SignedJwtAssertionCredentials(AssertionCredentials):
@classmethod
def from_json(cls, s):
data = json.loads(s)
data = json.loads(_from_bytes(s))
retval = SignedJwtAssertionCredentials(
data['service_account_name'],
base64.b64decode(data['private_key']),
@@ -1675,9 +1665,8 @@ def verify_id_token(id_token, audience, http=None,
http = _cached_http
resp, content = http.request(cert_uri)
if resp.status == 200:
certs = json.loads(content.decode('utf-8'))
certs = json.loads(_from_bytes(content))
return crypt.verify_signed_jwt_with_certs(id_token, certs, audience)
else:
raise VerifyJwtTokenError('Status code: %d' % resp.status)
@@ -1703,7 +1692,7 @@ def _extract_id_token(id_token):
raise VerifyJwtTokenError(
'Wrong number of segments in token: %s' % id_token)
return json.loads(_urlsafe_b64decode(segments[1]).decode('utf-8'))
return json.loads(_from_bytes(_urlsafe_b64decode(segments[1])))
def _parse_exchange_token_response(content):
@@ -1720,12 +1709,12 @@ def _parse_exchange_token_response(content):
i.e. {}. That basically indicates a failure.
"""
resp = {}
content = _from_bytes(content)
try:
resp = json.loads(content.decode('utf-8'))
resp = json.loads(content)
except Exception:
# different JSON libs raise different exceptions,
# so we just do a catch-all here
content = content.decode('utf-8')
resp = dict(urllib.parse.parse_qsl(content))
# some providers respond with 'expires', others with 'expires_in'
@@ -2000,6 +1989,7 @@ class OAuth2WebServerFlow(Flow):
resp, content = http.request(self.device_uri, method='POST', body=body,
headers=headers)
content = _from_bytes(content)
if resp.status == 200:
try:
flow_info = json.loads(content)

View File

@@ -19,6 +19,7 @@ import json
import logging
import time
from oauth2client._helpers import _from_bytes
from oauth2client._helpers import _json_encode
from oauth2client._helpers import _to_bytes
from oauth2client._helpers import _urlsafe_b64decode
@@ -124,7 +125,7 @@ def verify_signed_jwt_with_certs(jwt, certs, audience):
# Parse token.
json_body = _urlsafe_b64decode(segments[1])
try:
parsed = json.loads(json_body.decode('utf-8'))
parsed = json.loads(_from_bytes(json_body))
except:
raise AppIdentityError('Can\'t parse token: %s' % json_body)

View File

@@ -23,6 +23,7 @@ import json
import logging
from six.moves import urllib
from oauth2client._helpers import _from_bytes
from oauth2client import util
from oauth2client.client import AccessTokenRefreshError
from oauth2client.client import AssertionCredentials
@@ -63,7 +64,7 @@ class AppAssertionCredentials(AssertionCredentials):
@classmethod
def from_json(cls, json_data):
data = json.loads(json_data)
data = json.loads(_from_bytes(json_data))
return AppAssertionCredentials(data['scope'])
def _refresh(self, http_request):
@@ -81,6 +82,7 @@ class AppAssertionCredentials(AssertionCredentials):
query = '?scope=%s' % urllib.parse.quote(self.scope, '')
uri = META.replace('{?scope}', query)
response, content = http_request(uri)
content = _from_bytes(content)
if response.status == 200:
try:
d = json.loads(content)

View File

@@ -18,7 +18,6 @@ This credentials class is implemented on top of rsa library.
"""
import base64
import six
import time
from pyasn1.codec.ber import decoder

View File

@@ -15,6 +15,7 @@
import unittest
from oauth2client._helpers import _from_bytes
from oauth2client._helpers import _json_encode
from oauth2client._helpers import _parse_pem_key
from oauth2client._helpers import _to_bytes
@@ -66,6 +67,22 @@ class Test__to_bytes(unittest.TestCase):
self.assertRaises(ValueError, _to_bytes, value)
class Test__from_bytes(unittest.TestCase):
def test_with_unicode(self):
value = u'bytes-val'
self.assertEqual(_from_bytes(value), value)
def test_with_bytes(self):
value = b'string-val'
decoded_value = u'string-val'
self.assertEqual(_from_bytes(value), decoded_value)
def test_with_nonstring_type(self):
value = object()
self.assertRaises(ValueError, _from_bytes, value)
class Test__urlsafe_b64encode(unittest.TestCase):
DEADBEEF_ENCODED = b'ZGVhZGJlZWY'

View File

@@ -29,8 +29,8 @@ import os
import tempfile
import time
import unittest
import urllib
import urlparse
from six.moves import urllib
import dev_appserver
dev_appserver.fix_sys_path()
@@ -554,7 +554,7 @@ class DecoratorTests(unittest.TestCase):
self.assertEqual(self.decorator.credentials, None)
response = self.app.get('http://localhost/foo_path')
self.assertTrue(response.status.startswith('302'))
q = urlparse.parse_qs(response.headers['Location'].split('?', 1)[1])
q = urllib.parse.parse_qs(response.headers['Location'].split('?', 1)[1])
self.assertEqual('http://localhost/oauth2callback', q['redirect_uri'][0])
self.assertEqual('foo_client_id', q['client_id'][0])
self.assertEqual('foo_scope bar_scope', q['scope'][0])
@@ -575,10 +575,10 @@ class DecoratorTests(unittest.TestCase):
self.assertEqual('http://localhost/foo_path', parts[0])
self.assertEqual(None, self.decorator.credentials)
if self.decorator._token_response_param:
response_query = urlparse.parse_qs(parts[1])
response_query = urllib.parse.parse_qs(parts[1])
response = response_query[self.decorator._token_response_param][0]
self.assertEqual(Http2Mock.content,
json.loads(urllib.unquote(response)))
json.loads(urllib.parse.unquote(response)))
self.assertEqual(self.decorator.flow, self.decorator._tls.flow)
self.assertEqual(self.decorator.credentials,
self.decorator._tls.credentials)
@@ -609,7 +609,7 @@ class DecoratorTests(unittest.TestCase):
# Invalid Credentials should start the OAuth dance again.
response = self.app.get('/foo_path')
self.assertTrue(response.status.startswith('302'))
q = urlparse.parse_qs(response.headers['Location'].split('?', 1)[1])
q = urllib.parse.parse_qs(response.headers['Location'].split('?', 1)[1])
self.assertEqual('http://localhost/oauth2callback', q['redirect_uri'][0])
def test_storage_delete(self):
@@ -654,7 +654,7 @@ class DecoratorTests(unittest.TestCase):
self.assertEqual('200 OK', response.status)
self.assertEqual(False, self.decorator.has_credentials())
url = self.decorator.authorize_url()
q = urlparse.parse_qs(url.split('?', 1)[1])
q = urllib.parse.parse_qs(url.split('?', 1)[1])
self.assertEqual('http://localhost/oauth2callback', q['redirect_uri'][0])
self.assertEqual('foo_client_id', q['client_id'][0])
self.assertEqual('foo_scope bar_scope', q['scope'][0])

View File

@@ -20,11 +20,14 @@ Unit tests for oauth2client.gce.
__author__ = 'jcgregorio@google.com (Joe Gregorio)'
import json
from six.moves import urllib
import unittest
import httplib2
import mock
from oauth2client._helpers import _to_bytes
from oauth2client.client import AccessTokenRefreshError
from oauth2client.client import Credentials
from oauth2client.client import save_to_well_known_file
@@ -33,22 +36,32 @@ from oauth2client.gce import AppAssertionCredentials
class AssertionCredentialsTests(unittest.TestCase):
def test_good_refresh(self):
def _refresh_success_helper(self, bytes_response=False):
access_token = u'this-is-a-token'
return_val = json.dumps({u'accessToken': access_token})
if bytes_response:
return_val = _to_bytes(return_val)
http = mock.MagicMock()
http.request = mock.MagicMock(
return_value=(mock.Mock(status=200),
'{"accessToken": "this-is-a-token"}'))
return_value=(mock.Mock(status=200), return_val))
c = AppAssertionCredentials(scope=['http://example.com/a',
'http://example.com/b'])
self.assertEquals(None, c.access_token)
c.refresh(http)
self.assertEquals('this-is-a-token', c.access_token)
scopes = ['http://example.com/a', 'http://example.com/b']
credentials = AppAssertionCredentials(scope=scopes)
self.assertEquals(None, credentials.access_token)
credentials.refresh(http)
self.assertEquals(access_token, credentials.access_token)
http.request.assert_called_once_with(
'http://metadata.google.internal/0.1/meta-data/service-accounts/'
'default/acquire'
'?scope=http%3A%2F%2Fexample.com%2Fa%20http%3A%2F%2Fexample.com%2Fb')
base_metadata_uri = ('http://metadata.google.internal/0.1/meta-data/'
'service-accounts/default/acquire')
escaped_scopes = urllib.parse.quote(' '.join(scopes), safe='')
request_uri = base_metadata_uri + '?scope=' + escaped_scopes
http.request.assert_called_once_with(request_uri)
def test_refresh_success(self):
self._refresh_success_helper(bytes_response=False)
def test_refresh_success_bytes(self):
self._refresh_success_helper(bytes_response=True)
def test_fail_refresh(self):
http = mock.MagicMock()

View File

@@ -801,8 +801,8 @@ class BasicCredentialsTests(unittest.TestCase):
http = credentials.authorize(http)
http.request(u'http://example.com', method=u'GET', headers={u'foo': u'bar'})
for k, v in six.iteritems(http.headers):
self.assertEqual(six.binary_type, type(k))
self.assertEqual(six.binary_type, type(v))
self.assertTrue(isinstance(k, six.binary_type))
self.assertTrue(isinstance(v, six.binary_type))
# Test again with unicode strings that can't simply be converted to ASCII.
try: