Merge pull request #272 from dhermes/add-from-bytes
Adding _from_bytes helpers as a foil for _to_bytes.
This commit is contained in:
@@ -68,6 +68,27 @@ def _to_bytes(value, encoding='ascii'):
|
||||
raise ValueError('%r could not be converted to bytes' % (value,))
|
||||
|
||||
|
||||
def _from_bytes(value):
|
||||
"""Converts bytes to a string value, if necessary.
|
||||
|
||||
Args:
|
||||
value: The string/bytes value to be converted.
|
||||
|
||||
Returns:
|
||||
The original value converted to unicode (if bytes) or as passed in
|
||||
if it started out as unicode.
|
||||
|
||||
Raises:
|
||||
ValueError if the value could not be converted to unicode.
|
||||
"""
|
||||
result = (value.decode('utf-8')
|
||||
if isinstance(value, six.binary_type) else value)
|
||||
if isinstance(result, six.text_type):
|
||||
return result
|
||||
else:
|
||||
raise ValueError('%r could not be converted to unicode' % (value,))
|
||||
|
||||
|
||||
def _urlsafe_b64encode(raw_bytes):
|
||||
raw_bytes = _to_bytes(raw_bytes, encoding='utf-8')
|
||||
return base64.urlsafe_b64encode(raw_bytes).rstrip(b'=')
|
||||
|
||||
@@ -40,6 +40,7 @@ from oauth2client import GOOGLE_DEVICE_URI
|
||||
from oauth2client import GOOGLE_REVOKE_URI
|
||||
from oauth2client import GOOGLE_TOKEN_URI
|
||||
from oauth2client import GOOGLE_TOKEN_INFO_URI
|
||||
from oauth2client._helpers import _from_bytes
|
||||
from oauth2client._helpers import _to_bytes
|
||||
from oauth2client._helpers import _urlsafe_b64decode
|
||||
from oauth2client import clientsecrets
|
||||
@@ -269,32 +270,32 @@ class Credentials(object):
|
||||
|
||||
@classmethod
|
||||
def new_from_json(cls, s):
|
||||
"""Utility class method to instantiate a Credentials subclass from a JSON
|
||||
representation produced by to_json().
|
||||
"""Utility class method to instantiate a Credentials subclass from JSON.
|
||||
|
||||
Expects the JSON string to have been produced by to_json().
|
||||
|
||||
Args:
|
||||
s: string, JSON from to_json().
|
||||
s: string or bytes, JSON from to_json().
|
||||
|
||||
Returns:
|
||||
An instance of the subclass of Credentials that was serialized with
|
||||
to_json().
|
||||
"""
|
||||
if isinstance(s, bytes):
|
||||
s = s.decode('utf-8')
|
||||
data = json.loads(s)
|
||||
json_string_as_unicode = _from_bytes(s)
|
||||
data = json.loads(json_string_as_unicode)
|
||||
# Find and call the right classmethod from_json() to restore the object.
|
||||
module = data['_module']
|
||||
module_name = data['_module']
|
||||
try:
|
||||
m = __import__(module)
|
||||
module_obj = __import__(module_name)
|
||||
except ImportError:
|
||||
# In case there's an object from the old package structure, update it
|
||||
module = module.replace('.googleapiclient', '')
|
||||
m = __import__(module)
|
||||
module_name = module_name.replace('.googleapiclient', '')
|
||||
module_obj = __import__(module_name)
|
||||
|
||||
m = __import__(module, fromlist=module.split('.')[:-1])
|
||||
kls = getattr(m, data['_class'])
|
||||
module_obj = __import__(module_name, fromlist=module_name.split('.')[:-1])
|
||||
kls = getattr(module_obj, data['_class'])
|
||||
from_json = getattr(kls, 'from_json')
|
||||
return from_json(s)
|
||||
return from_json(json_string_as_unicode)
|
||||
|
||||
@classmethod
|
||||
def from_json(cls, unused_data):
|
||||
@@ -673,8 +674,7 @@ class OAuth2Credentials(Credentials):
|
||||
Returns:
|
||||
An instance of a Credentials subclass.
|
||||
"""
|
||||
if isinstance(s, bytes):
|
||||
s = s.decode('utf-8')
|
||||
s = _from_bytes(s)
|
||||
data = json.loads(s)
|
||||
if (data.get('token_expiry') and
|
||||
not isinstance(data['token_expiry'], datetime.datetime)):
|
||||
@@ -845,8 +845,7 @@ class OAuth2Credentials(Credentials):
|
||||
logger.info('Refreshing access_token')
|
||||
resp, content = http_request(
|
||||
self.token_uri, method='POST', body=body, headers=headers)
|
||||
if isinstance(content, bytes):
|
||||
content = content.decode('utf-8')
|
||||
content = _from_bytes(content)
|
||||
if resp.status == 200:
|
||||
d = json.loads(content)
|
||||
self.token_response = d
|
||||
@@ -905,16 +904,12 @@ class OAuth2Credentials(Credentials):
|
||||
query_params = {'token': token}
|
||||
token_revoke_uri = _update_query_params(self.revoke_uri, query_params)
|
||||
resp, content = http_request(token_revoke_uri)
|
||||
|
||||
if isinstance(content, bytes):
|
||||
content = content.decode('utf-8')
|
||||
|
||||
if resp.status == 200:
|
||||
self.invalid = True
|
||||
else:
|
||||
error_msg = 'Invalid response %s.' % resp.status
|
||||
try:
|
||||
d = json.loads(content)
|
||||
d = json.loads(_from_bytes(content))
|
||||
if 'error' in d:
|
||||
error_msg = d['error']
|
||||
except (TypeError, ValueError):
|
||||
@@ -949,10 +944,7 @@ class OAuth2Credentials(Credentials):
|
||||
query_params = {'access_token': token, 'fields': 'scope'}
|
||||
token_info_uri = _update_query_params(self.token_info_uri, query_params)
|
||||
resp, content = http_request(token_info_uri)
|
||||
|
||||
if six.PY3 and isinstance(content, bytes):
|
||||
content = content.decode('utf-8')
|
||||
|
||||
content = _from_bytes(content)
|
||||
if resp.status == 200:
|
||||
d = json.loads(content)
|
||||
self.scopes = set(util.string_to_scopes(d.get('scope', '')))
|
||||
@@ -1018,9 +1010,7 @@ class AccessTokenCredentials(OAuth2Credentials):
|
||||
|
||||
@classmethod
|
||||
def from_json(cls, s):
|
||||
if isinstance(s, bytes):
|
||||
s = s.decode('utf-8')
|
||||
data = json.loads(s)
|
||||
data = json.loads(_from_bytes(s))
|
||||
retval = AccessTokenCredentials(
|
||||
data['access_token'],
|
||||
data['user_agent'])
|
||||
@@ -1612,7 +1602,7 @@ class SignedJwtAssertionCredentials(AssertionCredentials):
|
||||
|
||||
@classmethod
|
||||
def from_json(cls, s):
|
||||
data = json.loads(s)
|
||||
data = json.loads(_from_bytes(s))
|
||||
retval = SignedJwtAssertionCredentials(
|
||||
data['service_account_name'],
|
||||
base64.b64decode(data['private_key']),
|
||||
@@ -1675,9 +1665,8 @@ def verify_id_token(id_token, audience, http=None,
|
||||
http = _cached_http
|
||||
|
||||
resp, content = http.request(cert_uri)
|
||||
|
||||
if resp.status == 200:
|
||||
certs = json.loads(content.decode('utf-8'))
|
||||
certs = json.loads(_from_bytes(content))
|
||||
return crypt.verify_signed_jwt_with_certs(id_token, certs, audience)
|
||||
else:
|
||||
raise VerifyJwtTokenError('Status code: %d' % resp.status)
|
||||
@@ -1703,7 +1692,7 @@ def _extract_id_token(id_token):
|
||||
raise VerifyJwtTokenError(
|
||||
'Wrong number of segments in token: %s' % id_token)
|
||||
|
||||
return json.loads(_urlsafe_b64decode(segments[1]).decode('utf-8'))
|
||||
return json.loads(_from_bytes(_urlsafe_b64decode(segments[1])))
|
||||
|
||||
|
||||
def _parse_exchange_token_response(content):
|
||||
@@ -1720,12 +1709,12 @@ def _parse_exchange_token_response(content):
|
||||
i.e. {}. That basically indicates a failure.
|
||||
"""
|
||||
resp = {}
|
||||
content = _from_bytes(content)
|
||||
try:
|
||||
resp = json.loads(content.decode('utf-8'))
|
||||
resp = json.loads(content)
|
||||
except Exception:
|
||||
# different JSON libs raise different exceptions,
|
||||
# so we just do a catch-all here
|
||||
content = content.decode('utf-8')
|
||||
resp = dict(urllib.parse.parse_qsl(content))
|
||||
|
||||
# some providers respond with 'expires', others with 'expires_in'
|
||||
@@ -2000,6 +1989,7 @@ class OAuth2WebServerFlow(Flow):
|
||||
|
||||
resp, content = http.request(self.device_uri, method='POST', body=body,
|
||||
headers=headers)
|
||||
content = _from_bytes(content)
|
||||
if resp.status == 200:
|
||||
try:
|
||||
flow_info = json.loads(content)
|
||||
|
||||
@@ -19,6 +19,7 @@ import json
|
||||
import logging
|
||||
import time
|
||||
|
||||
from oauth2client._helpers import _from_bytes
|
||||
from oauth2client._helpers import _json_encode
|
||||
from oauth2client._helpers import _to_bytes
|
||||
from oauth2client._helpers import _urlsafe_b64decode
|
||||
@@ -124,7 +125,7 @@ def verify_signed_jwt_with_certs(jwt, certs, audience):
|
||||
# Parse token.
|
||||
json_body = _urlsafe_b64decode(segments[1])
|
||||
try:
|
||||
parsed = json.loads(json_body.decode('utf-8'))
|
||||
parsed = json.loads(_from_bytes(json_body))
|
||||
except:
|
||||
raise AppIdentityError('Can\'t parse token: %s' % json_body)
|
||||
|
||||
|
||||
@@ -23,6 +23,7 @@ import json
|
||||
import logging
|
||||
from six.moves import urllib
|
||||
|
||||
from oauth2client._helpers import _from_bytes
|
||||
from oauth2client import util
|
||||
from oauth2client.client import AccessTokenRefreshError
|
||||
from oauth2client.client import AssertionCredentials
|
||||
@@ -63,7 +64,7 @@ class AppAssertionCredentials(AssertionCredentials):
|
||||
|
||||
@classmethod
|
||||
def from_json(cls, json_data):
|
||||
data = json.loads(json_data)
|
||||
data = json.loads(_from_bytes(json_data))
|
||||
return AppAssertionCredentials(data['scope'])
|
||||
|
||||
def _refresh(self, http_request):
|
||||
@@ -81,6 +82,7 @@ class AppAssertionCredentials(AssertionCredentials):
|
||||
query = '?scope=%s' % urllib.parse.quote(self.scope, '')
|
||||
uri = META.replace('{?scope}', query)
|
||||
response, content = http_request(uri)
|
||||
content = _from_bytes(content)
|
||||
if response.status == 200:
|
||||
try:
|
||||
d = json.loads(content)
|
||||
|
||||
@@ -18,7 +18,6 @@ This credentials class is implemented on top of rsa library.
|
||||
"""
|
||||
|
||||
import base64
|
||||
import six
|
||||
import time
|
||||
|
||||
from pyasn1.codec.ber import decoder
|
||||
|
||||
@@ -15,6 +15,7 @@
|
||||
|
||||
import unittest
|
||||
|
||||
from oauth2client._helpers import _from_bytes
|
||||
from oauth2client._helpers import _json_encode
|
||||
from oauth2client._helpers import _parse_pem_key
|
||||
from oauth2client._helpers import _to_bytes
|
||||
@@ -66,6 +67,22 @@ class Test__to_bytes(unittest.TestCase):
|
||||
self.assertRaises(ValueError, _to_bytes, value)
|
||||
|
||||
|
||||
class Test__from_bytes(unittest.TestCase):
|
||||
|
||||
def test_with_unicode(self):
|
||||
value = u'bytes-val'
|
||||
self.assertEqual(_from_bytes(value), value)
|
||||
|
||||
def test_with_bytes(self):
|
||||
value = b'string-val'
|
||||
decoded_value = u'string-val'
|
||||
self.assertEqual(_from_bytes(value), decoded_value)
|
||||
|
||||
def test_with_nonstring_type(self):
|
||||
value = object()
|
||||
self.assertRaises(ValueError, _from_bytes, value)
|
||||
|
||||
|
||||
class Test__urlsafe_b64encode(unittest.TestCase):
|
||||
|
||||
DEADBEEF_ENCODED = b'ZGVhZGJlZWY'
|
||||
|
||||
@@ -29,8 +29,8 @@ import os
|
||||
import tempfile
|
||||
import time
|
||||
import unittest
|
||||
import urllib
|
||||
import urlparse
|
||||
|
||||
from six.moves import urllib
|
||||
|
||||
import dev_appserver
|
||||
dev_appserver.fix_sys_path()
|
||||
@@ -554,7 +554,7 @@ class DecoratorTests(unittest.TestCase):
|
||||
self.assertEqual(self.decorator.credentials, None)
|
||||
response = self.app.get('http://localhost/foo_path')
|
||||
self.assertTrue(response.status.startswith('302'))
|
||||
q = urlparse.parse_qs(response.headers['Location'].split('?', 1)[1])
|
||||
q = urllib.parse.parse_qs(response.headers['Location'].split('?', 1)[1])
|
||||
self.assertEqual('http://localhost/oauth2callback', q['redirect_uri'][0])
|
||||
self.assertEqual('foo_client_id', q['client_id'][0])
|
||||
self.assertEqual('foo_scope bar_scope', q['scope'][0])
|
||||
@@ -575,10 +575,10 @@ class DecoratorTests(unittest.TestCase):
|
||||
self.assertEqual('http://localhost/foo_path', parts[0])
|
||||
self.assertEqual(None, self.decorator.credentials)
|
||||
if self.decorator._token_response_param:
|
||||
response_query = urlparse.parse_qs(parts[1])
|
||||
response_query = urllib.parse.parse_qs(parts[1])
|
||||
response = response_query[self.decorator._token_response_param][0]
|
||||
self.assertEqual(Http2Mock.content,
|
||||
json.loads(urllib.unquote(response)))
|
||||
json.loads(urllib.parse.unquote(response)))
|
||||
self.assertEqual(self.decorator.flow, self.decorator._tls.flow)
|
||||
self.assertEqual(self.decorator.credentials,
|
||||
self.decorator._tls.credentials)
|
||||
@@ -609,7 +609,7 @@ class DecoratorTests(unittest.TestCase):
|
||||
# Invalid Credentials should start the OAuth dance again.
|
||||
response = self.app.get('/foo_path')
|
||||
self.assertTrue(response.status.startswith('302'))
|
||||
q = urlparse.parse_qs(response.headers['Location'].split('?', 1)[1])
|
||||
q = urllib.parse.parse_qs(response.headers['Location'].split('?', 1)[1])
|
||||
self.assertEqual('http://localhost/oauth2callback', q['redirect_uri'][0])
|
||||
|
||||
def test_storage_delete(self):
|
||||
@@ -654,7 +654,7 @@ class DecoratorTests(unittest.TestCase):
|
||||
self.assertEqual('200 OK', response.status)
|
||||
self.assertEqual(False, self.decorator.has_credentials())
|
||||
url = self.decorator.authorize_url()
|
||||
q = urlparse.parse_qs(url.split('?', 1)[1])
|
||||
q = urllib.parse.parse_qs(url.split('?', 1)[1])
|
||||
self.assertEqual('http://localhost/oauth2callback', q['redirect_uri'][0])
|
||||
self.assertEqual('foo_client_id', q['client_id'][0])
|
||||
self.assertEqual('foo_scope bar_scope', q['scope'][0])
|
||||
|
||||
@@ -20,11 +20,14 @@ Unit tests for oauth2client.gce.
|
||||
|
||||
__author__ = 'jcgregorio@google.com (Joe Gregorio)'
|
||||
|
||||
import json
|
||||
from six.moves import urllib
|
||||
import unittest
|
||||
|
||||
import httplib2
|
||||
import mock
|
||||
|
||||
from oauth2client._helpers import _to_bytes
|
||||
from oauth2client.client import AccessTokenRefreshError
|
||||
from oauth2client.client import Credentials
|
||||
from oauth2client.client import save_to_well_known_file
|
||||
@@ -33,22 +36,32 @@ from oauth2client.gce import AppAssertionCredentials
|
||||
|
||||
class AssertionCredentialsTests(unittest.TestCase):
|
||||
|
||||
def test_good_refresh(self):
|
||||
def _refresh_success_helper(self, bytes_response=False):
|
||||
access_token = u'this-is-a-token'
|
||||
return_val = json.dumps({u'accessToken': access_token})
|
||||
if bytes_response:
|
||||
return_val = _to_bytes(return_val)
|
||||
http = mock.MagicMock()
|
||||
http.request = mock.MagicMock(
|
||||
return_value=(mock.Mock(status=200),
|
||||
'{"accessToken": "this-is-a-token"}'))
|
||||
return_value=(mock.Mock(status=200), return_val))
|
||||
|
||||
c = AppAssertionCredentials(scope=['http://example.com/a',
|
||||
'http://example.com/b'])
|
||||
self.assertEquals(None, c.access_token)
|
||||
c.refresh(http)
|
||||
self.assertEquals('this-is-a-token', c.access_token)
|
||||
scopes = ['http://example.com/a', 'http://example.com/b']
|
||||
credentials = AppAssertionCredentials(scope=scopes)
|
||||
self.assertEquals(None, credentials.access_token)
|
||||
credentials.refresh(http)
|
||||
self.assertEquals(access_token, credentials.access_token)
|
||||
|
||||
http.request.assert_called_once_with(
|
||||
'http://metadata.google.internal/0.1/meta-data/service-accounts/'
|
||||
'default/acquire'
|
||||
'?scope=http%3A%2F%2Fexample.com%2Fa%20http%3A%2F%2Fexample.com%2Fb')
|
||||
base_metadata_uri = ('http://metadata.google.internal/0.1/meta-data/'
|
||||
'service-accounts/default/acquire')
|
||||
escaped_scopes = urllib.parse.quote(' '.join(scopes), safe='')
|
||||
request_uri = base_metadata_uri + '?scope=' + escaped_scopes
|
||||
http.request.assert_called_once_with(request_uri)
|
||||
|
||||
def test_refresh_success(self):
|
||||
self._refresh_success_helper(bytes_response=False)
|
||||
|
||||
def test_refresh_success_bytes(self):
|
||||
self._refresh_success_helper(bytes_response=True)
|
||||
|
||||
def test_fail_refresh(self):
|
||||
http = mock.MagicMock()
|
||||
|
||||
@@ -801,8 +801,8 @@ class BasicCredentialsTests(unittest.TestCase):
|
||||
http = credentials.authorize(http)
|
||||
http.request(u'http://example.com', method=u'GET', headers={u'foo': u'bar'})
|
||||
for k, v in six.iteritems(http.headers):
|
||||
self.assertEqual(six.binary_type, type(k))
|
||||
self.assertEqual(six.binary_type, type(v))
|
||||
self.assertTrue(isinstance(k, six.binary_type))
|
||||
self.assertTrue(isinstance(v, six.binary_type))
|
||||
|
||||
# Test again with unicode strings that can't simply be converted to ASCII.
|
||||
try:
|
||||
|
||||
Reference in New Issue
Block a user