Begin transport module for httplib2 specific pieces.

Towards #554.
This commit is contained in:
Danny Hermes
2016-07-20 13:24:41 -07:00
parent 3e9b857ffa
commit 213050dd67
8 changed files with 353 additions and 145 deletions

View File

@@ -19,6 +19,7 @@ Submodules
oauth2client.file
oauth2client.service_account
oauth2client.tools
oauth2client.transport
oauth2client.util
Module contents

View File

@@ -0,0 +1,7 @@
oauth2client.transport module
=============================
.. automodule:: oauth2client.transport
:members:
:undoc-members:
:show-inheritance:

View File

@@ -28,7 +28,6 @@ import socket
import sys
import tempfile
import httplib2
import six
from six.moves import http_client
from six.moves import urllib
@@ -39,9 +38,9 @@ from oauth2client import GOOGLE_DEVICE_URI
from oauth2client import GOOGLE_REVOKE_URI
from oauth2client import GOOGLE_TOKEN_INFO_URI
from oauth2client import GOOGLE_TOKEN_URI
from oauth2client import transport
from oauth2client import util
from oauth2client._helpers import _from_bytes
from oauth2client._helpers import _to_bytes
from oauth2client._helpers import _urlsafe_b64decode
@@ -71,9 +70,6 @@ ID_TOKEN_VERIFICATON_CERTS = ID_TOKEN_VERIFICATION_CERTS
# Constant to use for the out of band OAuth 2.0 flow.
OOB_CALLBACK_URN = 'urn:ietf:wg:oauth:2.0:oob'
# Google Data client libraries may need to set this to [401, 403].
REFRESH_STATUS_CODES = (http_client.UNAUTHORIZED,)
# The value representing user credentials.
AUTHORIZED_USER = 'authorized_user'
@@ -120,6 +116,12 @@ _DESIRED_METADATA_FLAVOR = 'Google'
# easier testing (by replacing with a stub).
_UTCNOW = datetime.datetime.utcnow
# NOTE: These names were previously defined in this module but have been
# moved into `oauth2client.transport`,
clean_headers = transport.clean_headers
MemoryCache = transport.MemoryCache
REFRESH_STATUS_CODES = transport.REFRESH_STATUS_CODES
class SETTINGS(object):
"""Settings namespace for globally defined values."""
@@ -177,22 +179,6 @@ class CryptoUnavailableError(Error, NotImplementedError):
"""Raised when a crypto library is required, but none is available."""
class MemoryCache(object):
"""httplib2 Cache implementation which only caches locally."""
def __init__(self):
self.cache = {}
def get(self, key):
return self.cache.get(key)
def set(self, key, value):
self.cache[key] = value
def delete(self, key):
self.cache.pop(key, None)
def _parse_expiry(expiry):
if expiry and isinstance(expiry, datetime.datetime):
return expiry.strftime(EXPIRY_FORMAT)
@@ -451,32 +437,6 @@ class Storage(object):
self.release_lock()
def clean_headers(headers):
"""Forces header keys and values to be strings, i.e not unicode.
The httplib module just concats the header keys and values in a way that
may make the message header a unicode string, which, if it then tries to
contatenate to a binary request body may result in a unicode decode error.
Args:
headers: dict, A dictionary of headers.
Returns:
The same dictionary but with all the keys converted to strings.
"""
clean = {}
try:
for k, v in six.iteritems(headers):
if not isinstance(k, six.binary_type):
k = str(k)
if not isinstance(v, six.binary_type):
v = str(v)
clean[_to_bytes(k)] = _to_bytes(v)
except UnicodeEncodeError:
raise NonAsciiHeaderError(k, ': ', v)
return clean
def _update_query_params(uri, params):
"""Updates a URI with new query parameters.
@@ -494,26 +454,6 @@ def _update_query_params(uri, params):
return urllib.parse.urlunparse(new_parts)
def _initialize_headers(headers):
"""Creates a copy of the headers."""
if headers is None:
headers = {}
else:
headers = dict(headers)
return headers
def _apply_user_agent(headers, user_agent):
"""Adds a user-agent to the headers."""
if user_agent is not None:
if 'user-agent' in headers:
headers['user-agent'] = (user_agent + ' ' + headers['user-agent'])
else:
headers['user-agent'] = user_agent
return headers
class OAuth2Credentials(Credentials):
"""Credentials object for OAuth 2.0.
@@ -604,58 +544,7 @@ class OAuth2Credentials(Credentials):
that adds in the Authorization header and then calls the original
version of 'request()'.
"""
request_orig = http.request
# The closure that will replace 'httplib2.Http.request'.
def new_request(uri, method='GET', body=None, headers=None,
redirections=httplib2.DEFAULT_MAX_REDIRECTS,
connection_type=None):
if not self.access_token:
logger.info('Attempting refresh to obtain '
'initial access_token')
self._refresh(request_orig)
# Clone and modify the request headers to add the appropriate
# Authorization header.
headers = _initialize_headers(headers)
self.apply(headers)
_apply_user_agent(headers, self.user_agent)
body_stream_position = None
if all(getattr(body, stream_prop, None) for stream_prop in
('read', 'seek', 'tell')):
body_stream_position = body.tell()
resp, content = request_orig(uri, method, body,
clean_headers(headers),
redirections, connection_type)
# A stored token may expire between the time it is retrieved and
# the time the request is made, so we may need to try twice.
max_refresh_attempts = 2
for refresh_attempt in range(max_refresh_attempts):
if resp.status not in REFRESH_STATUS_CODES:
break
logger.info('Refreshing due to a %s (attempt %s/%s)',
resp.status, refresh_attempt + 1,
max_refresh_attempts)
self._refresh(request_orig)
self.apply(headers)
if body_stream_position is not None:
body.seek(body_stream_position)
resp, content = request_orig(uri, method, body,
clean_headers(headers),
redirections, connection_type)
return (resp, content)
# Replace the request method with our own closure.
http.request = new_request
# Set credentials as a property of the request method.
setattr(http.request, 'credentials', self)
transport.wrap_http_for_auth(self, http)
return http
def refresh(self, http):
@@ -781,7 +670,7 @@ class OAuth2Credentials(Credentials):
"""
if not self.access_token or self.access_token_expired:
if not http:
http = httplib2.Http()
http = transport.get_http_object()
self.refresh(http)
return AccessTokenInfo(access_token=self.access_token,
expires_in=self._expires_in())
@@ -1654,11 +1543,6 @@ def _require_crypto_or_die():
raise CryptoUnavailableError('No crypto library available')
# Only used in verify_id_token(), which is always calling to the same URI
# for the certs.
_cached_http = httplib2.Http(MemoryCache())
@util.positional(2)
def verify_id_token(id_token, audience, http=None,
cert_uri=ID_TOKEN_VERIFICATION_CERTS):
@@ -1684,7 +1568,7 @@ def verify_id_token(id_token, audience, http=None,
"""
_require_crypto_or_die()
if http is None:
http = _cached_http
http = transport.get_cached_http()
resp, content = http.request(cert_uri)
if resp.status == http_client.OK:
@@ -2027,7 +1911,7 @@ class OAuth2WebServerFlow(Flow):
headers['user-agent'] = self.user_agent
if http is None:
http = httplib2.Http()
http = transport.get_http_object()
resp, content = http.request(self.device_uri, method='POST', body=body,
headers=headers)
@@ -2110,7 +1994,7 @@ class OAuth2WebServerFlow(Flow):
headers['user-agent'] = self.user_agent
if http is None:
http = httplib2.Http()
http = transport.get_http_object()
resp, content = http.request(self.token_uri, method='POST', body=body,
headers=headers)

View File

@@ -27,14 +27,14 @@ from oauth2client import GOOGLE_REVOKE_URI
from oauth2client import GOOGLE_TOKEN_URI
from oauth2client import util
from oauth2client._helpers import _from_bytes
from oauth2client.client import _apply_user_agent
from oauth2client.client import _initialize_headers
from oauth2client.client import _UTCNOW
from oauth2client.client import AccessTokenInfo
from oauth2client.client import AssertionCredentials
from oauth2client.client import clean_headers
from oauth2client.client import EXPIRY_FORMAT
from oauth2client.client import SERVICE_ACCOUNT
from oauth2client.transport import _apply_user_agent
from oauth2client.transport import _initialize_headers
from oauth2client.transport import clean_headers
_PASSWORD_DEFAULT = 'notasecret'

198
oauth2client/transport.py Normal file
View File

@@ -0,0 +1,198 @@
# Copyright 2016 Google Inc. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
import httplib2
import six
from six.moves import http_client
from oauth2client._helpers import _to_bytes
_LOGGER = logging.getLogger(__name__)
# Properties present in file-like streams / buffers.
_STREAM_PROPERTIES = ('read', 'seek', 'tell')
# Google Data client libraries may need to set this to [401, 403].
REFRESH_STATUS_CODES = (http_client.UNAUTHORIZED,)
class MemoryCache(object):
"""httplib2 Cache implementation which only caches locally."""
def __init__(self):
self.cache = {}
def get(self, key):
return self.cache.get(key)
def set(self, key, value):
self.cache[key] = value
def delete(self, key):
self.cache.pop(key, None)
def get_cached_http():
"""Return an HTTP object which caches results returned.
This is intended to be used in methods like
oauth2client.client.verify_id_token(), which calls to the same URI
to retrieve certs.
Returns:
httplib2.Http, an HTTP object with a MemoryCache
"""
return _CACHED_HTTP
def get_http_object():
"""Return a new HTTP object.
Returns:
httplib2.Http, an HTTP object.
"""
return httplib2.Http()
def _initialize_headers(headers):
"""Creates a copy of the headers.
Args:
headers: dict, request headers to copy.
Returns:
dict, the copied headers or a new dictionary if the headers
were None.
"""
return {} if headers is None else dict(headers)
def _apply_user_agent(headers, user_agent):
"""Adds a user-agent to the headers.
Args:
headers: dict, request headers to add / modify user
agent within.
user_agent: str, the user agent to add.
Returns:
dict, the original headers passed in, but modified if the
user agent is not None.
"""
if user_agent is not None:
if 'user-agent' in headers:
headers['user-agent'] = (user_agent + ' ' + headers['user-agent'])
else:
headers['user-agent'] = user_agent
return headers
def clean_headers(headers):
"""Forces header keys and values to be strings, i.e not unicode.
The httplib module just concats the header keys and values in a way that
may make the message header a unicode string, which, if it then tries to
contatenate to a binary request body may result in a unicode decode error.
Args:
headers: dict, A dictionary of headers.
Returns:
The same dictionary but with all the keys converted to strings.
"""
clean = {}
try:
for k, v in six.iteritems(headers):
if not isinstance(k, six.binary_type):
k = str(k)
if not isinstance(v, six.binary_type):
v = str(v)
clean[_to_bytes(k)] = _to_bytes(v)
except UnicodeEncodeError:
from oauth2client.client import NonAsciiHeaderError
raise NonAsciiHeaderError(k, ': ', v)
return clean
def wrap_http_for_auth(credentials, http):
"""Prepares an HTTP object's request method for auth.
Wraps HTTP requests with logic to catch auth failures (typically
identified via a 401 status code). In the event of failure, tries
to refresh the token used and then retry the original request.
Args:
credentials: Credentials, the credentials used to identify
the authenticated user.
http: httplib2.Http, an http object to be used to make
auth requests.
"""
orig_request_method = http.request
# The closure that will replace 'httplib2.Http.request'.
def new_request(uri, method='GET', body=None, headers=None,
redirections=httplib2.DEFAULT_MAX_REDIRECTS,
connection_type=None):
if not credentials.access_token:
_LOGGER.info('Attempting refresh to obtain '
'initial access_token')
credentials._refresh(orig_request_method)
# Clone and modify the request headers to add the appropriate
# Authorization header.
headers = _initialize_headers(headers)
credentials.apply(headers)
_apply_user_agent(headers, credentials.user_agent)
body_stream_position = None
# Check if the body is a file-like stream.
if all(getattr(body, stream_prop, None) for stream_prop in
_STREAM_PROPERTIES):
body_stream_position = body.tell()
resp, content = orig_request_method(uri, method, body,
clean_headers(headers),
redirections, connection_type)
# A stored token may expire between the time it is retrieved and
# the time the request is made, so we may need to try twice.
max_refresh_attempts = 2
for refresh_attempt in range(max_refresh_attempts):
if resp.status not in REFRESH_STATUS_CODES:
break
_LOGGER.info('Refreshing due to a %s (attempt %s/%s)',
resp.status, refresh_attempt + 1,
max_refresh_attempts)
credentials._refresh(orig_request_method)
credentials.apply(headers)
if body_stream_position is not None:
body.seek(body_stream_position)
resp, content = orig_request_method(uri, method, body,
clean_headers(headers),
redirections, connection_type)
return resp, content
# Replace the request method with our own closure.
http.request = new_request
# Set credentials as a property of the request method.
setattr(http.request, 'credentials', credentials)
_CACHED_HTTP = httplib2.Http(MemoryCache())

View File

@@ -67,7 +67,6 @@ from oauth2client.client import FlowExchangeError
from oauth2client.client import GOOGLE_APPLICATION_CREDENTIALS
from oauth2client.client import GoogleCredentials
from oauth2client.client import HttpAccessTokenRefreshError
from oauth2client.client import MemoryCache
from oauth2client.client import NonAsciiHeaderError
from oauth2client.client import OAuth2Credentials
from oauth2client.client import OAuth2WebServerFlow
@@ -2242,18 +2241,6 @@ class CredentialsFromCodeTests(unittest2.TestCase):
self.code, http=http)
class MemoryCacheTests(unittest2.TestCase):
def test_get_set_delete(self):
m = MemoryCache()
self.assertEqual(None, m.get('foo'))
self.assertEqual(None, m.delete('foo'))
m.set('foo', 'bar')
self.assertEqual('bar', m.get('foo'))
m.delete('foo')
self.assertEqual(None, m.get('foo'))
class Test__save_private_file(unittest2.TestCase):
def _save_helper(self, filename):

View File

@@ -138,7 +138,7 @@ class CryptTests(unittest2.TestCase):
({'status': '200'}, datafile('certs.json')),
])
with mock.patch('oauth2client.client._cached_http', new=http):
with mock.patch('oauth2client.transport._CACHED_HTTP', new=http):
contents = verify_id_token(
jwt, 'some_audience_address@testing.gserviceaccount.com')

131
tests/test_transport.py Normal file
View File

@@ -0,0 +1,131 @@
# Copyright 2016 Google Inc. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import httplib2
import mock
import unittest2
from oauth2client import client
from oauth2client import transport
class TestMemoryCache(unittest2.TestCase):
def test_get_set_delete(self):
cache = transport.MemoryCache()
self.assertIsNone(cache.get('foo'))
self.assertIsNone(cache.delete('foo'))
cache.set('foo', 'bar')
self.assertEqual('bar', cache.get('foo'))
cache.delete('foo')
self.assertIsNone(cache.get('foo'))
class Test_get_cached_http(unittest2.TestCase):
def test_global(self):
cached_http = transport.get_cached_http()
self.assertIsInstance(cached_http, httplib2.Http)
self.assertIsInstance(cached_http.cache, transport.MemoryCache)
def test_value(self):
cache = object()
with mock.patch('oauth2client.transport._CACHED_HTTP', new=cache):
result = transport.get_cached_http()
self.assertIs(result, cache)
class Test_get_http_object(unittest2.TestCase):
@mock.patch.object(httplib2, 'Http', return_value=object())
def test_it(self, http_klass):
result = transport.get_http_object()
self.assertEqual(result, http_klass.return_value)
class Test__initialize_headers(unittest2.TestCase):
def test_null(self):
result = transport._initialize_headers(None)
self.assertEqual(result, {})
def test_copy(self):
headers = {'a': 1, 'b': 2}
result = transport._initialize_headers(headers)
self.assertEqual(result, headers)
self.assertIsNot(result, headers)
class Test__apply_user_agent(unittest2.TestCase):
def test_null(self):
headers = object()
result = transport._apply_user_agent(headers, None)
self.assertIs(result, headers)
def test_new_agent(self):
headers = {}
user_agent = 'foo'
result = transport._apply_user_agent(headers, user_agent)
self.assertIs(result, headers)
self.assertEqual(result, {'user-agent': user_agent})
def test_append(self):
orig_agent = 'bar'
headers = {'user-agent': orig_agent}
user_agent = 'baz'
result = transport._apply_user_agent(headers, user_agent)
self.assertIs(result, headers)
final_agent = user_agent + ' ' + orig_agent
self.assertEqual(result, {'user-agent': final_agent})
class Test_clean_headers(unittest2.TestCase):
def test_no_modify(self):
headers = {b'key': b'val'}
result = transport.clean_headers(headers)
self.assertIsNot(result, headers)
self.assertEqual(result, headers)
def test_cast_unicode(self):
headers = {u'key': u'val'}
header_bytes = {b'key': b'val'}
result = transport.clean_headers(headers)
self.assertIsNot(result, headers)
self.assertEqual(result, header_bytes)
def test_unicode_failure(self):
headers = {u'key': u'\u2603'}
with self.assertRaises(client.NonAsciiHeaderError):
transport.clean_headers(headers)
def test_cast_object(self):
headers = {b'key': True}
header_str = {b'key': b'True'}
result = transport.clean_headers(headers)
self.assertIsNot(result, headers)
self.assertEqual(result, header_str)
class Test_wrap_http_for_auth(unittest2.TestCase):
def test_wrap(self):
credentials = object()
http = mock.Mock()
http.request = orig_req_method = object()
result = transport.wrap_http_for_auth(credentials, http)
self.assertIsNone(result)
self.assertNotEqual(http.request, orig_req_method)
self.assertIs(http.request.credentials, credentials)