diff --git a/oauth2client/client.py b/oauth2client/client.py index c5b7214..ae0301d 100644 --- a/oauth2client/client.py +++ b/oauth2client/client.py @@ -409,7 +409,9 @@ def clean_headers(headers): clean = {} try: for k, v in six.iteritems(headers): - clean[k.encode('ascii')] = v.encode('ascii') + clean_k = k if isinstance(k, bytes) else str(k).encode('ascii') + clean_v = v if isinstance(v, bytes) else str(v).encode('ascii') + clean[clean_k] = clean_v except UnicodeEncodeError: raise NonAsciiHeaderError(k + ': ' + v) return clean diff --git a/tests/test_oauth2client.py b/tests/test_oauth2client.py index fea08c8..82ddb92 100644 --- a/tests/test_oauth2client.py +++ b/tests/test_oauth2client.py @@ -598,6 +598,37 @@ class BasicCredentialsTests(unittest.TestCase): instance = OAuth2Credentials.from_json(json.dumps(data)) self.assertTrue(isinstance(instance, OAuth2Credentials)) + def test_unicode_header_checks(self): + access_token = u'foo' + client_id = u'some_client_id' + client_secret = u'cOuDdkfjxxnv+' + refresh_token = u'1/0/a.df219fjls0' + token_expiry = str(datetime.datetime.utcnow()) + token_uri = str(GOOGLE_TOKEN_URI) + revoke_uri = str(GOOGLE_REVOKE_URI) + user_agent = u'refresh_checker/1.0' + credentials = OAuth2Credentials(access_token, client_id, client_secret, + refresh_token, token_expiry, token_uri, + user_agent, revoke_uri=revoke_uri) + + # First, test that we correctly encode basic objects, making sure + # to include a bytes object. Note that oauth2client will normalize + # everything to bytes, no matter what python version we're in. + http = credentials.authorize(HttpMock(headers={'status': '200'})) + headers = {u'foo': 3, b'bar': True, 'baz': b'abc'} + cleaned_headers = {b'foo': b'3', b'bar': b'True', b'baz': b'abc'} + http.request(u'http://example.com', method=u'GET', headers=headers) + for k, v in cleaned_headers.items(): + self.assertTrue(k in http.headers) + self.assertEqual(v, http.headers[k]) + + # Next, test that we do fail on unicode. + unicode_str = six.unichr(40960) + 'abcd' + self.assertRaises( + NonAsciiHeaderError, + http.request, + u'http://example.com', method=u'GET', headers={u'foo': unicode_str}) + def test_no_unicode_in_request_params(self): access_token = u'foo' client_id = u'some_client_id'