Fix bugs with auth, batch and retries.
Reviewed in http://codereview.appspot.com/5633052/.
This commit is contained in:
@@ -37,6 +37,7 @@ import urllib
|
|||||||
import urlparse
|
import urlparse
|
||||||
import uuid
|
import uuid
|
||||||
|
|
||||||
|
from email.generator import Generator
|
||||||
from email.mime.multipart import MIMEMultipart
|
from email.mime.multipart import MIMEMultipart
|
||||||
from email.mime.nonmultipart import MIMENonMultipart
|
from email.mime.nonmultipart import MIMENonMultipart
|
||||||
from email.parser import FeedParser
|
from email.parser import FeedParser
|
||||||
@@ -498,9 +499,12 @@ class BatchHttpRequest(object):
|
|||||||
# Global callback to be called for each individual response in the batch.
|
# Global callback to be called for each individual response in the batch.
|
||||||
self._callback = callback
|
self._callback = callback
|
||||||
|
|
||||||
# A map from id to (request, callback) pairs.
|
# A map from id to request.
|
||||||
self._requests = {}
|
self._requests = {}
|
||||||
|
|
||||||
|
# A map from id to callback.
|
||||||
|
self._callbacks = {}
|
||||||
|
|
||||||
# List of request ids, in the order in which they were added.
|
# List of request ids, in the order in which they were added.
|
||||||
self._order = []
|
self._order = []
|
||||||
|
|
||||||
@@ -510,6 +514,39 @@ class BatchHttpRequest(object):
|
|||||||
# Unique ID on which to base the Content-ID headers.
|
# Unique ID on which to base the Content-ID headers.
|
||||||
self._base_id = None
|
self._base_id = None
|
||||||
|
|
||||||
|
# A map from request id to (headers, content) response pairs
|
||||||
|
self._responses = {}
|
||||||
|
|
||||||
|
# A map of id(Credentials) that have been refreshed.
|
||||||
|
self._refreshed_credentials = {}
|
||||||
|
|
||||||
|
def _refresh_and_apply_credentials(self, request, http):
|
||||||
|
"""Refresh the credentials and apply to the request.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
request: HttpRequest, the request.
|
||||||
|
http: httplib2.Http, the global http object for the batch.
|
||||||
|
"""
|
||||||
|
# For the credentials to refresh, but only once per refresh_token
|
||||||
|
# If there is no http per the request then refresh the http passed in
|
||||||
|
# via execute()
|
||||||
|
creds = None
|
||||||
|
if request.http is not None and hasattr(request.http.request,
|
||||||
|
'credentials'):
|
||||||
|
creds = request.http.request.credentials
|
||||||
|
elif http is not None and hasattr(http.request, 'credentials'):
|
||||||
|
creds = http.request.credentials
|
||||||
|
if creds is not None:
|
||||||
|
if id(creds) not in self._refreshed_credentials:
|
||||||
|
creds.refresh(http)
|
||||||
|
self._refreshed_credentials[id(creds)] = 1
|
||||||
|
|
||||||
|
# Only apply the credentials if we are using the http object passed in,
|
||||||
|
# otherwise apply() will get called during _serialize_request().
|
||||||
|
if request.http is None or not hasattr(request.http.request,
|
||||||
|
'credentials'):
|
||||||
|
creds.apply(request.headers)
|
||||||
|
|
||||||
def _id_to_header(self, id_):
|
def _id_to_header(self, id_):
|
||||||
"""Convert an id to a Content-ID header value.
|
"""Convert an id to a Content-ID header value.
|
||||||
|
|
||||||
@@ -568,6 +605,10 @@ class BatchHttpRequest(object):
|
|||||||
msg = MIMENonMultipart(major, minor)
|
msg = MIMENonMultipart(major, minor)
|
||||||
headers = request.headers.copy()
|
headers = request.headers.copy()
|
||||||
|
|
||||||
|
if request.http is not None and hasattr(request.http.request,
|
||||||
|
'credentials'):
|
||||||
|
request.http.request.credentials.apply(headers)
|
||||||
|
|
||||||
# MIMENonMultipart adds its own Content-Type header.
|
# MIMENonMultipart adds its own Content-Type header.
|
||||||
if 'content-type' in headers:
|
if 'content-type' in headers:
|
||||||
del headers['content-type']
|
del headers['content-type']
|
||||||
@@ -581,7 +622,13 @@ class BatchHttpRequest(object):
|
|||||||
msg.set_payload(request.body)
|
msg.set_payload(request.body)
|
||||||
msg['content-length'] = str(len(request.body))
|
msg['content-length'] = str(len(request.body))
|
||||||
|
|
||||||
body = msg.as_string(False)
|
# Serialize the mime message.
|
||||||
|
fp = StringIO.StringIO()
|
||||||
|
# maxheaderlen=0 means don't line wrap headers.
|
||||||
|
g = Generator(fp, maxheaderlen=0)
|
||||||
|
g.flatten(msg, unixfrom=False)
|
||||||
|
body = fp.getvalue()
|
||||||
|
|
||||||
# Strip off the \n\n that the MIME lib tacks onto the end of the payload.
|
# Strip off the \n\n that the MIME lib tacks onto the end of the payload.
|
||||||
if request.body is None:
|
if request.body is None:
|
||||||
body = body[:-2]
|
body = body[:-2]
|
||||||
@@ -661,9 +708,71 @@ class BatchHttpRequest(object):
|
|||||||
raise BatchError("Resumable requests cannot be used in a batch request.")
|
raise BatchError("Resumable requests cannot be used in a batch request.")
|
||||||
if request_id in self._requests:
|
if request_id in self._requests:
|
||||||
raise KeyError("A request with this ID already exists: %s" % request_id)
|
raise KeyError("A request with this ID already exists: %s" % request_id)
|
||||||
self._requests[request_id] = (request, callback)
|
self._requests[request_id] = request
|
||||||
|
self._callbacks[request_id] = callback
|
||||||
self._order.append(request_id)
|
self._order.append(request_id)
|
||||||
|
|
||||||
|
def _execute(self, http, order, requests):
|
||||||
|
"""Serialize batch request, send to server, process response.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
http: httplib2.Http, an http object to be used to make the request with.
|
||||||
|
order: list, list of request ids in the order they were added to the
|
||||||
|
batch.
|
||||||
|
request: list, list of request objects to send.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
httplib2.Error if a transport error has occured.
|
||||||
|
apiclient.errors.BatchError if the response is the wrong format.
|
||||||
|
"""
|
||||||
|
message = MIMEMultipart('mixed')
|
||||||
|
# Message should not write out it's own headers.
|
||||||
|
setattr(message, '_write_headers', lambda self: None)
|
||||||
|
|
||||||
|
# Add all the individual requests.
|
||||||
|
for request_id in order:
|
||||||
|
request = requests[request_id]
|
||||||
|
|
||||||
|
msg = MIMENonMultipart('application', 'http')
|
||||||
|
msg['Content-Transfer-Encoding'] = 'binary'
|
||||||
|
msg['Content-ID'] = self._id_to_header(request_id)
|
||||||
|
|
||||||
|
body = self._serialize_request(request)
|
||||||
|
msg.set_payload(body)
|
||||||
|
message.attach(msg)
|
||||||
|
|
||||||
|
body = message.as_string()
|
||||||
|
|
||||||
|
headers = {}
|
||||||
|
headers['content-type'] = ('multipart/mixed; '
|
||||||
|
'boundary="%s"') % message.get_boundary()
|
||||||
|
|
||||||
|
resp, content = http.request(self._batch_uri, 'POST', body=body,
|
||||||
|
headers=headers)
|
||||||
|
|
||||||
|
if resp.status >= 300:
|
||||||
|
raise HttpError(resp, content, self._batch_uri)
|
||||||
|
|
||||||
|
# Now break out the individual responses and store each one.
|
||||||
|
boundary, _ = content.split(None, 1)
|
||||||
|
|
||||||
|
# Prepend with a content-type header so FeedParser can handle it.
|
||||||
|
header = 'content-type: %s\r\n\r\n' % resp['content-type']
|
||||||
|
for_parser = header + content
|
||||||
|
|
||||||
|
parser = FeedParser()
|
||||||
|
parser.feed(for_parser)
|
||||||
|
mime_response = parser.close()
|
||||||
|
|
||||||
|
if not mime_response.is_multipart():
|
||||||
|
raise BatchError("Response not in multipart/mixed format.", resp,
|
||||||
|
content)
|
||||||
|
|
||||||
|
for part in mime_response.get_payload():
|
||||||
|
request_id = self._header_to_id(part['Content-ID'])
|
||||||
|
headers, content = self._deserialize_response(part.get_payload())
|
||||||
|
self._responses[request_id] = (headers, content)
|
||||||
|
|
||||||
def execute(self, http=None):
|
def execute(self, http=None):
|
||||||
"""Execute all the requests as a single batched HTTP request.
|
"""Execute all the requests as a single batched HTTP request.
|
||||||
|
|
||||||
@@ -676,84 +785,61 @@ class BatchHttpRequest(object):
|
|||||||
None
|
None
|
||||||
|
|
||||||
Raises:
|
Raises:
|
||||||
apiclient.errors.HttpError if the response was not a 2xx.
|
|
||||||
httplib2.Error if a transport error has occured.
|
httplib2.Error if a transport error has occured.
|
||||||
apiclient.errors.BatchError if the response is the wrong format.
|
apiclient.errors.BatchError if the response is the wrong format.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
# If http is not supplied use the first valid one given in the requests.
|
||||||
if http is None:
|
if http is None:
|
||||||
for request_id in self._order:
|
for request_id in self._order:
|
||||||
request, callback = self._requests[request_id]
|
request = self._requests[request_id]
|
||||||
if request is not None:
|
if request is not None:
|
||||||
http = request.http
|
http = request.http
|
||||||
break
|
break
|
||||||
|
|
||||||
if http is None:
|
if http is None:
|
||||||
raise ValueError("Missing a valid http object.")
|
raise ValueError("Missing a valid http object.")
|
||||||
|
|
||||||
|
self._execute(http, self._order, self._requests)
|
||||||
|
|
||||||
msgRoot = MIMEMultipart('mixed')
|
# Loop over all the requests and check for 401s. For each 401 request the
|
||||||
# msgRoot should not write out it's own headers
|
# credentials should be refreshed and then sent again in a separate batch.
|
||||||
setattr(msgRoot, '_write_headers', lambda self: None)
|
redo_requests = {}
|
||||||
|
redo_order = []
|
||||||
|
|
||||||
# Add all the individual requests.
|
|
||||||
for request_id in self._order:
|
for request_id in self._order:
|
||||||
request, callback = self._requests[request_id]
|
headers, content = self._responses[request_id]
|
||||||
|
if headers['status'] == '401':
|
||||||
|
redo_order.append(request_id)
|
||||||
|
request = self._requests[request_id]
|
||||||
|
self._refresh_and_apply_credentials(request, http)
|
||||||
|
redo_requests[request_id] = request
|
||||||
|
|
||||||
msg = MIMENonMultipart('application', 'http')
|
if redo_requests:
|
||||||
msg['Content-Transfer-Encoding'] = 'binary'
|
self._execute(http, redo_order, redo_requests)
|
||||||
msg['Content-ID'] = self._id_to_header(request_id)
|
|
||||||
|
|
||||||
body = self._serialize_request(request)
|
# Now process all callbacks that are erroring, and raise an exception for
|
||||||
msg.set_payload(body)
|
# ones that return a non-2xx response? Or add extra parameter to callback
|
||||||
msgRoot.attach(msg)
|
# that contains an HttpError?
|
||||||
|
|
||||||
body = msgRoot.as_string()
|
for request_id in self._order:
|
||||||
|
headers, content = self._responses[request_id]
|
||||||
|
|
||||||
headers = {}
|
request = self._requests[request_id]
|
||||||
headers['content-type'] = ('multipart/mixed; '
|
callback = self._callbacks[request_id]
|
||||||
'boundary="%s"') % msgRoot.get_boundary()
|
|
||||||
|
|
||||||
resp, content = http.request(self._batch_uri, 'POST', body=body,
|
response = None
|
||||||
headers=headers)
|
exception = None
|
||||||
|
try:
|
||||||
|
r = httplib2.Response(headers)
|
||||||
|
response = request.postproc(r, content)
|
||||||
|
except HttpError, e:
|
||||||
|
exception = e
|
||||||
|
|
||||||
if resp.status >= 300:
|
if callback is not None:
|
||||||
raise HttpError(resp, content, self._batch_uri)
|
callback(request_id, response, exception)
|
||||||
|
|
||||||
# Now break up the response and process each one with the correct postproc
|
|
||||||
# and trigger the right callbacks.
|
|
||||||
boundary, _ = content.split(None, 1)
|
|
||||||
|
|
||||||
# Prepend with a content-type header so FeedParser can handle it.
|
|
||||||
header = 'content-type: %s\r\n\r\n' % resp['content-type']
|
|
||||||
for_parser = header + content
|
|
||||||
|
|
||||||
parser = FeedParser()
|
|
||||||
parser.feed(for_parser)
|
|
||||||
respRoot = parser.close()
|
|
||||||
|
|
||||||
if not respRoot.is_multipart():
|
|
||||||
raise BatchError("Response not in multipart/mixed format.", resp,
|
|
||||||
content)
|
|
||||||
|
|
||||||
parts = respRoot.get_payload()
|
|
||||||
for part in parts:
|
|
||||||
request_id = self._header_to_id(part['Content-ID'])
|
|
||||||
|
|
||||||
headers, content = self._deserialize_response(part.get_payload())
|
|
||||||
|
|
||||||
# TODO(jcgregorio) Remove this temporary hack once the server stops
|
|
||||||
# gzipping individual response bodies.
|
|
||||||
if content[0] != '{':
|
|
||||||
gzipped_content = content
|
|
||||||
content = gzip.GzipFile(
|
|
||||||
fileobj=StringIO.StringIO(gzipped_content)).read()
|
|
||||||
|
|
||||||
request, cb = self._requests[request_id]
|
|
||||||
postproc = request.postproc
|
|
||||||
response = postproc(resp, content)
|
|
||||||
if cb is not None:
|
|
||||||
cb(request_id, response)
|
|
||||||
if self._callback is not None:
|
if self._callback is not None:
|
||||||
self._callback(request_id, response)
|
self._callback(request_id, response, exception)
|
||||||
|
|
||||||
|
|
||||||
class HttpRequestMock(object):
|
class HttpRequestMock(object):
|
||||||
|
|||||||
@@ -129,6 +129,23 @@ class Credentials(object):
|
|||||||
"""
|
"""
|
||||||
_abstract()
|
_abstract()
|
||||||
|
|
||||||
|
def refresh(self, http):
|
||||||
|
"""Forces a refresh of the access_token.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
http: httplib2.Http, an http object to be used to make the refresh
|
||||||
|
request.
|
||||||
|
"""
|
||||||
|
_abstract()
|
||||||
|
|
||||||
|
def apply(self, headers):
|
||||||
|
"""Add the authorization to the headers.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
headers: dict, the headers to add the Authorization header to.
|
||||||
|
"""
|
||||||
|
_abstract()
|
||||||
|
|
||||||
def _to_json(self, strip):
|
def _to_json(self, strip):
|
||||||
"""Utility function for creating a JSON representation of an instance of Credentials.
|
"""Utility function for creating a JSON representation of an instance of Credentials.
|
||||||
|
|
||||||
@@ -324,6 +341,92 @@ class OAuth2Credentials(Credentials):
|
|||||||
# refreshed.
|
# refreshed.
|
||||||
self.invalid = False
|
self.invalid = False
|
||||||
|
|
||||||
|
def authorize(self, http):
|
||||||
|
"""Authorize an httplib2.Http instance with these credentials.
|
||||||
|
|
||||||
|
The modified http.request method will add authentication headers to each
|
||||||
|
request and will refresh access_tokens when a 401 is received on a
|
||||||
|
request. In addition the http.request method has a credentials property,
|
||||||
|
http.request.credentials, which is the Credentials object that authorized
|
||||||
|
it.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
http: An instance of httplib2.Http
|
||||||
|
or something that acts like it.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A modified instance of http that was passed in.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
|
||||||
|
h = httplib2.Http()
|
||||||
|
h = credentials.authorize(h)
|
||||||
|
|
||||||
|
You can't create a new OAuth subclass of httplib2.Authenication
|
||||||
|
because it never gets passed the absolute URI, which is needed for
|
||||||
|
signing. So instead we have to overload 'request' with a closure
|
||||||
|
that adds in the Authorization header and then calls the original
|
||||||
|
version of 'request()'.
|
||||||
|
"""
|
||||||
|
request_orig = http.request
|
||||||
|
|
||||||
|
# The closure that will replace 'httplib2.Http.request'.
|
||||||
|
def new_request(uri, method='GET', body=None, headers=None,
|
||||||
|
redirections=httplib2.DEFAULT_MAX_REDIRECTS,
|
||||||
|
connection_type=None):
|
||||||
|
if not self.access_token:
|
||||||
|
logger.info('Attempting refresh to obtain initial access_token')
|
||||||
|
self._refresh(request_orig)
|
||||||
|
|
||||||
|
# Modify the request headers to add the appropriate
|
||||||
|
# Authorization header.
|
||||||
|
if headers is None:
|
||||||
|
headers = {}
|
||||||
|
self.apply(headers)
|
||||||
|
|
||||||
|
if self.user_agent is not None:
|
||||||
|
if 'user-agent' in headers:
|
||||||
|
headers['user-agent'] = self.user_agent + ' ' + headers['user-agent']
|
||||||
|
else:
|
||||||
|
headers['user-agent'] = self.user_agent
|
||||||
|
|
||||||
|
resp, content = request_orig(uri, method, body, headers,
|
||||||
|
redirections, connection_type)
|
||||||
|
|
||||||
|
if resp.status == 401:
|
||||||
|
logger.info('Refreshing due to a 401')
|
||||||
|
self._refresh(request_orig)
|
||||||
|
self.apply(headers)
|
||||||
|
return request_orig(uri, method, body, headers,
|
||||||
|
redirections, connection_type)
|
||||||
|
else:
|
||||||
|
return (resp, content)
|
||||||
|
|
||||||
|
# Replace the request method with our own closure.
|
||||||
|
http.request = new_request
|
||||||
|
|
||||||
|
# Set credentials as a property of the request method.
|
||||||
|
setattr(http.request, 'credentials', self)
|
||||||
|
|
||||||
|
return http
|
||||||
|
|
||||||
|
def refresh(self, http):
|
||||||
|
"""Forces a refresh of the access_token.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
http: httplib2.Http, an http object to be used to make the refresh
|
||||||
|
request.
|
||||||
|
"""
|
||||||
|
self._refresh(http.request)
|
||||||
|
|
||||||
|
def apply(self, headers):
|
||||||
|
"""Add the authorization to the headers.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
headers: dict, the headers to add the Authorization header to.
|
||||||
|
"""
|
||||||
|
headers['Authorization'] = 'Bearer ' + self.access_token
|
||||||
|
|
||||||
def to_json(self):
|
def to_json(self):
|
||||||
return self._to_json(Credentials.NON_SERIALIZED_MEMBERS)
|
return self._to_json(Credentials.NON_SERIALIZED_MEMBERS)
|
||||||
|
|
||||||
@@ -431,6 +534,13 @@ class OAuth2Credentials(Credentials):
|
|||||||
This method first checks by reading the Storage object if available.
|
This method first checks by reading the Storage object if available.
|
||||||
If a refresh is still needed, it holds the Storage lock until the
|
If a refresh is still needed, it holds the Storage lock until the
|
||||||
refresh is completed.
|
refresh is completed.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
http_request: callable, a callable that matches the method signature of
|
||||||
|
httplib2.Http.request, used to make the refresh request.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
AccessTokenRefreshError: When the refresh fails.
|
||||||
"""
|
"""
|
||||||
if not self.store:
|
if not self.store:
|
||||||
self._do_refresh_request(http_request)
|
self._do_refresh_request(http_request)
|
||||||
@@ -451,8 +561,8 @@ class OAuth2Credentials(Credentials):
|
|||||||
"""Refresh the access_token using the refresh_token.
|
"""Refresh the access_token using the refresh_token.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
http: An instance of httplib2.Http.request
|
http_request: callable, a callable that matches the method signature of
|
||||||
or something that acts like it.
|
httplib2.Http.request, used to make the refresh request.
|
||||||
|
|
||||||
Raises:
|
Raises:
|
||||||
AccessTokenRefreshError: When the refresh fails.
|
AccessTokenRefreshError: When the refresh fails.
|
||||||
@@ -491,64 +601,6 @@ class OAuth2Credentials(Credentials):
|
|||||||
pass
|
pass
|
||||||
raise AccessTokenRefreshError(error_msg)
|
raise AccessTokenRefreshError(error_msg)
|
||||||
|
|
||||||
def authorize(self, http):
|
|
||||||
"""Authorize an httplib2.Http instance with these credentials.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
http: An instance of httplib2.Http
|
|
||||||
or something that acts like it.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
A modified instance of http that was passed in.
|
|
||||||
|
|
||||||
Example:
|
|
||||||
|
|
||||||
h = httplib2.Http()
|
|
||||||
h = credentials.authorize(h)
|
|
||||||
|
|
||||||
You can't create a new OAuth subclass of httplib2.Authenication
|
|
||||||
because it never gets passed the absolute URI, which is needed for
|
|
||||||
signing. So instead we have to overload 'request' with a closure
|
|
||||||
that adds in the Authorization header and then calls the original
|
|
||||||
version of 'request()'.
|
|
||||||
"""
|
|
||||||
request_orig = http.request
|
|
||||||
|
|
||||||
# The closure that will replace 'httplib2.Http.request'.
|
|
||||||
def new_request(uri, method='GET', body=None, headers=None,
|
|
||||||
redirections=httplib2.DEFAULT_MAX_REDIRECTS,
|
|
||||||
connection_type=None):
|
|
||||||
if not self.access_token:
|
|
||||||
logger.info('Attempting refresh to obtain initial access_token')
|
|
||||||
self._refresh(request_orig)
|
|
||||||
|
|
||||||
# Modify the request headers to add the appropriate
|
|
||||||
# Authorization header.
|
|
||||||
if headers is None:
|
|
||||||
headers = {}
|
|
||||||
headers['authorization'] = 'OAuth ' + self.access_token
|
|
||||||
|
|
||||||
if self.user_agent is not None:
|
|
||||||
if 'user-agent' in headers:
|
|
||||||
headers['user-agent'] = self.user_agent + ' ' + headers['user-agent']
|
|
||||||
else:
|
|
||||||
headers['user-agent'] = self.user_agent
|
|
||||||
|
|
||||||
resp, content = request_orig(uri, method, body, headers,
|
|
||||||
redirections, connection_type)
|
|
||||||
|
|
||||||
if resp.status == 401:
|
|
||||||
logger.info('Refreshing due to a 401')
|
|
||||||
self._refresh(request_orig)
|
|
||||||
headers['authorization'] = 'OAuth ' + self.access_token
|
|
||||||
return request_orig(uri, method, body, headers,
|
|
||||||
redirections, connection_type)
|
|
||||||
else:
|
|
||||||
return (resp, content)
|
|
||||||
|
|
||||||
http.request = new_request
|
|
||||||
return http
|
|
||||||
|
|
||||||
|
|
||||||
class AccessTokenCredentials(OAuth2Credentials):
|
class AccessTokenCredentials(OAuth2Credentials):
|
||||||
"""Credentials object for OAuth 2.0.
|
"""Credentials object for OAuth 2.0.
|
||||||
|
|||||||
@@ -35,6 +35,52 @@ from apiclient.http import MediaUpload
|
|||||||
from apiclient.http import MediaInMemoryUpload
|
from apiclient.http import MediaInMemoryUpload
|
||||||
from apiclient.http import set_user_agent
|
from apiclient.http import set_user_agent
|
||||||
from apiclient.model import JsonModel
|
from apiclient.model import JsonModel
|
||||||
|
from oauth2client.client import Credentials
|
||||||
|
|
||||||
|
|
||||||
|
class MockCredentials(Credentials):
|
||||||
|
"""Mock class for all Credentials objects."""
|
||||||
|
def __init__(self, bearer_token):
|
||||||
|
super(MockCredentials, self).__init__()
|
||||||
|
self._authorized = 0
|
||||||
|
self._refreshed = 0
|
||||||
|
self._applied = 0
|
||||||
|
self._bearer_token = bearer_token
|
||||||
|
|
||||||
|
def authorize(self, http):
|
||||||
|
self._authorized += 1
|
||||||
|
|
||||||
|
request_orig = http.request
|
||||||
|
|
||||||
|
# The closure that will replace 'httplib2.Http.request'.
|
||||||
|
def new_request(uri, method='GET', body=None, headers=None,
|
||||||
|
redirections=httplib2.DEFAULT_MAX_REDIRECTS,
|
||||||
|
connection_type=None):
|
||||||
|
# Modify the request headers to add the appropriate
|
||||||
|
# Authorization header.
|
||||||
|
if headers is None:
|
||||||
|
headers = {}
|
||||||
|
self.apply(headers)
|
||||||
|
|
||||||
|
resp, content = request_orig(uri, method, body, headers,
|
||||||
|
redirections, connection_type)
|
||||||
|
|
||||||
|
return resp, content
|
||||||
|
|
||||||
|
# Replace the request method with our own closure.
|
||||||
|
http.request = new_request
|
||||||
|
|
||||||
|
# Set credentials as a property of the request method.
|
||||||
|
setattr(http.request, 'credentials', self)
|
||||||
|
|
||||||
|
return http
|
||||||
|
|
||||||
|
def refresh(self, http):
|
||||||
|
self._refreshed += 1
|
||||||
|
|
||||||
|
def apply(self, headers):
|
||||||
|
self._applied += 1
|
||||||
|
headers['authorization'] = self._bearer_token + ' ' + str(self._refreshed)
|
||||||
|
|
||||||
|
|
||||||
DATA_DIR = os.path.join(os.path.dirname(__file__), 'data')
|
DATA_DIR = os.path.join(os.path.dirname(__file__), 'data')
|
||||||
@@ -52,7 +98,7 @@ class TestUserAgent(unittest.TestCase):
|
|||||||
|
|
||||||
http = set_user_agent(http, "my_app/5.5")
|
http = set_user_agent(http, "my_app/5.5")
|
||||||
resp, content = http.request("http://example.com")
|
resp, content = http.request("http://example.com")
|
||||||
self.assertEqual(content['user-agent'], 'my_app/5.5')
|
self.assertEqual('my_app/5.5', content['user-agent'])
|
||||||
|
|
||||||
def test_set_user_agent_nested(self):
|
def test_set_user_agent_nested(self):
|
||||||
http = HttpMockSequence([
|
http = HttpMockSequence([
|
||||||
@@ -62,25 +108,25 @@ class TestUserAgent(unittest.TestCase):
|
|||||||
http = set_user_agent(http, "my_app/5.5")
|
http = set_user_agent(http, "my_app/5.5")
|
||||||
http = set_user_agent(http, "my_library/0.1")
|
http = set_user_agent(http, "my_library/0.1")
|
||||||
resp, content = http.request("http://example.com")
|
resp, content = http.request("http://example.com")
|
||||||
self.assertEqual(content['user-agent'], 'my_app/5.5 my_library/0.1')
|
self.assertEqual('my_app/5.5 my_library/0.1', content['user-agent'])
|
||||||
|
|
||||||
def test_media_file_upload_to_from_json(self):
|
def test_media_file_upload_to_from_json(self):
|
||||||
upload = MediaFileUpload(
|
upload = MediaFileUpload(
|
||||||
datafile('small.png'), chunksize=500, resumable=True)
|
datafile('small.png'), chunksize=500, resumable=True)
|
||||||
self.assertEquals('image/png', upload.mimetype())
|
self.assertEqual('image/png', upload.mimetype())
|
||||||
self.assertEquals(190, upload.size())
|
self.assertEqual(190, upload.size())
|
||||||
self.assertEquals(True, upload.resumable())
|
self.assertEqual(True, upload.resumable())
|
||||||
self.assertEquals(500, upload.chunksize())
|
self.assertEqual(500, upload.chunksize())
|
||||||
self.assertEquals('PNG', upload.getbytes(1, 3))
|
self.assertEqual('PNG', upload.getbytes(1, 3))
|
||||||
|
|
||||||
json = upload.to_json()
|
json = upload.to_json()
|
||||||
new_upload = MediaUpload.new_from_json(json)
|
new_upload = MediaUpload.new_from_json(json)
|
||||||
|
|
||||||
self.assertEquals('image/png', new_upload.mimetype())
|
self.assertEqual('image/png', new_upload.mimetype())
|
||||||
self.assertEquals(190, new_upload.size())
|
self.assertEqual(190, new_upload.size())
|
||||||
self.assertEquals(True, new_upload.resumable())
|
self.assertEqual(True, new_upload.resumable())
|
||||||
self.assertEquals(500, new_upload.chunksize())
|
self.assertEqual(500, new_upload.chunksize())
|
||||||
self.assertEquals('PNG', new_upload.getbytes(1, 3))
|
self.assertEqual('PNG', new_upload.getbytes(1, 3))
|
||||||
|
|
||||||
def test_http_request_to_from_json(self):
|
def test_http_request_to_from_json(self):
|
||||||
|
|
||||||
@@ -103,13 +149,13 @@ class TestUserAgent(unittest.TestCase):
|
|||||||
json = req.to_json()
|
json = req.to_json()
|
||||||
new_req = HttpRequest.from_json(json, http, _postproc)
|
new_req = HttpRequest.from_json(json, http, _postproc)
|
||||||
|
|
||||||
self.assertEquals(new_req.headers,
|
self.assertEqual({'content-type':
|
||||||
{'content-type':
|
'multipart/related; boundary="---flubber"'},
|
||||||
'multipart/related; boundary="---flubber"'})
|
new_req.headers)
|
||||||
self.assertEquals(new_req.uri, 'http://example.com')
|
self.assertEqual('http://example.com', new_req.uri)
|
||||||
self.assertEquals(new_req.body, '{}')
|
self.assertEqual('{}', new_req.body)
|
||||||
self.assertEquals(new_req.http, http)
|
self.assertEqual(http, new_req.http)
|
||||||
self.assertEquals(new_req.resumable.to_json(), media_upload.to_json())
|
self.assertEqual(media_upload.to_json(), new_req.resumable.to_json())
|
||||||
|
|
||||||
EXPECTED = """POST /someapi/v1/collection/?foo=bar HTTP/1.1
|
EXPECTED = """POST /someapi/v1/collection/?foo=bar HTTP/1.1
|
||||||
Content-Type: application/json
|
Content-Type: application/json
|
||||||
@@ -153,6 +199,50 @@ ETag: "etag/sheep"\r\n\r\n{"baz": "qux"}
|
|||||||
--batch_foobarbaz--"""
|
--batch_foobarbaz--"""
|
||||||
|
|
||||||
|
|
||||||
|
BATCH_RESPONSE_WITH_401 = """--batch_foobarbaz
|
||||||
|
Content-Type: application/http
|
||||||
|
Content-Transfer-Encoding: binary
|
||||||
|
Content-ID: <randomness+1>
|
||||||
|
|
||||||
|
HTTP/1.1 401 Authoration Required
|
||||||
|
Content-Type application/json
|
||||||
|
Content-Length: 14
|
||||||
|
ETag: "etag/pony"\r\n\r\n{"error": {"message":
|
||||||
|
"Authorizaton failed."}}
|
||||||
|
|
||||||
|
--batch_foobarbaz
|
||||||
|
Content-Type: application/http
|
||||||
|
Content-Transfer-Encoding: binary
|
||||||
|
Content-ID: <randomness+2>
|
||||||
|
|
||||||
|
HTTP/1.1 200 OK
|
||||||
|
Content-Type application/json
|
||||||
|
Content-Length: 14
|
||||||
|
ETag: "etag/sheep"\r\n\r\n{"baz": "qux"}
|
||||||
|
--batch_foobarbaz--"""
|
||||||
|
|
||||||
|
|
||||||
|
BATCH_SINGLE_RESPONSE = """--batch_foobarbaz
|
||||||
|
Content-Type: application/http
|
||||||
|
Content-Transfer-Encoding: binary
|
||||||
|
Content-ID: <randomness+1>
|
||||||
|
|
||||||
|
HTTP/1.1 200 OK
|
||||||
|
Content-Type application/json
|
||||||
|
Content-Length: 14
|
||||||
|
ETag: "etag/pony"\r\n\r\n{"foo": 42}
|
||||||
|
--batch_foobarbaz--"""
|
||||||
|
|
||||||
|
class Callbacks(object):
|
||||||
|
def __init__(self):
|
||||||
|
self.responses = {}
|
||||||
|
self.exceptions = {}
|
||||||
|
|
||||||
|
def f(self, request_id, response, exception):
|
||||||
|
self.responses[request_id] = response
|
||||||
|
self.exceptions[request_id] = exception
|
||||||
|
|
||||||
|
|
||||||
class TestBatch(unittest.TestCase):
|
class TestBatch(unittest.TestCase):
|
||||||
|
|
||||||
def setUp(self):
|
def setUp(self):
|
||||||
@@ -196,7 +286,7 @@ class TestBatch(unittest.TestCase):
|
|||||||
methodId=None,
|
methodId=None,
|
||||||
resumable=None)
|
resumable=None)
|
||||||
s = batch._serialize_request(request).splitlines()
|
s = batch._serialize_request(request).splitlines()
|
||||||
self.assertEquals(s, EXPECTED.splitlines())
|
self.assertEqual(EXPECTED.splitlines(), s)
|
||||||
|
|
||||||
def test_serialize_request_media_body(self):
|
def test_serialize_request_media_body(self):
|
||||||
batch = BatchHttpRequest()
|
batch = BatchHttpRequest()
|
||||||
@@ -213,9 +303,9 @@ class TestBatch(unittest.TestCase):
|
|||||||
headers={'content-type': 'application/json'},
|
headers={'content-type': 'application/json'},
|
||||||
methodId=None,
|
methodId=None,
|
||||||
resumable=None)
|
resumable=None)
|
||||||
|
# Just testing it shouldn't raise an exception.
|
||||||
s = batch._serialize_request(request).splitlines()
|
s = batch._serialize_request(request).splitlines()
|
||||||
|
|
||||||
|
|
||||||
def test_serialize_request_no_body(self):
|
def test_serialize_request_no_body(self):
|
||||||
batch = BatchHttpRequest()
|
batch = BatchHttpRequest()
|
||||||
request = HttpRequest(
|
request = HttpRequest(
|
||||||
@@ -228,30 +318,30 @@ class TestBatch(unittest.TestCase):
|
|||||||
methodId=None,
|
methodId=None,
|
||||||
resumable=None)
|
resumable=None)
|
||||||
s = batch._serialize_request(request).splitlines()
|
s = batch._serialize_request(request).splitlines()
|
||||||
self.assertEquals(s, NO_BODY_EXPECTED.splitlines())
|
self.assertEqual(NO_BODY_EXPECTED.splitlines(), s)
|
||||||
|
|
||||||
def test_deserialize_response(self):
|
def test_deserialize_response(self):
|
||||||
batch = BatchHttpRequest()
|
batch = BatchHttpRequest()
|
||||||
resp, content = batch._deserialize_response(RESPONSE)
|
resp, content = batch._deserialize_response(RESPONSE)
|
||||||
|
|
||||||
self.assertEquals(resp.status, 200)
|
self.assertEqual(200, resp.status)
|
||||||
self.assertEquals(resp.reason, 'OK')
|
self.assertEqual('OK', resp.reason)
|
||||||
self.assertEquals(resp.version, 11)
|
self.assertEqual(11, resp.version)
|
||||||
self.assertEquals(content, '{"answer": 42}')
|
self.assertEqual('{"answer": 42}', content)
|
||||||
|
|
||||||
def test_new_id(self):
|
def test_new_id(self):
|
||||||
batch = BatchHttpRequest()
|
batch = BatchHttpRequest()
|
||||||
|
|
||||||
id_ = batch._new_id()
|
id_ = batch._new_id()
|
||||||
self.assertEquals(id_, '1')
|
self.assertEqual('1', id_)
|
||||||
|
|
||||||
id_ = batch._new_id()
|
id_ = batch._new_id()
|
||||||
self.assertEquals(id_, '2')
|
self.assertEqual('2', id_)
|
||||||
|
|
||||||
batch.add(self.request1, request_id='3')
|
batch.add(self.request1, request_id='3')
|
||||||
|
|
||||||
id_ = batch._new_id()
|
id_ = batch._new_id()
|
||||||
self.assertEquals(id_, '4')
|
self.assertEqual('4', id_)
|
||||||
|
|
||||||
def test_add(self):
|
def test_add(self):
|
||||||
batch = BatchHttpRequest()
|
batch = BatchHttpRequest()
|
||||||
@@ -267,13 +357,6 @@ class TestBatch(unittest.TestCase):
|
|||||||
self.assertRaises(BatchError, batch.add, self.request1, request_id='1')
|
self.assertRaises(BatchError, batch.add, self.request1, request_id='1')
|
||||||
|
|
||||||
def test_execute(self):
|
def test_execute(self):
|
||||||
class Callbacks(object):
|
|
||||||
def __init__(self):
|
|
||||||
self.responses = {}
|
|
||||||
|
|
||||||
def f(self, request_id, response):
|
|
||||||
self.responses[request_id] = response
|
|
||||||
|
|
||||||
batch = BatchHttpRequest()
|
batch = BatchHttpRequest()
|
||||||
callbacks = Callbacks()
|
callbacks = Callbacks()
|
||||||
|
|
||||||
@@ -285,8 +368,10 @@ class TestBatch(unittest.TestCase):
|
|||||||
BATCH_RESPONSE),
|
BATCH_RESPONSE),
|
||||||
])
|
])
|
||||||
batch.execute(http)
|
batch.execute(http)
|
||||||
self.assertEqual(callbacks.responses['1'], {'foo': 42})
|
self.assertEqual({'foo': 42}, callbacks.responses['1'])
|
||||||
self.assertEqual(callbacks.responses['2'], {'baz': 'qux'})
|
self.assertEqual(None, callbacks.exceptions['1'])
|
||||||
|
self.assertEqual({'baz': 'qux'}, callbacks.responses['2'])
|
||||||
|
self.assertEqual(None, callbacks.exceptions['2'])
|
||||||
|
|
||||||
def test_execute_request_body(self):
|
def test_execute_request_body(self):
|
||||||
batch = BatchHttpRequest()
|
batch = BatchHttpRequest()
|
||||||
@@ -311,14 +396,82 @@ class TestBatch(unittest.TestCase):
|
|||||||
header = parts[1].splitlines()[1]
|
header = parts[1].splitlines()[1]
|
||||||
self.assertEqual('Content-Type: application/http', header)
|
self.assertEqual('Content-Type: application/http', header)
|
||||||
|
|
||||||
|
def test_execute_refresh_and_retry_on_401(self):
|
||||||
|
batch = BatchHttpRequest()
|
||||||
|
callbacks = Callbacks()
|
||||||
|
cred_1 = MockCredentials('Foo')
|
||||||
|
cred_2 = MockCredentials('Bar')
|
||||||
|
|
||||||
|
http = HttpMockSequence([
|
||||||
|
({'status': '200',
|
||||||
|
'content-type': 'multipart/mixed; boundary="batch_foobarbaz"'},
|
||||||
|
BATCH_RESPONSE_WITH_401),
|
||||||
|
({'status': '200',
|
||||||
|
'content-type': 'multipart/mixed; boundary="batch_foobarbaz"'},
|
||||||
|
BATCH_SINGLE_RESPONSE),
|
||||||
|
])
|
||||||
|
|
||||||
|
creds_http_1 = HttpMockSequence([])
|
||||||
|
cred_1.authorize(creds_http_1)
|
||||||
|
|
||||||
|
creds_http_2 = HttpMockSequence([])
|
||||||
|
cred_2.authorize(creds_http_2)
|
||||||
|
|
||||||
|
self.request1.http = creds_http_1
|
||||||
|
self.request2.http = creds_http_2
|
||||||
|
|
||||||
|
batch.add(self.request1, callback=callbacks.f)
|
||||||
|
batch.add(self.request2, callback=callbacks.f)
|
||||||
|
batch.execute(http)
|
||||||
|
|
||||||
|
self.assertEqual({'foo': 42}, callbacks.responses['1'])
|
||||||
|
self.assertEqual(None, callbacks.exceptions['1'])
|
||||||
|
self.assertEqual({'baz': 'qux'}, callbacks.responses['2'])
|
||||||
|
self.assertEqual(None, callbacks.exceptions['2'])
|
||||||
|
|
||||||
|
self.assertEqual(1, cred_1._refreshed)
|
||||||
|
self.assertEqual(0, cred_2._refreshed)
|
||||||
|
|
||||||
|
self.assertEqual(1, cred_1._authorized)
|
||||||
|
self.assertEqual(1, cred_2._authorized)
|
||||||
|
|
||||||
|
self.assertEqual(1, cred_2._applied)
|
||||||
|
self.assertEqual(2, cred_1._applied)
|
||||||
|
|
||||||
|
def test_http_errors_passed_to_callback(self):
|
||||||
|
batch = BatchHttpRequest()
|
||||||
|
callbacks = Callbacks()
|
||||||
|
cred_1 = MockCredentials('Foo')
|
||||||
|
cred_2 = MockCredentials('Bar')
|
||||||
|
|
||||||
|
http = HttpMockSequence([
|
||||||
|
({'status': '200',
|
||||||
|
'content-type': 'multipart/mixed; boundary="batch_foobarbaz"'},
|
||||||
|
BATCH_RESPONSE_WITH_401),
|
||||||
|
({'status': '200',
|
||||||
|
'content-type': 'multipart/mixed; boundary="batch_foobarbaz"'},
|
||||||
|
BATCH_RESPONSE_WITH_401),
|
||||||
|
])
|
||||||
|
|
||||||
|
creds_http_1 = HttpMockSequence([])
|
||||||
|
cred_1.authorize(creds_http_1)
|
||||||
|
|
||||||
|
creds_http_2 = HttpMockSequence([])
|
||||||
|
cred_2.authorize(creds_http_2)
|
||||||
|
|
||||||
|
self.request1.http = creds_http_1
|
||||||
|
self.request2.http = creds_http_2
|
||||||
|
|
||||||
|
batch.add(self.request1, callback=callbacks.f)
|
||||||
|
batch.add(self.request2, callback=callbacks.f)
|
||||||
|
batch.execute(http)
|
||||||
|
|
||||||
|
self.assertEqual(None, callbacks.responses['1'])
|
||||||
|
self.assertEqual(401, callbacks.exceptions['1'].resp.status)
|
||||||
|
self.assertEqual({u'baz': u'qux'}, callbacks.responses['2'])
|
||||||
|
self.assertEqual(None, callbacks.exceptions['2'])
|
||||||
|
|
||||||
def test_execute_global_callback(self):
|
def test_execute_global_callback(self):
|
||||||
class Callbacks(object):
|
|
||||||
def __init__(self):
|
|
||||||
self.responses = {}
|
|
||||||
|
|
||||||
def f(self, request_id, response):
|
|
||||||
self.responses[request_id] = response
|
|
||||||
|
|
||||||
callbacks = Callbacks()
|
callbacks = Callbacks()
|
||||||
batch = BatchHttpRequest(callback=callbacks.f)
|
batch = BatchHttpRequest(callback=callbacks.f)
|
||||||
|
|
||||||
@@ -330,8 +483,8 @@ class TestBatch(unittest.TestCase):
|
|||||||
BATCH_RESPONSE),
|
BATCH_RESPONSE),
|
||||||
])
|
])
|
||||||
batch.execute(http)
|
batch.execute(http)
|
||||||
self.assertEqual(callbacks.responses['1'], {'foo': 42})
|
self.assertEqual({'foo': 42}, callbacks.responses['1'])
|
||||||
self.assertEqual(callbacks.responses['2'], {'baz': 'qux'})
|
self.assertEqual({'baz': 'qux'}, callbacks.responses['2'])
|
||||||
|
|
||||||
def test_media_inmemory_upload(self):
|
def test_media_inmemory_upload(self):
|
||||||
media = MediaInMemoryUpload('abcdef', 'text/plain', chunksize=10,
|
media = MediaInMemoryUpload('abcdef', 'text/plain', chunksize=10,
|
||||||
|
|||||||
@@ -70,7 +70,7 @@ class OAuth2CredentialsTests(unittest.TestCase):
|
|||||||
])
|
])
|
||||||
http = self.credentials.authorize(http)
|
http = self.credentials.authorize(http)
|
||||||
resp, content = http.request("http://example.com")
|
resp, content = http.request("http://example.com")
|
||||||
self.assertEqual(content['authorization'], 'OAuth 1/3w')
|
self.assertEqual('Bearer 1/3w', content['Authorization'])
|
||||||
|
|
||||||
def test_token_refresh_failure(self):
|
def test_token_refresh_failure(self):
|
||||||
http = HttpMockSequence([
|
http = HttpMockSequence([
|
||||||
@@ -95,11 +95,11 @@ class OAuth2CredentialsTests(unittest.TestCase):
|
|||||||
def test_to_from_json(self):
|
def test_to_from_json(self):
|
||||||
json = self.credentials.to_json()
|
json = self.credentials.to_json()
|
||||||
instance = OAuth2Credentials.from_json(json)
|
instance = OAuth2Credentials.from_json(json)
|
||||||
self.assertEquals(type(instance), OAuth2Credentials)
|
self.assertEqual(OAuth2Credentials, type(instance))
|
||||||
instance.token_expiry = None
|
instance.token_expiry = None
|
||||||
self.credentials.token_expiry = None
|
self.credentials.token_expiry = None
|
||||||
|
|
||||||
self.assertEquals(self.credentials.__dict__, instance.__dict__)
|
self.assertEqual(instance.__dict__, self.credentials.__dict__)
|
||||||
|
|
||||||
|
|
||||||
class AccessTokenCredentialsTests(unittest.TestCase):
|
class AccessTokenCredentialsTests(unittest.TestCase):
|
||||||
@@ -136,7 +136,7 @@ class AccessTokenCredentialsTests(unittest.TestCase):
|
|||||||
])
|
])
|
||||||
http = self.credentials.authorize(http)
|
http = self.credentials.authorize(http)
|
||||||
resp, content = http.request('http://example.com')
|
resp, content = http.request('http://example.com')
|
||||||
self.assertEqual(content['authorization'], 'OAuth foo')
|
self.assertEqual('Bearer foo', content['Authorization'])
|
||||||
|
|
||||||
|
|
||||||
class TestAssertionCredentials(unittest.TestCase):
|
class TestAssertionCredentials(unittest.TestCase):
|
||||||
@@ -155,8 +155,8 @@ class TestAssertionCredentials(unittest.TestCase):
|
|||||||
|
|
||||||
def test_assertion_body(self):
|
def test_assertion_body(self):
|
||||||
body = urlparse.parse_qs(self.credentials._generate_refresh_request_body())
|
body = urlparse.parse_qs(self.credentials._generate_refresh_request_body())
|
||||||
self.assertEqual(body['assertion'][0], self.assertion_text)
|
self.assertEqual(self.assertion_text, body['assertion'][0])
|
||||||
self.assertEqual(body['assertion_type'][0], self.assertion_type)
|
self.assertEqual(self.assertion_type, body['assertion_type'][0])
|
||||||
|
|
||||||
def test_assertion_refresh(self):
|
def test_assertion_refresh(self):
|
||||||
http = HttpMockSequence([
|
http = HttpMockSequence([
|
||||||
@@ -165,7 +165,7 @@ class TestAssertionCredentials(unittest.TestCase):
|
|||||||
])
|
])
|
||||||
http = self.credentials.authorize(http)
|
http = self.credentials.authorize(http)
|
||||||
resp, content = http.request("http://example.com")
|
resp, content = http.request("http://example.com")
|
||||||
self.assertEqual(content['authorization'], 'OAuth 1/3w')
|
self.assertEqual('Bearer 1/3w', content['Authorization'])
|
||||||
|
|
||||||
|
|
||||||
class ExtractIdTokenText(unittest.TestCase):
|
class ExtractIdTokenText(unittest.TestCase):
|
||||||
@@ -177,7 +177,7 @@ class ExtractIdTokenText(unittest.TestCase):
|
|||||||
jwt = 'stuff.' + payload + '.signature'
|
jwt = 'stuff.' + payload + '.signature'
|
||||||
|
|
||||||
extracted = _extract_id_token(jwt)
|
extracted = _extract_id_token(jwt)
|
||||||
self.assertEqual(body, extracted)
|
self.assertEqual(extracted, body)
|
||||||
|
|
||||||
def test_extract_failure(self):
|
def test_extract_failure(self):
|
||||||
body = {'foo': 'bar'}
|
body = {'foo': 'bar'}
|
||||||
@@ -201,11 +201,11 @@ class OAuth2WebServerFlowTest(unittest.TestCase):
|
|||||||
|
|
||||||
parsed = urlparse.urlparse(authorize_url)
|
parsed = urlparse.urlparse(authorize_url)
|
||||||
q = parse_qs(parsed[4])
|
q = parse_qs(parsed[4])
|
||||||
self.assertEqual(q['client_id'][0], 'client_id+1')
|
self.assertEqual('client_id+1', q['client_id'][0])
|
||||||
self.assertEqual(q['response_type'][0], 'code')
|
self.assertEqual('code', q['response_type'][0])
|
||||||
self.assertEqual(q['scope'][0], 'foo')
|
self.assertEqual('foo', q['scope'][0])
|
||||||
self.assertEqual(q['redirect_uri'][0], 'OOB_CALLBACK_URN')
|
self.assertEqual('OOB_CALLBACK_URN', q['redirect_uri'][0])
|
||||||
self.assertEqual(q['access_type'][0], 'offline')
|
self.assertEqual('offline', q['access_type'][0])
|
||||||
|
|
||||||
def test_override_flow_access_type(self):
|
def test_override_flow_access_type(self):
|
||||||
"""Passing access_type overrides the default."""
|
"""Passing access_type overrides the default."""
|
||||||
@@ -220,11 +220,11 @@ class OAuth2WebServerFlowTest(unittest.TestCase):
|
|||||||
|
|
||||||
parsed = urlparse.urlparse(authorize_url)
|
parsed = urlparse.urlparse(authorize_url)
|
||||||
q = parse_qs(parsed[4])
|
q = parse_qs(parsed[4])
|
||||||
self.assertEqual(q['client_id'][0], 'client_id+1')
|
self.assertEqual('client_id+1', q['client_id'][0])
|
||||||
self.assertEqual(q['response_type'][0], 'code')
|
self.assertEqual('code', q['response_type'][0])
|
||||||
self.assertEqual(q['scope'][0], 'foo')
|
self.assertEqual('foo', q['scope'][0])
|
||||||
self.assertEqual(q['redirect_uri'][0], 'OOB_CALLBACK_URN')
|
self.assertEqual('OOB_CALLBACK_URN', q['redirect_uri'][0])
|
||||||
self.assertEqual(q['access_type'][0], 'online')
|
self.assertEqual('online', q['access_type'][0])
|
||||||
|
|
||||||
def test_exchange_failure(self):
|
def test_exchange_failure(self):
|
||||||
http = HttpMockSequence([
|
http = HttpMockSequence([
|
||||||
@@ -246,9 +246,9 @@ class OAuth2WebServerFlowTest(unittest.TestCase):
|
|||||||
])
|
])
|
||||||
|
|
||||||
credentials = self.flow.step2_exchange('some random code', http)
|
credentials = self.flow.step2_exchange('some random code', http)
|
||||||
self.assertEqual(credentials.access_token, 'SlAV32hkKG')
|
self.assertEqual('SlAV32hkKG', credentials.access_token)
|
||||||
self.assertNotEqual(credentials.token_expiry, None)
|
self.assertNotEqual(None, credentials.token_expiry)
|
||||||
self.assertEqual(credentials.refresh_token, '8xLOxBtZp8')
|
self.assertEqual('8xLOxBtZp8', credentials.refresh_token)
|
||||||
|
|
||||||
def test_exchange_no_expires_in(self):
|
def test_exchange_no_expires_in(self):
|
||||||
http = HttpMockSequence([
|
http = HttpMockSequence([
|
||||||
@@ -257,7 +257,7 @@ class OAuth2WebServerFlowTest(unittest.TestCase):
|
|||||||
])
|
])
|
||||||
|
|
||||||
credentials = self.flow.step2_exchange('some random code', http)
|
credentials = self.flow.step2_exchange('some random code', http)
|
||||||
self.assertEqual(credentials.token_expiry, None)
|
self.assertEqual(None, credentials.token_expiry)
|
||||||
|
|
||||||
def test_exchange_id_token_fail(self):
|
def test_exchange_id_token_fail(self):
|
||||||
http = HttpMockSequence([
|
http = HttpMockSequence([
|
||||||
@@ -282,7 +282,7 @@ class OAuth2WebServerFlowTest(unittest.TestCase):
|
|||||||
])
|
])
|
||||||
|
|
||||||
credentials = self.flow.step2_exchange('some random code', http)
|
credentials = self.flow.step2_exchange('some random code', http)
|
||||||
self.assertEquals(body, credentials.id_token)
|
self.assertEqual(credentials.id_token, body)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
|
|||||||
@@ -100,8 +100,8 @@ class CryptTests(unittest.TestCase):
|
|||||||
certs = {'foo': public_key }
|
certs = {'foo': public_key }
|
||||||
audience = 'some_audience_address@testing.gserviceaccount.com'
|
audience = 'some_audience_address@testing.gserviceaccount.com'
|
||||||
contents = crypt.verify_signed_jwt_with_certs(jwt, certs, audience)
|
contents = crypt.verify_signed_jwt_with_certs(jwt, certs, audience)
|
||||||
self.assertEquals('billy bob', contents['user'])
|
self.assertEqual('billy bob', contents['user'])
|
||||||
self.assertEquals('data', contents['metadata']['meta'])
|
self.assertEqual('data', contents['metadata']['meta'])
|
||||||
|
|
||||||
def test_verify_id_token_with_certs_uri(self):
|
def test_verify_id_token_with_certs_uri(self):
|
||||||
jwt = self._create_signed_jwt()
|
jwt = self._create_signed_jwt()
|
||||||
@@ -112,8 +112,8 @@ class CryptTests(unittest.TestCase):
|
|||||||
|
|
||||||
contents = verify_id_token(jwt,
|
contents = verify_id_token(jwt,
|
||||||
'some_audience_address@testing.gserviceaccount.com', http)
|
'some_audience_address@testing.gserviceaccount.com', http)
|
||||||
self.assertEquals('billy bob', contents['user'])
|
self.assertEqual('billy bob', contents['user'])
|
||||||
self.assertEquals('data', contents['metadata']['meta'])
|
self.assertEqual('data', contents['metadata']['meta'])
|
||||||
|
|
||||||
def test_verify_id_token_with_certs_uri_fails(self):
|
def test_verify_id_token_with_certs_uri_fails(self):
|
||||||
jwt = self._create_signed_jwt()
|
jwt = self._create_signed_jwt()
|
||||||
@@ -195,7 +195,7 @@ class CryptTests(unittest.TestCase):
|
|||||||
])
|
])
|
||||||
http = credentials.authorize(http)
|
http = credentials.authorize(http)
|
||||||
resp, content = http.request('http://example.org')
|
resp, content = http.request('http://example.org')
|
||||||
self.assertEquals(content['authorization'], 'OAuth 1/3w')
|
self.assertEqual('Bearer 1/3w', content['Authorization'])
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
unittest.main()
|
unittest.main()
|
||||||
|
|||||||
Reference in New Issue
Block a user