Flows no longer need to be saved between uses.

Also introduces util.positional declarations.

Reviewed in http://codereview.appspot.com/6441056/.

Fixes issue #136.
This commit is contained in:
Joe Gregorio
2012-08-03 16:17:40 -04:00
parent ba5c790bd0
commit 68a8cfe26d
35 changed files with 569 additions and 624 deletions

View File

@@ -57,6 +57,7 @@ from apiclient.model import RawModel
from apiclient.schema import Schemas from apiclient.schema import Schemas
from email.mime.multipart import MIMEMultipart from email.mime.multipart import MIMEMultipart
from email.mime.nonmultipart import MIMENonMultipart from email.mime.nonmultipart import MIMENonMultipart
from oauth2client import util
from oauth2client.anyjson import simplejson from oauth2client.anyjson import simplejson
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -139,6 +140,7 @@ def key2param(key):
return ''.join(result) return ''.join(result)
@util.positional(2)
def build(serviceName, def build(serviceName,
version, version,
http=None, http=None,
@@ -194,7 +196,7 @@ def build(serviceName,
raise UnknownApiNameOrVersion("name: %s version: %s" % (serviceName, raise UnknownApiNameOrVersion("name: %s version: %s" % (serviceName,
version)) version))
if resp.status >= 400: if resp.status >= 400:
raise HttpError(resp, content, requested_url) raise HttpError(resp, content, uri=requested_url)
try: try:
service = simplejson.loads(content) service = simplejson.loads(content)
@@ -202,10 +204,11 @@ def build(serviceName,
logger.error('Failed to parse as JSON: ' + content) logger.error('Failed to parse as JSON: ' + content)
raise InvalidJsonError() raise InvalidJsonError()
return build_from_document(content, discoveryServiceUrl, http=http, return build_from_document(content, base=discoveryServiceUrl, http=http,
developerKey=developerKey, model=model, requestBuilder=requestBuilder) developerKey=developerKey, model=model, requestBuilder=requestBuilder)
@util.positional(1)
def build_from_document( def build_from_document(
service, service,
base=None, base=None,
@@ -529,7 +532,8 @@ def _createResource(http, baseUrl, model, requestBuilder,
raise UnknownFileType(media_filename) raise UnknownFileType(media_filename)
if not mimeparse.best_match([media_mime_type], ','.join(accept)): if not mimeparse.best_match([media_mime_type], ','.join(accept)):
raise UnacceptableMimeTypeError(media_mime_type) raise UnacceptableMimeTypeError(media_mime_type)
media_upload = MediaFileUpload(media_filename, media_mime_type) media_upload = MediaFileUpload(media_filename,
mimetype=media_mime_type)
elif isinstance(media_filename, MediaUpload): elif isinstance(media_filename, MediaUpload):
media_upload = media_filename media_upload = media_filename
else: else:

View File

@@ -23,6 +23,7 @@ should be defined in this file.
__author__ = 'jcgregorio@google.com (Joe Gregorio)' __author__ = 'jcgregorio@google.com (Joe Gregorio)'
from oauth2client import util
from oauth2client.anyjson import simplejson from oauth2client.anyjson import simplejson
@@ -34,6 +35,7 @@ class Error(Exception):
class HttpError(Error): class HttpError(Error):
"""HTTP data was invalid or unexpected.""" """HTTP data was invalid or unexpected."""
@util.positional(3)
def __init__(self, resp, content, uri=None): def __init__(self, resp, content, uri=None):
self.resp = resp self.resp = resp
self.content = content self.content = content
@@ -92,6 +94,7 @@ class ResumableUploadError(Error):
class BatchError(HttpError): class BatchError(HttpError):
"""Error occured during batch operations.""" """Error occured during batch operations."""
@util.positional(2)
def __init__(self, reason, resp=None, content=None): def __init__(self, reason, resp=None, content=None):
self.resp = resp self.resp = resp
self.content = content self.content = content
@@ -106,6 +109,7 @@ class BatchError(HttpError):
class UnexpectedMethodError(Error): class UnexpectedMethodError(Error):
"""Exception raised by RequestMockBuilder on unexpected calls.""" """Exception raised by RequestMockBuilder on unexpected calls."""
@util.positional(1)
def __init__(self, methodId=None): def __init__(self, methodId=None):
"""Constructor for an UnexpectedMethodError.""" """Constructor for an UnexpectedMethodError."""
super(UnexpectedMethodError, self).__init__( super(UnexpectedMethodError, self).__init__(

View File

@@ -43,6 +43,7 @@ from errors import ResumableUploadError
from errors import UnexpectedBodyError from errors import UnexpectedBodyError
from errors import UnexpectedMethodError from errors import UnexpectedMethodError
from model import JsonModel from model import JsonModel
from oauth2client import util
from oauth2client.anyjson import simplejson from oauth2client.anyjson import simplejson
@@ -162,6 +163,7 @@ class MediaUpload(object):
""" """
raise NotImplementedError() raise NotImplementedError()
@util.positional(1)
def _to_json(self, strip=None): def _to_json(self, strip=None):
"""Utility function for creating a JSON representation of a MediaUpload. """Utility function for creating a JSON representation of a MediaUpload.
@@ -226,6 +228,7 @@ class MediaFileUpload(MediaUpload):
media_body=media).execute() media_body=media).execute()
""" """
@util.positional(2)
def __init__(self, filename, mimetype=None, chunksize=DEFAULT_CHUNK_SIZE, resumable=False): def __init__(self, filename, mimetype=None, chunksize=DEFAULT_CHUNK_SIZE, resumable=False):
"""Constructor. """Constructor.
@@ -302,13 +305,13 @@ class MediaFileUpload(MediaUpload):
string, a JSON representation of this instance, suitable to pass to string, a JSON representation of this instance, suitable to pass to
from_json(). from_json().
""" """
return self._to_json(['_fd']) return self._to_json(strip=['_fd'])
@staticmethod @staticmethod
def from_json(s): def from_json(s):
d = simplejson.loads(s) d = simplejson.loads(s)
return MediaFileUpload( return MediaFileUpload(d['_filename'], mimetype=d['_mimetype'],
d['_filename'], d['_mimetype'], d['_chunksize'], d['_resumable']) chunksize=d['_chunksize'], resumable=d['_resumable'])
class MediaIoBaseUpload(MediaUpload): class MediaIoBaseUpload(MediaUpload):
@@ -326,6 +329,7 @@ class MediaIoBaseUpload(MediaUpload):
media_body=media).execute() media_body=media).execute()
""" """
@util.positional(3)
def __init__(self, fd, mimetype, chunksize=DEFAULT_CHUNK_SIZE, def __init__(self, fd, mimetype, chunksize=DEFAULT_CHUNK_SIZE,
resumable=False): resumable=False):
"""Constructor. """Constructor.
@@ -414,6 +418,7 @@ class MediaInMemoryUpload(MediaUpload):
method. method.
""" """
@util.positional(2)
def __init__(self, body, mimetype='application/octet-stream', def __init__(self, body, mimetype='application/octet-stream',
chunksize=DEFAULT_CHUNK_SIZE, resumable=False): chunksize=DEFAULT_CHUNK_SIZE, resumable=False):
"""Create a new MediaBytesUpload. """Create a new MediaBytesUpload.
@@ -496,8 +501,9 @@ class MediaInMemoryUpload(MediaUpload):
def from_json(s): def from_json(s):
d = simplejson.loads(s) d = simplejson.loads(s)
return MediaInMemoryUpload(base64.b64decode(d['_b64body']), return MediaInMemoryUpload(base64.b64decode(d['_b64body']),
d['_mimetype'], d['_chunksize'], mimetype=d['_mimetype'],
d['_resumable']) chunksize=d['_chunksize'],
resumable=d['_resumable'])
class MediaIoBaseDownload(object): class MediaIoBaseDownload(object):
@@ -520,6 +526,7 @@ class MediaIoBaseDownload(object):
print "Download Complete!" print "Download Complete!"
""" """
@util.positional(3)
def __init__(self, fd, request, chunksize=DEFAULT_CHUNK_SIZE): def __init__(self, fd, request, chunksize=DEFAULT_CHUNK_SIZE):
"""Constructor. """Constructor.
@@ -574,12 +581,13 @@ class MediaIoBaseDownload(object):
self._done = True self._done = True
return MediaDownloadProgress(self._progress, self._total_size), self._done return MediaDownloadProgress(self._progress, self._total_size), self._done
else: else:
raise HttpError(resp, content, self._uri) raise HttpError(resp, content, uri=self._uri)
class HttpRequest(object): class HttpRequest(object):
"""Encapsulates a single HTTP request.""" """Encapsulates a single HTTP request."""
@util.positional(4)
def __init__(self, http, postproc, uri, def __init__(self, http, postproc, uri,
method='GET', method='GET',
body=None, body=None,
@@ -623,6 +631,7 @@ class HttpRequest(object):
# The bytes that have been uploaded. # The bytes that have been uploaded.
self.resumable_progress = 0 self.resumable_progress = 0
@util.positional(1)
def execute(self, http=None): def execute(self, http=None):
"""Execute the request. """Execute the request.
@@ -643,7 +652,7 @@ class HttpRequest(object):
if self.resumable: if self.resumable:
body = None body = None
while body is None: while body is None:
_, body = self.next_chunk(http) _, body = self.next_chunk(http=http)
return body return body
else: else:
if 'content-length' not in self.headers: if 'content-length' not in self.headers:
@@ -661,13 +670,14 @@ class HttpRequest(object):
self.body = parsed.query self.body = parsed.query
self.headers['content-length'] = str(len(self.body)) self.headers['content-length'] = str(len(self.body))
resp, content = http.request(self.uri, self.method, resp, content = http.request(self.uri, method=self.method,
body=self.body, body=self.body,
headers=self.headers) headers=self.headers)
if resp.status >= 300: if resp.status >= 300:
raise HttpError(resp, content, self.uri) raise HttpError(resp, content, uri=self.uri)
return self.postproc(resp, content) return self.postproc(resp, content)
@util.positional(1)
def next_chunk(self, http=None): def next_chunk(self, http=None):
"""Execute the next step of a resumable upload. """Execute the next step of a resumable upload.
@@ -782,7 +792,7 @@ class HttpRequest(object):
self.resumable_uri = resp['location'] self.resumable_uri = resp['location']
else: else:
self._in_error_state = True self._in_error_state = True
raise HttpError(resp, content, self.uri) raise HttpError(resp, content, uri=self.uri)
return (MediaUploadProgress(self.resumable_progress, self.resumable.size()), return (MediaUploadProgress(self.resumable_progress, self.resumable.size()),
None) None)
@@ -844,9 +854,10 @@ class BatchHttpRequest(object):
batch.add(service.animals().list(), list_animals) batch.add(service.animals().list(), list_animals)
batch.add(service.farmers().list(), list_farmers) batch.add(service.farmers().list(), list_farmers)
batch.execute(http) batch.execute(http=http)
""" """
@util.positional(1)
def __init__(self, callback=None, batch_uri=None): def __init__(self, callback=None, batch_uri=None):
"""Constructor for a BatchHttpRequest. """Constructor for a BatchHttpRequest.
@@ -1042,6 +1053,7 @@ class BatchHttpRequest(object):
self._last_auto_id += 1 self._last_auto_id += 1
return str(self._last_auto_id) return str(self._last_auto_id)
@util.positional(2)
def add(self, request, callback=None, request_id=None): def add(self, request, callback=None, request_id=None):
"""Add a new request. """Add a new request.
@@ -1119,7 +1131,7 @@ class BatchHttpRequest(object):
headers=headers) headers=headers)
if resp.status >= 300: if resp.status >= 300:
raise HttpError(resp, content, self._batch_uri) raise HttpError(resp, content, uri=self._batch_uri)
# Now break out the individual responses and store each one. # Now break out the individual responses and store each one.
boundary, _ = content.split(None, 1) boundary, _ = content.split(None, 1)
@@ -1133,14 +1145,15 @@ class BatchHttpRequest(object):
mime_response = parser.close() mime_response = parser.close()
if not mime_response.is_multipart(): if not mime_response.is_multipart():
raise BatchError("Response not in multipart/mixed format.", resp, raise BatchError("Response not in multipart/mixed format.", resp=resp,
content) content=content)
for part in mime_response.get_payload(): for part in mime_response.get_payload():
request_id = self._header_to_id(part['Content-ID']) request_id = self._header_to_id(part['Content-ID'])
response, content = self._deserialize_response(part.get_payload()) response, content = self._deserialize_response(part.get_payload())
self._responses[request_id] = (response, content) self._responses[request_id] = (response, content)
@util.positional(1)
def execute(self, http=None): def execute(self, http=None):
"""Execute all the requests as a single batched HTTP request. """Execute all the requests as a single batched HTTP request.
@@ -1200,7 +1213,7 @@ class BatchHttpRequest(object):
exception = None exception = None
try: try:
if resp.status >= 300: if resp.status >= 300:
raise HttpError(resp, content, request.uri) raise HttpError(resp, content, uri=request.uri)
response = request.postproc(resp, content) response = request.postproc(resp, content)
except HttpError, e: except HttpError, e:
exception = e exception = e
@@ -1310,7 +1323,7 @@ class RequestMockBuilder(object):
raise UnexpectedBodyError(expected_body, body) raise UnexpectedBodyError(expected_body, body)
return HttpRequestMock(resp, content, postproc) return HttpRequestMock(resp, content, postproc)
elif self.check_unexpected: elif self.check_unexpected:
raise UnexpectedMethodError(methodId) raise UnexpectedMethodError(methodId=methodId)
else: else:
model = JsonModel(False) model = JsonModel(False)
return HttpRequestMock(None, '{}', model.response) return HttpRequestMock(None, '{}', model.response)

View File

@@ -62,6 +62,8 @@ The constructor takes a discovery document in which to look up named schema.
__author__ = 'jcgregorio@google.com (Joe Gregorio)' __author__ = 'jcgregorio@google.com (Joe Gregorio)'
import copy import copy
from oauth2client import util
from oauth2client.anyjson import simplejson from oauth2client.anyjson import simplejson
@@ -80,6 +82,7 @@ class Schemas(object):
# Cache of pretty printed schemas. # Cache of pretty printed schemas.
self.pretty = {} self.pretty = {}
@util.positional(2)
def _prettyPrintByName(self, name, seen=None, dent=0): def _prettyPrintByName(self, name, seen=None, dent=0):
"""Get pretty printed object prototype from the schema name. """Get pretty printed object prototype from the schema name.
@@ -102,7 +105,7 @@ class Schemas(object):
if name not in self.pretty: if name not in self.pretty:
self.pretty[name] = _SchemaToStruct(self.schemas[name], self.pretty[name] = _SchemaToStruct(self.schemas[name],
seen, dent).to_str(self._prettyPrintByName) seen, dent=dent).to_str(self._prettyPrintByName)
seen.pop() seen.pop()
@@ -121,6 +124,7 @@ class Schemas(object):
# Return with trailing comma and newline removed. # Return with trailing comma and newline removed.
return self._prettyPrintByName(name, seen=[], dent=1)[:-2] return self._prettyPrintByName(name, seen=[], dent=1)[:-2]
@util.positional(2)
def _prettyPrintSchema(self, schema, seen=None, dent=0): def _prettyPrintSchema(self, schema, seen=None, dent=0):
"""Get pretty printed object prototype of schema. """Get pretty printed object prototype of schema.
@@ -136,7 +140,7 @@ class Schemas(object):
if seen is None: if seen is None:
seen = [] seen = []
return _SchemaToStruct(schema, seen, dent).to_str(self._prettyPrintByName) return _SchemaToStruct(schema, seen, dent=dent).to_str(self._prettyPrintByName)
def prettyPrintSchema(self, schema): def prettyPrintSchema(self, schema):
"""Get pretty printed object prototype of schema. """Get pretty printed object prototype of schema.
@@ -163,6 +167,7 @@ class Schemas(object):
class _SchemaToStruct(object): class _SchemaToStruct(object):
"""Convert schema to a prototype object.""" """Convert schema to a prototype object."""
@util.positional(3)
def __init__(self, schema, seen, dent=0): def __init__(self, schema, seen, dent=0):
"""Constructor. """Constructor.
@@ -256,7 +261,7 @@ class _SchemaToStruct(object):
elif '$ref' in schema: elif '$ref' in schema:
schemaName = schema['$ref'] schemaName = schema['$ref']
description = schema.get('description', '') description = schema.get('description', '')
s = self.from_cache(schemaName, self.seen) s = self.from_cache(schemaName, seen=self.seen)
parts = s.splitlines() parts = s.splitlines()
self.emitEnd(parts[0], description) self.emitEnd(parts[0], description)
for line in parts[1:]: for line in parts[1:]:

View File

@@ -27,21 +27,20 @@ import time
import clientsecrets import clientsecrets
from anyjson import simplejson
from client import AccessTokenRefreshError
from client import AssertionCredentials
from client import Credentials
from client import Flow
from client import OAuth2WebServerFlow
from client import Storage
from google.appengine.api import memcache
from google.appengine.api import users
from google.appengine.api import app_identity from google.appengine.api import app_identity
from google.appengine.api import users
from google.appengine.ext import db from google.appengine.ext import db
from google.appengine.ext import webapp from google.appengine.ext import webapp
from google.appengine.ext.webapp.util import login_required from google.appengine.ext.webapp.util import login_required
from google.appengine.ext.webapp.util import run_wsgi_app from google.appengine.ext.webapp.util import run_wsgi_app
from oauth2client import util
from oauth2client.anyjson import simplejson
from oauth2client.client import AccessTokenRefreshError
from oauth2client.client import AssertionCredentials
from oauth2client.client import Credentials
from oauth2client.client import Flow
from oauth2client.client import OAuth2WebServerFlow
from oauth2client.client import Storage
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -66,6 +65,7 @@ class AppAssertionCredentials(AssertionCredentials):
generate and refresh its own access tokens. generate and refresh its own access tokens.
""" """
@util.positional(2)
def __init__(self, scope, **kwargs): def __init__(self, scope, **kwargs):
"""Constructor for AppAssertionCredentials """Constructor for AppAssertionCredentials
@@ -77,9 +77,8 @@ class AppAssertionCredentials(AssertionCredentials):
self.scope = scope self.scope = scope
super(AppAssertionCredentials, self).__init__( super(AppAssertionCredentials, self).__init__(
None, 'ignored' # assertion_type is ignore in this subclass.
None, )
None)
@classmethod @classmethod
def from_json(cls, json): def from_json(cls, json):
@@ -195,6 +194,7 @@ class StorageByKeyName(Storage):
are stored by key_name. are stored by key_name.
""" """
@util.positional(4)
def __init__(self, model, key_name, property_name, cache=None): def __init__(self, model, key_name, property_name, cache=None):
"""Constructor for Storage. """Constructor for Storage.
@@ -286,11 +286,14 @@ class OAuth2Decorator(object):
""" """
@util.positional(4)
def __init__(self, client_id, client_secret, scope, def __init__(self, client_id, client_secret, scope,
auth_uri='https://accounts.google.com/o/oauth2/auth', auth_uri='https://accounts.google.com/o/oauth2/auth',
token_uri='https://accounts.google.com/o/oauth2/token', token_uri='https://accounts.google.com/o/oauth2/token',
user_agent=None, user_agent=None,
message=None, **kwargs): message=None,
callback_path='/oauth2callback',
**kwargs):
"""Constructor for OAuth2Decorator """Constructor for OAuth2Decorator
@@ -307,15 +310,24 @@ class OAuth2Decorator(object):
message: Message to display if there are problems with the OAuth 2.0 message: Message to display if there are problems with the OAuth 2.0
configuration. The message may contain HTML and will be presented on the configuration. The message may contain HTML and will be presented on the
web interface for any method that uses the decorator. web interface for any method that uses the decorator.
callback_path: string, The absolute path to use as the callback URI. Note
that this must match up with the URI given when registering the
application in the APIs Console.
**kwargs: dict, Keyword arguments are be passed along as kwargs to the **kwargs: dict, Keyword arguments are be passed along as kwargs to the
OAuth2WebServerFlow constructor. OAuth2WebServerFlow constructor.
""" """
self.flow = OAuth2WebServerFlow(client_id, client_secret, scope, user_agent, self.flow = None
auth_uri, token_uri, **kwargs)
self.credentials = None self.credentials = None
self._request_handler = None self._client_id = client_id
self._client_secret = client_secret
self._scope = scope
self._auth_uri = auth_uri
self._token_uri = token_uri
self._user_agent = user_agent
self._kwargs = kwargs
self._message = message self._message = message
self._in_error = False self._in_error = False
self._callback_path = callback_path
def _display_error_message(self, request_handler): def _display_error_message(self, request_handler):
request_handler.response.out.write('<html><body>') request_handler.response.out.write('<html><body>')
@@ -344,9 +356,11 @@ class OAuth2Decorator(object):
request_handler.redirect(users.create_login_url( request_handler.redirect(users.create_login_url(
request_handler.request.uri)) request_handler.request.uri))
return return
self._create_flow(request_handler)
# Store the request URI in 'state' so we can use it later # Store the request URI in 'state' so we can use it later
self.flow.params['state'] = request_handler.request.url self.flow.params['state'] = request_handler.request.url
self._request_handler = request_handler
self.credentials = StorageByKeyName( self.credentials = StorageByKeyName(
CredentialsModel, user.user_id(), 'credentials').get() CredentialsModel, user.user_id(), 'credentials').get()
@@ -359,6 +373,26 @@ class OAuth2Decorator(object):
return check_oauth return check_oauth
def _create_flow(self, request_handler):
"""Create the Flow object.
The Flow is calculated lazily since we don't know where this app is
running until it receives a request, at which point redirect_uri can be
calculated and then the Flow object can be constructed.
Args:
request_handler: webapp.RequestHandler, the request handler.
"""
if self.flow is None:
redirect_uri = request_handler.request.relative_url(
self._callback_path) # Usually /oauth2callback
self.flow = OAuth2WebServerFlow(self._client_id, self._client_secret,
self._scope, redirect_uri=redirect_uri,
user_agent=self._user_agent,
auth_uri=self._auth_uri,
token_uri=self._token_uri, **self._kwargs)
def oauth_aware(self, method): def oauth_aware(self, method):
"""Decorator that sets up for OAuth 2.0 dance, but doesn't do it. """Decorator that sets up for OAuth 2.0 dance, but doesn't do it.
@@ -385,9 +419,9 @@ class OAuth2Decorator(object):
request_handler.request.uri)) request_handler.request.uri))
return return
self._create_flow(request_handler)
self.flow.params['state'] = request_handler.request.url self.flow.params['state'] = request_handler.request.url
self._request_handler = request_handler
self.credentials = StorageByKeyName( self.credentials = StorageByKeyName(
CredentialsModel, user.user_id(), 'credentials').get() CredentialsModel, user.user_id(), 'credentials').get()
method(request_handler, *args, **kwargs) method(request_handler, *args, **kwargs)
@@ -407,11 +441,7 @@ class OAuth2Decorator(object):
Must only be called from with a webapp.RequestHandler subclassed method Must only be called from with a webapp.RequestHandler subclassed method
that had been decorated with either @oauth_required or @oauth_aware. that had been decorated with either @oauth_required or @oauth_aware.
""" """
callback = self._request_handler.request.relative_url('/oauth2callback') url = self.flow.step1_get_authorize_url()
url = self.flow.step1_get_authorize_url(callback)
user = users.get_current_user()
memcache.set(user.user_id(), pickle.dumps(self.flow),
namespace=OAUTH2CLIENT_NAMESPACE)
return str(url) return str(url)
def http(self): def http(self):
@@ -423,6 +453,70 @@ class OAuth2Decorator(object):
""" """
return self.credentials.authorize(httplib2.Http()) return self.credentials.authorize(httplib2.Http())
@property
def callback_path(self):
"""The absolute path where the callback will occur.
Note this is the absolute path, not the absolute URI, that will be
calculated by the decorator at runtime. See callback_handler() for how this
should be used.
Returns:
The callback path as a string.
"""
return self._callback_path
def callback_handler(self):
"""RequestHandler for the OAuth 2.0 redirect callback.
Usage:
app = webapp.WSGIApplication([
('/index', MyIndexHandler),
...,
(decorator.callback_path, decorator.callback_handler())
])
Returns:
A webapp.RequestHandler that handles the redirect back from the
server during the OAuth 2.0 dance.
"""
decorator = self
class OAuth2Handler(webapp.RequestHandler):
"""Handler for the redirect_uri of the OAuth 2.0 dance."""
@login_required
def get(self):
error = self.request.get('error')
if error:
errormsg = self.request.get('error_description', error)
self.response.out.write(
'The authorization request failed: %s' % errormsg)
else:
user = users.get_current_user()
decorator._create_flow(self)
credentials = decorator.flow.step2_exchange(self.request.params)
StorageByKeyName(
CredentialsModel, user.user_id(), 'credentials').put(credentials)
self.redirect(str(self.request.get('state')))
return OAuth2Handler
def callback_application(self):
"""WSGI application for handling the OAuth 2.0 redirect callback.
If you need finer grained control use `callback_handler` which returns just
the webapp.RequestHandler.
Returns:
A webapp.WSGIApplication that handles the redirect back from the
server during the OAuth 2.0 dance.
"""
return webapp.WSGIApplication([
(self.callback_path, self.callback_handler())
])
class OAuth2DecoratorFromClientSecrets(OAuth2Decorator): class OAuth2DecoratorFromClientSecrets(OAuth2Decorator):
"""An OAuth2Decorator that builds from a clientsecrets file. """An OAuth2Decorator that builds from a clientsecrets file.
@@ -446,6 +540,7 @@ class OAuth2DecoratorFromClientSecrets(OAuth2Decorator):
# in API calls # in API calls
""" """
@util.positional(3)
def __init__(self, filename, scope, message=None, cache=None): def __init__(self, filename, scope, message=None, cache=None):
"""Constructor """Constructor
@@ -457,7 +552,7 @@ class OAuth2DecoratorFromClientSecrets(OAuth2Decorator):
clientsecrets file is missing or invalid. The message may contain HTML and clientsecrets file is missing or invalid. The message may contain HTML and
will be presented on the web interface for any method that uses the will be presented on the web interface for any method that uses the
decorator. decorator.
cache: An optional cache service client that implements get() and set() cache: An optional cache service client that implements get() and set()
methods. See clientsecrets.loadfile() for details. methods. See clientsecrets.loadfile() for details.
""" """
try: try:
@@ -469,9 +564,9 @@ class OAuth2DecoratorFromClientSecrets(OAuth2Decorator):
client_info['client_id'], client_info['client_id'],
client_info['client_secret'], client_info['client_secret'],
scope, scope,
client_info['auth_uri'], auth_uri=client_info['auth_uri'],
client_info['token_uri'], token_uri=client_info['token_uri'],
message) message=message)
except clientsecrets.InvalidClientSecretsError: except clientsecrets.InvalidClientSecretsError:
self._in_error = True self._in_error = True
if message is not None: if message is not None:
@@ -480,7 +575,8 @@ class OAuth2DecoratorFromClientSecrets(OAuth2Decorator):
self._message = "Please configure your application for OAuth 2.0" self._message = "Please configure your application for OAuth 2.0"
def oauth2decorator_from_clientsecrets(filename, scope, @util.positional(2)
def oauth2decorator_from_clientsecrets(filename, scope,
message=None, cache=None): message=None, cache=None):
"""Creates an OAuth2Decorator populated from a clientsecrets file. """Creates an OAuth2Decorator populated from a clientsecrets file.
@@ -492,46 +588,11 @@ def oauth2decorator_from_clientsecrets(filename, scope,
clientsecrets file is missing or invalid. The message may contain HTML and clientsecrets file is missing or invalid. The message may contain HTML and
will be presented on the web interface for any method that uses the will be presented on the web interface for any method that uses the
decorator. decorator.
cache: An optional cache service client that implements get() and set() cache: An optional cache service client that implements get() and set()
methods. See clientsecrets.loadfile() for details. methods. See clientsecrets.loadfile() for details.
Returns: An OAuth2Decorator Returns: An OAuth2Decorator
""" """
return OAuth2DecoratorFromClientSecrets(filename, scope, return OAuth2DecoratorFromClientSecrets(filename, scope,
message=message, cache=cache) message=message, cache=cache)
class OAuth2Handler(webapp.RequestHandler):
"""Handler for the redirect_uri of the OAuth 2.0 dance."""
@login_required
def get(self):
error = self.request.get('error')
if error:
errormsg = self.request.get('error_description', error)
self.response.out.write(
'The authorization request failed: %s' % errormsg)
else:
user = users.get_current_user()
flow = pickle.loads(memcache.get(user.user_id(),
namespace=OAUTH2CLIENT_NAMESPACE))
# This code should be ammended with application specific error
# handling. The following cases should be considered:
# 1. What if the flow doesn't exist in memcache? Or is corrupt?
# 2. What if the step2_exchange fails?
if flow:
credentials = flow.step2_exchange(self.request.params)
StorageByKeyName(
CredentialsModel, user.user_id(), 'credentials').put(credentials)
self.redirect(str(self.request.get('state')))
else:
# TODO Add error handling here.
pass
application = webapp.WSGIApplication([('/oauth2callback', OAuth2Handler)])
def main():
run_wsgi_app(application)

View File

@@ -31,7 +31,8 @@ import time
import urllib import urllib
import urlparse import urlparse
from anyjson import simplejson from oauth2client import util
from oauth2client.anyjson import simplejson
HAS_OPENSSL = False HAS_OPENSSL = False
try: try:
@@ -327,6 +328,7 @@ class OAuth2Credentials(Credentials):
OAuth2Credentials objects may be safely pickled and unpickled. OAuth2Credentials objects may be safely pickled and unpickled.
""" """
@util.positional(8)
def __init__(self, access_token, client_id, client_secret, refresh_token, def __init__(self, access_token, client_id, client_secret, refresh_token,
token_expiry, token_uri, user_agent, id_token=None): token_expiry, token_uri, user_agent, id_token=None):
"""Create an instance of OAuth2Credentials. """Create an instance of OAuth2Credentials.
@@ -394,6 +396,7 @@ class OAuth2Credentials(Credentials):
request_orig = http.request request_orig = http.request
# The closure that will replace 'httplib2.Http.request'. # The closure that will replace 'httplib2.Http.request'.
@util.positional(1)
def new_request(uri, method='GET', body=None, headers=None, def new_request(uri, method='GET', body=None, headers=None,
redirections=httplib2.DEFAULT_MAX_REDIRECTS, redirections=httplib2.DEFAULT_MAX_REDIRECTS,
connection_type=None): connection_type=None):
@@ -481,7 +484,7 @@ class OAuth2Credentials(Credentials):
data['token_expiry'], data['token_expiry'],
data['token_uri'], data['token_uri'],
data['user_agent'], data['user_agent'],
data.get('id_token', None)) id_token=data.get('id_token', None))
retval.invalid = data['invalid'] retval.invalid = data['invalid']
return retval return retval
@@ -699,7 +702,8 @@ class AssertionCredentials(OAuth2Credentials):
AssertionCredentials objects may be safely pickled and unpickled. AssertionCredentials objects may be safely pickled and unpickled.
""" """
def __init__(self, assertion_type, user_agent, @util.positional(2)
def __init__(self, assertion_type, user_agent=None,
token_uri='https://accounts.google.com/o/oauth2/token', token_uri='https://accounts.google.com/o/oauth2/token',
**unused_kwargs): **unused_kwargs):
"""Constructor for AssertionFlowCredentials. """Constructor for AssertionFlowCredentials.
@@ -757,6 +761,7 @@ if HAS_OPENSSL:
MAX_TOKEN_LIFETIME_SECS = 3600 # 1 hour in seconds MAX_TOKEN_LIFETIME_SECS = 3600 # 1 hour in seconds
@util.positional(4)
def __init__(self, def __init__(self,
service_account_name, service_account_name,
private_key, private_key,
@@ -781,7 +786,7 @@ if HAS_OPENSSL:
super(SignedJwtAssertionCredentials, self).__init__( super(SignedJwtAssertionCredentials, self).__init__(
'http://oauth.net/grant_type/jwt/1.0/bearer', 'http://oauth.net/grant_type/jwt/1.0/bearer',
user_agent, user_agent=user_agent,
token_uri=token_uri, token_uri=token_uri,
) )
@@ -833,6 +838,7 @@ if HAS_OPENSSL:
# for the certs. # for the certs.
_cached_http = httplib2.Http(MemoryCache()) _cached_http = httplib2.Http(MemoryCache())
@util.positional(2)
def verify_id_token(id_token, audience, http=None, def verify_id_token(id_token, audience, http=None,
cert_uri=ID_TOKEN_VERIFICATON_CERTS): cert_uri=ID_TOKEN_VERIFICATON_CERTS):
"""Verifies a signed JWT id_token. """Verifies a signed JWT id_token.
@@ -892,6 +898,7 @@ def _extract_id_token(id_token):
return simplejson.loads(_urlsafe_b64decode(segments[1])) return simplejson.loads(_urlsafe_b64decode(segments[1]))
def _parse_exchange_token_response(content): def _parse_exchange_token_response(content):
"""Parses response of an exchange token request. """Parses response of an exchange token request.
@@ -919,10 +926,11 @@ def _parse_exchange_token_response(content):
return resp return resp
@util.positional(4)
def credentials_from_code(client_id, client_secret, scope, code, def credentials_from_code(client_id, client_secret, scope, code,
redirect_uri = 'postmessage', redirect_uri='postmessage', http=None, user_agent=None,
http=None, user_agent=None, token_uri='https://accounts.google.com/o/oauth2/token'):
token_uri='https://accounts.google.com/o/oauth2/token'):
"""Exchanges an authorization code for an OAuth2Credentials object. """Exchanges an authorization code for an OAuth2Credentials object.
Args: Args:
@@ -943,19 +951,19 @@ def credentials_from_code(client_id, client_secret, scope, code,
FlowExchangeError if the authorization code cannot be exchanged for an FlowExchangeError if the authorization code cannot be exchanged for an
access token access token
""" """
flow = OAuth2WebServerFlow(client_id, client_secret, scope, user_agent, flow = OAuth2WebServerFlow(client_id, client_secret, scope,
'https://accounts.google.com/o/oauth2/auth', redirect_uri=redirect_uri, user_agent=user_agent,
token_uri) auth_uri='https://accounts.google.com/o/oauth2/auth',
token_uri=token_uri)
# We primarily make this call to set up the redirect_uri in the flow object credentials = flow.step2_exchange(code, http=http)
uriThatWeDontReallyUse = flow.step1_get_authorize_url(redirect_uri)
credentials = flow.step2_exchange(code, http)
return credentials return credentials
@util.positional(3)
def credentials_from_clientsecrets_and_code(filename, scope, code, def credentials_from_clientsecrets_and_code(filename, scope, code,
message = None, message = None,
redirect_uri = 'postmessage', redirect_uri='postmessage',
http=None, http=None,
cache=None): cache=None):
"""Returns OAuth2Credentials from a clientsecrets file and an auth code. """Returns OAuth2Credentials from a clientsecrets file and an auth code.
@@ -966,7 +974,7 @@ def credentials_from_clientsecrets_and_code(filename, scope, code,
Args: Args:
filename: string, File name of clientsecrets. filename: string, File name of clientsecrets.
scope: string or list of strings, scope(s) to request. scope: string or list of strings, scope(s) to request.
code: string, An authroization code, most likely passed down from code: string, An authorization code, most likely passed down from
the client the client
message: string, A friendly string to display to the user if the message: string, A friendly string to display to the user if the
clientsecrets file is missing or invalid. If message is provided then clientsecrets file is missing or invalid. If message is provided then
@@ -975,7 +983,7 @@ def credentials_from_clientsecrets_and_code(filename, scope, code,
redirect_uri: string, this is generally set to 'postmessage' to match the redirect_uri: string, this is generally set to 'postmessage' to match the
redirect_uri that the client specified redirect_uri that the client specified
http: httplib2.Http, optional http instance to use to do the fetch http: httplib2.Http, optional http instance to use to do the fetch
cache: An optional cache service client that implements get() and set() cache: An optional cache service client that implements get() and set()
methods. See clientsecrets.loadfile() for details. methods. See clientsecrets.loadfile() for details.
Returns: Returns:
@@ -988,20 +996,22 @@ def credentials_from_clientsecrets_and_code(filename, scope, code,
clientsecrets.InvalidClientSecretsError if the clientsecrets file is clientsecrets.InvalidClientSecretsError if the clientsecrets file is
invalid. invalid.
""" """
flow = flow_from_clientsecrets(filename, scope, message=message, cache=cache) flow = flow_from_clientsecrets(filename, scope, message=message, cache=cache,
# We primarily make this call to set up the redirect_uri in the flow object redirect_uri=redirect_uri)
uriThatWeDontReallyUse = flow.step1_get_authorize_url(redirect_uri) credentials = flow.step2_exchange(code, http=http)
credentials = flow.step2_exchange(code, http)
return credentials return credentials
class OAuth2WebServerFlow(Flow): class OAuth2WebServerFlow(Flow):
"""Does the Web Server Flow for OAuth 2.0. """Does the Web Server Flow for OAuth 2.0.
OAuth2Credentials objects may be safely pickled and unpickled. OAuth2WebServerFlow objects may be safely pickled and unpickled.
""" """
def __init__(self, client_id, client_secret, scope, user_agent=None, @util.positional(4)
def __init__(self, client_id, client_secret, scope,
redirect_uri=None,
user_agent=None,
auth_uri='https://accounts.google.com/o/oauth2/auth', auth_uri='https://accounts.google.com/o/oauth2/auth',
token_uri='https://accounts.google.com/o/oauth2/token', token_uri='https://accounts.google.com/o/oauth2/token',
**kwargs): **kwargs):
@@ -1012,6 +1022,9 @@ class OAuth2WebServerFlow(Flow):
client_secret: string client secret. client_secret: string client secret.
scope: string or list of strings, scope(s) of the credentials being scope: string or list of strings, scope(s) of the credentials being
requested. requested.
redirect_uri: string, Either the string 'urn:ietf:wg:oauth:2.0:oob' for
a non-web-based application, or a URI that handles the callback from
the authorization server.
user_agent: string, HTTP User-Agent to provide for this application. user_agent: string, HTTP User-Agent to provide for this application.
auth_uri: string, URI for authorization endpoint. For convenience auth_uri: string, URI for authorization endpoint. For convenience
defaults to Google's endpoints but any OAuth 2.0 provider can be used. defaults to Google's endpoints but any OAuth 2.0 provider can be used.
@@ -1025,6 +1038,7 @@ class OAuth2WebServerFlow(Flow):
if type(scope) is list: if type(scope) is list:
scope = ' '.join(scope) scope = ' '.join(scope)
self.scope = scope self.scope = scope
self.redirect_uri = redirect_uri
self.user_agent = user_agent self.user_agent = user_agent
self.auth_uri = auth_uri self.auth_uri = auth_uri
self.token_uri = token_uri self.token_uri = token_uri
@@ -1032,27 +1046,33 @@ class OAuth2WebServerFlow(Flow):
'access_type': 'offline', 'access_type': 'offline',
} }
self.params.update(kwargs) self.params.update(kwargs)
self.redirect_uri = None
def step1_get_authorize_url(self, redirect_uri=OOB_CALLBACK_URN): @util.positional(1)
def step1_get_authorize_url(self, redirect_uri=None):
"""Returns a URI to redirect to the provider. """Returns a URI to redirect to the provider.
Args: Args:
redirect_uri: string, Either the string 'urn:ietf:wg:oauth:2.0:oob' for redirect_uri: string, Either the string 'urn:ietf:wg:oauth:2.0:oob' for
a non-web-based application, or a URI that handles the callback from a non-web-based application, or a URI that handles the callback from
the authorization server. the authorization server. This parameter is deprecated, please move to
passing the redirect_uri in via the constructor.
If redirect_uri is 'urn:ietf:wg:oauth:2.0:oob' then pass in the Returns:
generated verification code to step2_exchange, A URI as a string to redirect the user to begin the authorization flow.
otherwise pass in the query parameters received
at the callback uri to step2_exchange.
""" """
if redirect_uri is not None:
logger.warning(('The redirect_uri parameter for'
'OAuth2WebServerFlow.step1_get_authorize_url is deprecated. Please'
'move to passing the redirect_uri in via the constructor.'))
self.redirect_uri = redirect_uri
if self.redirect_uri is None:
raise ValueError('The value of redirect_uri must not be None.')
self.redirect_uri = redirect_uri
query = { query = {
'response_type': 'code', 'response_type': 'code',
'client_id': self.client_id, 'client_id': self.client_id,
'redirect_uri': redirect_uri, 'redirect_uri': self.redirect_uri,
'scope': self.scope, 'scope': self.scope,
} }
query.update(self.params) query.update(self.params)
@@ -1061,6 +1081,7 @@ class OAuth2WebServerFlow(Flow):
parts[4] = urllib.urlencode(query) parts[4] = urllib.urlencode(query)
return urlparse.urlunparse(parts) return urlparse.urlunparse(parts)
@util.positional(2)
def step2_exchange(self, code, http=None): def step2_exchange(self, code, http=None):
"""Exhanges a code for OAuth2Credentials. """Exhanges a code for OAuth2Credentials.
@@ -1134,7 +1155,9 @@ class OAuth2WebServerFlow(Flow):
error_msg = 'Invalid response: %s.' % str(resp.status) error_msg = 'Invalid response: %s.' % str(resp.status)
raise FlowExchangeError(error_msg) raise FlowExchangeError(error_msg)
def flow_from_clientsecrets(filename, scope, message=None, cache=None):
@util.positional(2)
def flow_from_clientsecrets(filename, scope, redirect_uri=None, message=None, cache=None):
"""Create a Flow from a clientsecrets file. """Create a Flow from a clientsecrets file.
Will create the right kind of Flow based on the contents of the clientsecrets Will create the right kind of Flow based on the contents of the clientsecrets
@@ -1143,11 +1166,14 @@ def flow_from_clientsecrets(filename, scope, message=None, cache=None):
Args: Args:
filename: string, File name of client secrets. filename: string, File name of client secrets.
scope: string or list of strings, scope(s) to request. scope: string or list of strings, scope(s) to request.
redirect_uri: string, Either the string 'urn:ietf:wg:oauth:2.0:oob' for
a non-web-based application, or a URI that handles the callback from
the authorization server.
message: string, A friendly string to display to the user if the message: string, A friendly string to display to the user if the
clientsecrets file is missing or invalid. If message is provided then clientsecrets file is missing or invalid. If message is provided then
sys.exit will be called in the case of an error. If message in not sys.exit will be called in the case of an error. If message in not
provided then clientsecrets.InvalidClientSecretsError will be raised. provided then clientsecrets.InvalidClientSecretsError will be raised.
cache: An optional cache service client that implements get() and set() cache: An optional cache service client that implements get() and set()
methods. See clientsecrets.loadfile() for details. methods. See clientsecrets.loadfile() for details.
Returns: Returns:
@@ -1165,9 +1191,11 @@ def flow_from_clientsecrets(filename, scope, message=None, cache=None):
client_info['client_id'], client_info['client_id'],
client_info['client_secret'], client_info['client_secret'],
scope, scope,
None, # user_agent redirect_uri=redirect_uri,
client_info['auth_uri'], user_agent=None,
client_info['token_uri']) auth_uri=client_info['auth_uri'],
token_uri=client_info['token_uri'])
except clientsecrets.InvalidClientSecretsError: except clientsecrets.InvalidClientSecretsError:
if message: if message:
sys.exit(message) sys.exit(message)

View File

@@ -23,6 +23,8 @@ import logging
import os import os
import time import time
from oauth2client import util
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -292,6 +294,7 @@ except ImportError:
class LockedFile(object): class LockedFile(object):
"""Represent a file that has exclusive access.""" """Represent a file that has exclusive access."""
@util.positional(4)
def __init__(self, filename, mode, fallback_mode, use_native_locking=True): def __init__(self, filename, mode, fallback_mode, use_native_locking=True):
"""Construct a LockedFile. """Construct a LockedFile.

View File

@@ -38,8 +38,9 @@ import os
import threading import threading
from anyjson import simplejson from anyjson import simplejson
from client import Storage as BaseStorage from oauth2client.client import Storage as BaseStorage
from client import Credentials from oauth2client.client import Credentials
from oauth2client import util
from locked_file import LockedFile from locked_file import LockedFile
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -59,6 +60,7 @@ class NewerCredentialStoreError(Error):
pass pass
@util.positional(4)
def get_credential_storage(filename, client_id, user_agent, scope, def get_credential_storage(filename, client_id, user_agent, scope,
warn_on_readonly=True): warn_on_readonly=True):
"""Get a Storage instance for a credential. """Get a Storage instance for a credential.
@@ -78,7 +80,7 @@ def get_credential_storage(filename, client_id, user_agent, scope,
_multistores_lock.acquire() _multistores_lock.acquire()
try: try:
multistore = _multistores.setdefault( multistore = _multistores.setdefault(
filename, _MultiStore(filename, warn_on_readonly)) filename, _MultiStore(filename, warn_on_readonly=warn_on_readonly))
finally: finally:
_multistores_lock.release() _multistores_lock.release()
if type(scope) is list: if type(scope) is list:
@@ -89,6 +91,7 @@ def get_credential_storage(filename, client_id, user_agent, scope,
class _MultiStore(object): class _MultiStore(object):
"""A file backed store for multiple credentials.""" """A file backed store for multiple credentials."""
@util.positional(2)
def __init__(self, filename, warn_on_readonly=True): def __init__(self, filename, warn_on_readonly=True):
"""Initialize the class. """Initialize the class.

View File

@@ -29,8 +29,9 @@ import socket
import sys import sys
import webbrowser import webbrowser
from client import FlowExchangeError from oauth2client.client import FlowExchangeError
from client import OOB_CALLBACK_URN from oauth2client.client import OOB_CALLBACK_URN
from oauth2client import util
try: try:
from urlparse import parse_qsl from urlparse import parse_qsl
@@ -91,6 +92,7 @@ class ClientRedirectHandler(BaseHTTPServer.BaseHTTPRequestHandler):
pass pass
@util.positional(2)
def run(flow, storage, http=None): def run(flow, storage, http=None):
"""Core code for a command-line application. """Core code for a command-line application.
@@ -130,7 +132,8 @@ def run(flow, storage, http=None):
oauth_callback = 'http://%s:%s/' % (FLAGS.auth_host_name, port_number) oauth_callback = 'http://%s:%s/' % (FLAGS.auth_host_name, port_number)
else: else:
oauth_callback = OOB_CALLBACK_URN oauth_callback = OOB_CALLBACK_URN
authorize_url = flow.step1_get_authorize_url(oauth_callback) flow.redirect_uri = oauth_callback
authorize_url = flow.step1_get_authorize_url()
if FLAGS.auth_local_webserver: if FLAGS.auth_local_webserver:
webbrowser.open(authorize_url, new=1, autoraise=True) webbrowser.open(authorize_url, new=1, autoraise=True)
@@ -163,7 +166,7 @@ def run(flow, storage, http=None):
code = raw_input('Enter verification code: ').strip() code = raw_input('Enter verification code: ').strip()
try: try:
credential = flow.step2_exchange(code, http) credential = flow.step2_exchange(code, http=http)
except FlowExchangeError, e: except FlowExchangeError, e:
sys.exit('Authentication has failed: %s' % e) sys.exit('Authentication has failed: %s' % e)

127
oauth2client/util.py Normal file
View File

@@ -0,0 +1,127 @@
#!/usr/bin/env python
#
# Copyright 2010 Google Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
"""Common utility library."""
__author__ = ['rafek@google.com (Rafe Kaplan)',
'guido@google.com (Guido van Rossum)',
]
__all__ = [
'positional',
]
import gflags
import inspect
import logging
logger = logging.getLogger(__name__)
FLAGS = gflags.FLAGS
gflags.DEFINE_enum('positional_parameters_enforcement', 'WARNING',
['EXCEPTION', 'WARNING', 'IGNORE'],
'The action when an oauth2client.util.positional declaration is violated.')
def positional(max_positional_args):
"""A decorator to declare that only the first N arguments my be positional.
This decorator makes it easy to support Python 3 style key-word only
parameters. For example, in Python 3 it is possible to write:
def fn(pos1, *, kwonly1=None, kwonly1=None):
...
All named parameters after * must be a keyword:
fn(10, 'kw1', 'kw2') # Raises exception.
fn(10, kwonly1='kw1') # Ok.
Example:
To define a function like above, do:
@positional(1)
def fn(pos1, kwonly1=None, kwonly2=None):
...
If no default value is provided to a keyword argument, it becomes a required
keyword argument:
@positional(0)
def fn(required_kw):
...
This must be called with the keyword parameter:
fn() # Raises exception.
fn(10) # Raises exception.
fn(required_kw=10) # Ok.
When defining instance or class methods always remember to account for
'self' and 'cls':
class MyClass(object):
@positional(2)
def my_method(self, pos1, kwonly1=None):
...
@classmethod
@positional(2)
def my_method(cls, pos1, kwonly1=None):
...
The positional decorator behavior is controlled by the
--positional_parameters_enforcement flag. The flag may be set to 'EXCEPTION',
'WARNING' or 'IGNORE' to raise an exception, log a warning, or do nothing,
respectively, if a declaration is violated.
Args:
max_positional_arguments: Maximum number of positional arguments. All
parameters after the this index must be keyword only.
Returns:
A decorator that prevents using arguments after max_positional_args from
being used as positional parameters.
Raises:
TypeError if a key-word only argument is provided as a positional parameter,
but only if the --positional_parameters_enforcement flag is set to
'EXCEPTION'.
"""
def positional_decorator(wrapped):
def positional_wrapper(*args, **kwargs):
if len(args) > max_positional_args:
plural_s = ''
if max_positional_args != 1:
plural_s = 's'
message = '%s() takes at most %d positional argument%s (%d given)' % (
wrapped.__name__, max_positional_args, plural_s, len(args))
if FLAGS.positional_parameters_enforcement == 'EXCEPTION':
raise TypeError(message)
elif FLAGS.positional_parameters_enforcement == 'WARNING':
logger.warning(message)
else: # IGNORE
pass
return wrapped(*args, **kwargs)
return positional_wrapper
if isinstance(max_positional_args, (int, long)):
return positional_decorator
else:
args, _, _, defaults = inspect.getargspec(max_positional_args)
return positional(len(args) - len(defaults))(max_positional_args)

View File

@@ -1,14 +1,21 @@
#!/usr/bin/env python #!/usr/bin/env python
import gflags
import glob import glob
import imp import imp
import logging import logging
import os import os
import sys import sys
import unittest import unittest
# import oauth2client.util for its gflags.
import oauth2client.util
from trace import fullmodname from trace import fullmodname
logging.basicConfig(level=logging.CRITICAL) logging.basicConfig(level=logging.CRITICAL)
FLAGS = gflags.FLAGS
APP_ENGINE_PATH='../google_appengine' APP_ENGINE_PATH='../google_appengine'
# Conditional import of cleanup function # Conditional import of cleanup function
@@ -26,12 +33,13 @@ from google.appengine.dist import use_library
use_library('django', '1.2') use_library('django', '1.2')
def main(): def main(argv):
for t in sys.argv[1:]: argv = FLAGS(argv)
for t in argv[1:]:
module = imp.load_source('test', t) module = imp.load_source('test', t)
test = unittest.TestLoader().loadTestsFromModule(module) test = unittest.TestLoader().loadTestsFromModule(module)
result = unittest.TextTestRunner(verbosity=1).run(test) result = unittest.TextTestRunner(verbosity=1).run(test)
if __name__ == '__main__': if __name__ == '__main__':
main() main(sys.argv)

View File

@@ -4,18 +4,19 @@
# #
# The python interpreter to use is passed in on the command line. # The python interpreter to use is passed in on the command line.
$1 runtests.py tests/test_discovery.py FLAGS=--positional_parameters_enforcement=EXCEPTION
$1 runtests.py tests/test_errors.py $1 runtests.py $FLAGS tests/test_discovery.py
$1 runtests.py tests/test_http.py $1 runtests.py $FLAGS tests/test_errors.py
$1 runtests.py tests/test_json_model.py $1 runtests.py $FLAGS tests/test_http.py
$1 runtests.py tests/test_mocks.py $1 runtests.py $FLAGS tests/test_json_model.py
$1 runtests.py tests/test_model.py $1 runtests.py $FLAGS tests/test_mocks.py
$1 runtests.py tests/test_oauth2client_clientsecrets.py $1 runtests.py $FLAGS tests/test_model.py
$1 runtests.py tests/test_oauth2client_django_orm.py $1 runtests.py $FLAGS tests/test_oauth2client_clientsecrets.py
$1 runtests.py tests/test_oauth2client_file.py $1 runtests.py $FLAGS tests/test_oauth2client_django_orm.py
$1 runtests.py tests/test_oauth2client_jwt.py $1 runtests.py $FLAGS tests/test_oauth2client_file.py
$1 runtests.py tests/test_oauth2client.py $1 runtests.py $FLAGS tests/test_oauth2client_jwt.py
$1 runtests.py tests/test_protobuf_model.py $1 runtests.py $FLAGS tests/test_oauth2client.py
$1 runtests.py tests/test_schema.py $1 runtests.py $FLAGS tests/test_protobuf_model.py
$1 runtests.py tests/test_oauth2client_appengine.py $1 runtests.py $FLAGS tests/test_schema.py
$1 runtests.py tests/test_oauth2client_keyring.py $1 runtests.py $FLAGS tests/test_oauth2client_appengine.py
$1 runtests.py $FLAGS tests/test_oauth2client_keyring.py

View File

@@ -28,6 +28,7 @@ from apiclient.discovery import build
import gflags import gflags
import httplib2 import httplib2
from oauth2client.client import flow_from_clientsecrets from oauth2client.client import flow_from_clientsecrets
from oauth2client.client import OOB_CALLBACK_URN
from oauth2client.file import Storage from oauth2client.file import Storage
from oauth2client.tools import run from oauth2client.tools import run
@@ -58,6 +59,7 @@ with information from the APIs Console <https://code.google.com/apis/console>.
FLOW = flow_from_clientsecrets( FLOW = flow_from_clientsecrets(
CLIENT_SECRETS, CLIENT_SECRETS,
scope='https://www.googleapis.com/auth/adexchange.buyer', scope='https://www.googleapis.com/auth/adexchange.buyer',
redirect_uri=OOB_CALLBACK_URN,
message=MISSING_CLIENT_SECRETS_MESSAGE message=MISSING_CLIENT_SECRETS_MESSAGE
) )
@@ -108,4 +110,3 @@ def initialize_service():
# Construct a service object via the discovery service. # Construct a service object via the discovery service.
service = build('adexchangebuyer', 'v1', http=http) service = build('adexchangebuyer', 'v1', http=http)
return service return service

View File

@@ -28,6 +28,7 @@ from apiclient.discovery import build
import gflags import gflags
import httplib2 import httplib2
from oauth2client.client import flow_from_clientsecrets from oauth2client.client import flow_from_clientsecrets
from oauth2client.client import OOB_CALLBACK_URN
from oauth2client.file import Storage from oauth2client.file import Storage
from oauth2client.tools import run from oauth2client.tools import run
@@ -57,6 +58,7 @@ with information from the APIs Console <https://code.google.com/apis/console>.
# Set up a Flow object to be used if we need to authenticate. # Set up a Flow object to be used if we need to authenticate.
FLOW = flow_from_clientsecrets(CLIENT_SECRETS, FLOW = flow_from_clientsecrets(CLIENT_SECRETS,
scope='https://www.googleapis.com/auth/adsense.readonly', scope='https://www.googleapis.com/auth/adsense.readonly',
redirect_uri=OOB_CALLBACK_URN,
message=MISSING_CLIENT_SECRETS_MESSAGE) message=MISSING_CLIENT_SECRETS_MESSAGE)
# The gflags module makes defining command-line options easy for applications. # The gflags module makes defining command-line options easy for applications.

View File

@@ -40,6 +40,7 @@ from apiclient.discovery import build
import gflags import gflags
import httplib2 import httplib2
from oauth2client.client import flow_from_clientsecrets from oauth2client.client import flow_from_clientsecrets
from oauth2client.client import OOB_CALLBACK_URN
from oauth2client.file import Storage from oauth2client.file import Storage
from oauth2client.tools import run from oauth2client.tools import run
@@ -70,6 +71,7 @@ with information from the APIs Console <https://code.google.com/apis/console>.
# Set up a Flow object to be used if we need to authenticate. # Set up a Flow object to be used if we need to authenticate.
FLOW = flow_from_clientsecrets(CLIENT_SECRETS, FLOW = flow_from_clientsecrets(CLIENT_SECRETS,
scope='https://www.googleapis.com/auth/analytics.readonly', scope='https://www.googleapis.com/auth/analytics.readonly',
redirect_uri=OOB_CALLBACK_URN,
message=MISSING_CLIENT_SECRETS_MESSAGE) message=MISSING_CLIENT_SECRETS_MESSAGE)
# The gflags module makes defining command-line options easy for applications. # The gflags module makes defining command-line options easy for applications.

View File

@@ -4,9 +4,6 @@ runtime: python
api_version: 1 api_version: 1
handlers: handlers:
- url: /oauth2callback
script: oauth2client/appengine.py
- url: .* - url: .*
script: main.py script: main.py

View File

@@ -66,8 +66,8 @@ http = httplib2.Http(memcache)
service = build("plus", "v1", http=http) service = build("plus", "v1", http=http)
decorator = oauth2decorator_from_clientsecrets( decorator = oauth2decorator_from_clientsecrets(
CLIENT_SECRETS, CLIENT_SECRETS,
'https://www.googleapis.com/auth/plus.me', scope='https://www.googleapis.com/auth/plus.me',
MISSING_CLIENT_SECRETS_MESSAGE) message=MISSING_CLIENT_SECRETS_MESSAGE)
class MainHandler(webapp.RequestHandler): class MainHandler(webapp.RequestHandler):
@@ -87,7 +87,7 @@ class AboutHandler(webapp.RequestHandler):
def get(self): def get(self):
try: try:
http = decorator.http() http = decorator.http()
user = service.people().get(userId='me').execute(http) user = service.people().get(userId='me').execute(http=http)
text = 'Hello, %s!' % user['displayName'] text = 'Hello, %s!' % user['displayName']
path = os.path.join(os.path.dirname(__file__), 'welcome.html') path = os.path.join(os.path.dirname(__file__), 'welcome.html')
@@ -101,6 +101,7 @@ def main():
[ [
('/', MainHandler), ('/', MainHandler),
('/about', AboutHandler), ('/about', AboutHandler),
(decorator.callback_path, decorator.callback_handler()),
], ],
debug=True) debug=True)
run_wsgi_app(application) run_wsgi_app(application)

View File

@@ -113,7 +113,7 @@ def main(argv):
users = service.users() users = service.users()
# Retrieve this user's profile information # Retrieve this user's profile information
thisuser = users.get(userId="self").execute(http) thisuser = users.get(userId="self").execute(http=http)
print "This user's display name is: %s" % thisuser['displayName'] print "This user's display name is: %s" % thisuser['displayName']
# Retrieve the list of Blogs this user has write privileges on # Retrieve the list of Blogs this user has write privileges on
@@ -128,7 +128,7 @@ def main(argv):
print "The posts for %s:" % blog['name'] print "The posts for %s:" % blog['name']
request = posts.list(blogId=blog['id']) request = posts.list(blogId=blog['id'])
while request != None: while request != None:
posts_doc = request.execute(http) posts_doc = request.execute(http=http)
if 'items' in posts_doc and not (posts_doc['items'] is None): if 'items' in posts_doc and not (posts_doc['items'] is None):
for post in posts_doc['items']: for post in posts_doc['items']:
print " %s (%s)" % (post['title'], post['url']) print " %s (%s)" % (post['title'], post['url'])

View File

@@ -128,7 +128,7 @@ def main(argv):
try: try:
# List all the jobs for a team # List all the jobs for a team
jobs_result = service.jobs().list(teamId=FLAGS.teamId).execute(http) jobs_result = service.jobs().list(teamId=FLAGS.teamId).execute(http=http)
print('List of Jobs:') print('List of Jobs:')
pprint.pprint(jobs_result) pprint.pprint(jobs_result)

View File

@@ -26,7 +26,6 @@ import pickle
from oauth2client.appengine import CredentialsProperty from oauth2client.appengine import CredentialsProperty
from oauth2client.appengine import StorageByKeyName from oauth2client.appengine import StorageByKeyName
from oauth2client.client import OAuth2WebServerFlow from oauth2client.client import OAuth2WebServerFlow
from google.appengine.api import memcache
from google.appengine.api import users from google.appengine.api import users
from google.appengine.ext import db from google.appengine.ext import db
from google.appengine.ext import webapp from google.appengine.ext import webapp
@@ -39,6 +38,7 @@ FLOW = OAuth2WebServerFlow(
client_id='2ad565600216d25d9cde', client_id='2ad565600216d25d9cde',
client_secret='03b56df2949a520be6049ff98b89813f17b467dc', client_secret='03b56df2949a520be6049ff98b89813f17b467dc',
scope='read', scope='read',
redirect_uri='https://dailymotoauth2test.appspot.com/auth_return',
user_agent='oauth2client-sample/1.0', user_agent='oauth2client-sample/1.0',
auth_uri='https://api.dailymotion.com/oauth/authorize', auth_uri='https://api.dailymotion.com/oauth/authorize',
token_uri='https://api.dailymotion.com/oauth/token' token_uri='https://api.dailymotion.com/oauth/token'
@@ -58,9 +58,7 @@ class MainHandler(webapp.RequestHandler):
Credentials, user.user_id(), 'credentials').get() Credentials, user.user_id(), 'credentials').get()
if credentials is None or credentials.invalid == True: if credentials is None or credentials.invalid == True:
callback = self.request.relative_url('/auth_return') authorize_url = FLOW.step1_get_authorize_url()
authorize_url = FLOW.step1_get_authorize_url(callback)
memcache.set(user.user_id(), pickle.dumps(FLOW))
self.redirect(authorize_url) self.redirect(authorize_url)
else: else:
http = httplib2.Http() http = httplib2.Http()
@@ -82,14 +80,10 @@ class OAuthHandler(webapp.RequestHandler):
@login_required @login_required
def get(self): def get(self):
user = users.get_current_user() user = users.get_current_user()
flow = pickle.loads(memcache.get(user.user_id())) credentials = FLOW.step2_exchange(self.request.params)
if flow: StorageByKeyName(
credentials = flow.step2_exchange(self.request.params) Credentials, user.user_id(), 'credentials').put(credentials)
StorageByKeyName( self.redirect("/")
Credentials, user.user_id(), 'credentials').put(credentials)
self.redirect("/")
else:
pass
def main(): def main():

View File

@@ -8,13 +8,6 @@ from django.db import models
from oauth2client.django_orm import FlowField from oauth2client.django_orm import FlowField
from oauth2client.django_orm import CredentialsField from oauth2client.django_orm import CredentialsField
# The Flow could also be stored in memcache since it is short lived.
class FlowModel(models.Model):
id = models.ForeignKey(User, primary_key=True)
flow = FlowField()
class CredentialsModel(models.Model): class CredentialsModel(models.Model):
id = models.ForeignKey(User, primary_key=True) id = models.ForeignKey(User, primary_key=True)
@@ -30,4 +23,3 @@ class FlowAdmin(admin.ModelAdmin):
admin.site.register(CredentialsModel, CredentialsAdmin) admin.site.register(CredentialsModel, CredentialsAdmin)
admin.site.register(FlowModel, FlowAdmin)

View File

@@ -9,7 +9,6 @@ from django.contrib.auth.decorators import login_required
from oauth2client.django_orm import Storage from oauth2client.django_orm import Storage
from oauth2client.client import OAuth2WebServerFlow from oauth2client.client import OAuth2WebServerFlow
from django_sample.plus.models import CredentialsModel from django_sample.plus.models import CredentialsModel
from django_sample.plus.models import FlowModel
from apiclient.discovery import build from apiclient.discovery import build
from django.http import HttpResponseRedirect from django.http import HttpResponseRedirect
@@ -17,22 +16,21 @@ from django.shortcuts import render_to_response
STEP2_URI = 'http://localhost:8000/oauth2callback' STEP2_URI = 'http://localhost:8000/oauth2callback'
FLOW = OAuth2WebServerFlow(
client_id='[[Insert Client ID here.]]',
client_secret='[[Insert Client Secret here.]]',
scope='https://www.googleapis.com/auth/plus.me',
redirect_uri=STEP2_URI,
user_agent='plus-django-sample/1.0',
)
@login_required @login_required
def index(request): def index(request):
storage = Storage(CredentialsModel, 'id', request.user, 'credential') storage = Storage(CredentialsModel, 'id', request.user, 'credential')
credential = storage.get() credential = storage.get()
if credential is None or credential.invalid == True: if credential is None or credential.invalid == True:
flow = OAuth2WebServerFlow( authorize_url = FLOW.step1_get_authorize_url()
client_id='[[Insert Client ID here.]]',
client_secret='[[Insert Client Secret here.]]',
scope='https://www.googleapis.com/auth/plus.me',
user_agent='plus-django-sample/1.0',
)
authorize_url = flow.step1_get_authorize_url(STEP2_URI)
f = FlowModel(id=request.user, flow=flow)
f.save()
return HttpResponseRedirect(authorize_url) return HttpResponseRedirect(authorize_url)
else: else:
http = httplib2.Http() http = httplib2.Http()
@@ -50,12 +48,7 @@ def index(request):
@login_required @login_required
def auth_return(request): def auth_return(request):
try: credential = FLOW.step2_exchange(request.REQUEST)
f = FlowModel.objects.get(id=request.user) storage = Storage(CredentialsModel, 'id', request.user, 'credential')
credential = f.flow.step2_exchange(request.REQUEST) storage.put(credential)
storage = Storage(CredentialsModel, 'id', request.user, 'credential') return HttpResponseRedirect("/")
storage.put(credential)
f.delete()
return HttpResponseRedirect("/")
except FlowModel.DoesNotExist:
pass

View File

@@ -113,7 +113,7 @@ def main(argv):
service = build("plus", "v1", http=http) service = build("plus", "v1", http=http)
try: try:
person = service.people().get(userId='me').execute(http) person = service.people().get(userId='me').execute(http=http)
print "Got your ID: %s" % person['displayName'] print "Got your ID: %s" % person['displayName']
print print

View File

@@ -57,7 +57,7 @@ def main(argv):
service = build("tasks", "v1", http=http) service = build("tasks", "v1", http=http)
# List all the tasklists for the account. # List all the tasklists for the account.
lists = service.tasklists().list().execute(http) lists = service.tasklists().list().execute(http=http)
pprint.pprint(lists) pprint.pprint(lists)

View File

@@ -4,9 +4,8 @@ runtime: python
api_version: 1 api_version: 1
handlers: handlers:
- url: /oauth2callback
script: oauth2client/appengine.py
- url: /css - url: /css
static_dir: css static_dir: css
- url: .* - url: .*
script: main.py script: main.py

View File

@@ -50,7 +50,10 @@ class MainHandler(webapp.RequestHandler):
def truncate(s, l): def truncate(s, l):
return s[:l] + '...' if len(s) > l else s return s[:l] + '...' if len(s) > l else s
application = webapp.WSGIApplication([('/', MainHandler)], debug=True) application = webapp.WSGIApplication([
('/', MainHandler),
(decorator.callback_path, decorator.callback_handler()),
], debug=True)
def main(): def main():

View File

@@ -100,7 +100,7 @@ def start_threads(credentials):
backoff = Backoff() backoff = Backoff()
while backoff.loop(): while backoff.loop():
try: try:
response = request.execute(http) response = request.execute(http=http)
print "Processed: %s in thread %d" % (response['id'], n) print "Processed: %s in thread %d" % (response['id'], n)
break break
except HttpError, e: except HttpError, e:

View File

@@ -1,36 +0,0 @@
This is an example program that can run as a power
management hook to set the timezone on the computer
based on the user's location, as determined by Google
Latitude. To use this application you will need Google
Latitude running on a mobile device.
api: latitude
keywords: cmdline
Installation
============
The google-api-python-client library will need to
be installed.
$ sudo python setup.py install
Then you will need to install the tznever application:
$ sudo cp tznever /usr/sbin/tznever
And then add it in as a power management hook:
$ sudo ln -s /usr/sbin/tznever /etc/pm/sleep.d/45tznever
Once that is done you need to run tznever once from the
the command line to tie it to your Latitude account:
$ sudo tznever
After that, every time your laptop resumes it will
check you Latitude location and set the timezone
accordingly.
TODO
====
1. What about stale Latitude data?

View File

@@ -1,289 +0,0 @@
#!/usr/bin/python
# -*- coding: utf-8 -*-
#
# Copyright 2010 Google Inc. All Rights Reserved.
# Portions copyright PSF License
# http://code.activestate.com/recipes/278731-creating-a-daemon-the-python-way/
"""A pm-action hook for setting timezone.
Uses the Google Latitude API and the geonames.org
API to find your cellphones latitude and longitude
and from the determine the timezone you are in,
and then sets the computer's timezone to that.
"""
__author__ = 'jcgregorio@google.com (Joe Gregorio)'
from apiclient.discovery import build
import httplib2
import os
import pickle
import pprint
import subprocess
import sys
import time
import uritemplate
from apiclient.anyjson import simplejson
from apiclient.discovery import build
from apiclient.oauth import FlowThreeLegged
from apiclient.ext.authtools import run
from apiclient.ext.file import Storage
# Uncomment to get detailed logging
# httplib2.debuglevel = 4
# URI Template to convert latitude and longitude into a timezone
GEONAMES = 'http://api.geonames.org/timezoneJSON?lat={lat}&lng={long}&username=jcgregorio'
PID_FILE = '/var/lock/tznever.pid'
CACHE = '/var/local/tznever/.cache'
# Default daemon parameters.
# File mode creation mask of the daemon.
UMASK = 0
# Default working directory for the daemon.
WORKDIR = "/"
# Default maximum for the number of available file descriptors.
MAXFD = 1024
# The standard I/O file descriptors are redirected to /dev/null by default.
if (hasattr(os, "devnull")):
REDIRECT_TO = os.devnull
else:
REDIRECT_TO = "/dev/null"
def main():
storage = Storage('/var/local/tznever/latitude_credentials.dat')
credentials = storage.get()
if len(sys.argv) == 1:
if credentials is None or credentials.invalid == True:
auth_discovery = build('latitude', 'v1').auth_discovery()
flow = FlowThreeLegged(auth_discovery,
consumer_key='m-buzz.appspot.com',
consumer_secret='NQEHb4eU6GkjjFGe1MD5W6IC',
user_agent='tz-never/1.0',
domain='m-buzz.appspot.com',
scope='https://www.googleapis.com/auth/latitude',
xoauth_displayname='TZ Never Again',
location='current',
granularity='city'
)
credentials = run(flow, storage)
else:
print "You are already authorized"
else:
if credentials is None or credentials.invalid == True:
print "This app, tznever, is not authorized. Run from the command-line to re-authorize."
os.exit(1)
if len(sys.argv) > 1 and sys.argv[1] in ['hibernate', 'suspend']:
print "Hibernating"
# Kill off the possibly still running process by its pid
if os.path.isfile(PID_FILE):
f = file(PID_FILE, 'r')
pid = f.read()
f.close()
cmdline = ['/bin/kill', '-2', pid]
subprocess.Popen(cmdline)
os.unlink(PID_FILE)
elif len(sys.argv) > 1 and sys.argv[1] in ['thaw', 'resume']:
print "Resuming"
# write our pid out
f = file(PID_FILE, 'w')
f.write(str(os.getpid()))
f.close()
success = False
first_time = True
while not success:
try:
if not first_time:
time.sleep(5)
else:
first_time = False
print "Daemonizing so as not to gum up the works."
createDaemon()
# rewrite the PID file with our new PID
f = file(PID_FILE, 'w')
f.write(str(os.getpid()))
f.close()
http = httplib2.Http(CACHE)
http = credentials.authorize(http)
service = build('latitude', 'v1', http=http)
location = service.currentLocation().get(granularity='city').execute()
position = {
'lat': str(location['latitude']),
'long': str(location['longitude'])
}
http2 = httplib2.Http(CACHE)
resp, content = http2.request(uritemplate.expand(GEONAMES, position))
geodata = simplejson.loads(content)
tz = geodata['timezoneId']
f = file('/etc/timezone', 'w')
f.write(tz)
f.close()
cmdline = 'dpkg-reconfigure -f noninteractive tzdata'.split(' ')
subprocess.Popen(cmdline)
success = True
except httplib2.ServerNotFoundError, e:
print "still not connected, sleeping"
except KeyboardInterrupt, e:
if os.path.isfile(PID_FILE):
os.unlink(PID_FILE)
success = True
# clean up pid file
if os.path.isfile(PID_FILE):
os.unlink(PID_FILE)
def createDaemon():
"""Detach a process from the controlling terminal and run it in the
background as a daemon.
"""
try:
# Fork a child process so the parent can exit. This returns control to
# the command-line or shell. It also guarantees that the child will not
# be a process group leader, since the child receives a new process ID
# and inherits the parent's process group ID. This step is required
# to insure that the next call to os.setsid is successful.
pid = os.fork()
except OSError, e:
raise Exception, "%s [%d]" % (e.strerror, e.errno)
if (pid == 0): # The first child.
# To become the session leader of this new session and the process group
# leader of the new process group, we call os.setsid(). The process is
# also guaranteed not to have a controlling terminal.
os.setsid()
# Is ignoring SIGHUP necessary?
#
# It's often suggested that the SIGHUP signal should be ignored before
# the second fork to avoid premature termination of the process. The
# reason is that when the first child terminates, all processes, e.g.
# the second child, in the orphaned group will be sent a SIGHUP.
#
# "However, as part of the session management system, there are exactly
# two cases where SIGHUP is sent on the death of a process:
#
# 1) When the process that dies is the session leader of a session that
# is attached to a terminal device, SIGHUP is sent to all processes
# in the foreground process group of that terminal device.
# 2) When the death of a process causes a process group to become
# orphaned, and one or more processes in the orphaned group are
# stopped, then SIGHUP and SIGCONT are sent to all members of the
# orphaned group." [2]
#
# The first case can be ignored since the child is guaranteed not to have
# a controlling terminal. The second case isn't so easy to dismiss.
# The process group is orphaned when the first child terminates and
# POSIX.1 requires that every STOPPED process in an orphaned process
# group be sent a SIGHUP signal followed by a SIGCONT signal. Since the
# second child is not STOPPED though, we can safely forego ignoring the
# SIGHUP signal. In any case, there are no ill-effects if it is ignored.
#
# import signal # Set handlers for asynchronous events.
# signal.signal(signal.SIGHUP, signal.SIG_IGN)
try:
# Fork a second child and exit immediately to prevent zombies. This
# causes the second child process to be orphaned, making the init
# process responsible for its cleanup. And, since the first child is
# a session leader without a controlling terminal, it's possible for
# it to acquire one by opening a terminal in the future (System V-
# based systems). This second fork guarantees that the child is no
# longer a session leader, preventing the daemon from ever acquiring
# a controlling terminal.
pid = os.fork() # Fork a second child.
except OSError, e:
raise Exception, "%s [%d]" % (e.strerror, e.errno)
if (pid == 0): # The second child.
# Since the current working directory may be a mounted filesystem, we
# avoid the issue of not being able to unmount the filesystem at
# shutdown time by changing it to the root directory.
os.chdir(WORKDIR)
# We probably don't want the file mode creation mask inherited from
# the parent, so we give the child complete control over permissions.
os.umask(UMASK)
else:
# exit() or _exit()? See below.
os._exit(0) # Exit parent (the first child) of the second child.
else:
# exit() or _exit()?
# _exit is like exit(), but it doesn't call any functions registered
# with atexit (and on_exit) or any registered signal handlers. It also
# closes any open file descriptors. Using exit() may cause all stdio
# streams to be flushed twice and any temporary files may be unexpectedly
# removed. It's therefore recommended that child branches of a fork()
# and the parent branch(es) of a daemon use _exit().
os._exit(0) # Exit parent of the first child.
# Close all open file descriptors. This prevents the child from keeping
# open any file descriptors inherited from the parent. There is a variety
# of methods to accomplish this task. Three are listed below.
#
# Try the system configuration variable, SC_OPEN_MAX, to obtain the maximum
# number of open file descriptors to close. If it doesn't exists, use
# the default value (configurable).
#
# try:
# maxfd = os.sysconf("SC_OPEN_MAX")
# except (AttributeError, ValueError):
# maxfd = MAXFD
#
# OR
#
# if (os.sysconf_names.has_key("SC_OPEN_MAX")):
# maxfd = os.sysconf("SC_OPEN_MAX")
# else:
# maxfd = MAXFD
#
# OR
#
# Use the getrlimit method to retrieve the maximum file descriptor number
# that can be opened by this process. If there is not limit on the
# resource, use the default value.
#
import resource # Resource usage information.
maxfd = resource.getrlimit(resource.RLIMIT_NOFILE)[1]
if (maxfd == resource.RLIM_INFINITY):
maxfd = MAXFD
# Iterate through and close all file descriptors.
for fd in range(0, maxfd):
try:
os.close(fd)
except OSError: # ERROR, fd wasn't open to begin with (ignored)
pass
# Redirect the standard I/O file descriptors to the specified file. Since
# the daemon has no controlling terminal, most daemons redirect stdin,
# stdout, and stderr to /dev/null. This is done to prevent side-effects
# from reads and writes to the standard I/O file descriptors.
# This call to open is guaranteed to return the lowest file descriptor,
# which will be 0 (stdin), since it was closed above.
os.open(REDIRECT_TO, os.O_RDWR) # standard input (0)
# Duplicate standard input to standard output and standard error.
os.dup2(0, 1) # standard output (1)
os.dup2(0, 2) # standard error (2)
return(0)
if __name__ == '__main__':
main()

View File

@@ -64,10 +64,17 @@ class Utilities(unittest.TestCase):
class DiscoveryErrors(unittest.TestCase): class DiscoveryErrors(unittest.TestCase):
def test_tests_should_be_run_with_strict_positional_enforcement(self):
try:
plus = build('plus', 'v1', None)
self.fail("should have raised a TypeError exception over missing http=.")
except TypeError:
pass
def test_failed_to_parse_discovery_json(self): def test_failed_to_parse_discovery_json(self):
self.http = HttpMock(datafile('malformed.json'), {'status': '200'}) self.http = HttpMock(datafile('malformed.json'), {'status': '200'})
try: try:
plus = build('plus', 'v1', self.http) plus = build('plus', 'v1', http=self.http)
self.fail("should have raised an exception over malformed JSON.") self.fail("should have raised an exception over malformed JSON.")
except InvalidJsonError: except InvalidJsonError:
pass pass
@@ -103,7 +110,7 @@ class DiscoveryFromHttp(unittest.TestCase):
http = HttpMockSequence([ http = HttpMockSequence([
({'status': '400'}, file(datafile('zoo.json'), 'r').read()), ({'status': '400'}, file(datafile('zoo.json'), 'r').read()),
]) ])
zoo = build('zoo', 'v1', http, developerKey='foo', zoo = build('zoo', 'v1', http=http, developerKey='foo',
discoveryServiceUrl='http://example.com') discoveryServiceUrl='http://example.com')
self.fail('Should have raised an exception.') self.fail('Should have raised an exception.')
except HttpError, e: except HttpError, e:
@@ -116,7 +123,7 @@ class DiscoveryFromHttp(unittest.TestCase):
http = HttpMockSequence([ http = HttpMockSequence([
({'status': '400'}, file(datafile('zoo.json'), 'r').read()), ({'status': '400'}, file(datafile('zoo.json'), 'r').read()),
]) ])
zoo = build('zoo', 'v1', http, developerKey=None, zoo = build('zoo', 'v1', http=http, developerKey=None,
discoveryServiceUrl='http://example.com') discoveryServiceUrl='http://example.com')
self.fail('Should have raised an exception.') self.fail('Should have raised an exception.')
except HttpError, e: except HttpError, e:
@@ -127,7 +134,7 @@ class Discovery(unittest.TestCase):
def test_method_error_checking(self): def test_method_error_checking(self):
self.http = HttpMock(datafile('plus.json'), {'status': '200'}) self.http = HttpMock(datafile('plus.json'), {'status': '200'})
plus = build('plus', 'v1', self.http) plus = build('plus', 'v1', http=self.http)
# Missing required parameters # Missing required parameters
try: try:
@@ -170,7 +177,7 @@ class Discovery(unittest.TestCase):
def test_type_coercion(self): def test_type_coercion(self):
http = HttpMock(datafile('zoo.json'), {'status': '200'}) http = HttpMock(datafile('zoo.json'), {'status': '200'})
zoo = build('zoo', 'v1', http) zoo = build('zoo', 'v1', http=http)
request = zoo.query( request = zoo.query(
q="foo", i=1.0, n=1.0, b=0, a=[1,2,3], o={'a':1}, e='bar') q="foo", i=1.0, n=1.0, b=0, a=[1,2,3], o={'a':1}, e='bar')
@@ -192,7 +199,7 @@ class Discovery(unittest.TestCase):
def test_optional_stack_query_parameters(self): def test_optional_stack_query_parameters(self):
http = HttpMock(datafile('zoo.json'), {'status': '200'}) http = HttpMock(datafile('zoo.json'), {'status': '200'})
zoo = build('zoo', 'v1', http) zoo = build('zoo', 'v1', http=http)
request = zoo.query(trace='html', fields='description') request = zoo.query(trace='html', fields='description')
parsed = urlparse.urlparse(request.uri) parsed = urlparse.urlparse(request.uri)
@@ -202,7 +209,7 @@ class Discovery(unittest.TestCase):
def test_string_params_value_of_none_get_dropped(self): def test_string_params_value_of_none_get_dropped(self):
http = HttpMock(datafile('zoo.json'), {'status': '200'}) http = HttpMock(datafile('zoo.json'), {'status': '200'})
zoo = build('zoo', 'v1', http) zoo = build('zoo', 'v1', http=http)
request = zoo.query(trace=None, fields='description') request = zoo.query(trace=None, fields='description')
parsed = urlparse.urlparse(request.uri) parsed = urlparse.urlparse(request.uri)
@@ -211,7 +218,7 @@ class Discovery(unittest.TestCase):
def test_model_added_query_parameters(self): def test_model_added_query_parameters(self):
http = HttpMock(datafile('zoo.json'), {'status': '200'}) http = HttpMock(datafile('zoo.json'), {'status': '200'})
zoo = build('zoo', 'v1', http) zoo = build('zoo', 'v1', http=http)
request = zoo.animals().get(name='Lion') request = zoo.animals().get(name='Lion')
parsed = urlparse.urlparse(request.uri) parsed = urlparse.urlparse(request.uri)
@@ -221,7 +228,7 @@ class Discovery(unittest.TestCase):
def test_fallback_to_raw_model(self): def test_fallback_to_raw_model(self):
http = HttpMock(datafile('zoo.json'), {'status': '200'}) http = HttpMock(datafile('zoo.json'), {'status': '200'})
zoo = build('zoo', 'v1', http) zoo = build('zoo', 'v1', http=http)
request = zoo.animals().getmedia(name='Lion') request = zoo.animals().getmedia(name='Lion')
parsed = urlparse.urlparse(request.uri) parsed = urlparse.urlparse(request.uri)
@@ -231,7 +238,7 @@ class Discovery(unittest.TestCase):
def test_patch(self): def test_patch(self):
http = HttpMock(datafile('zoo.json'), {'status': '200'}) http = HttpMock(datafile('zoo.json'), {'status': '200'})
zoo = build('zoo', 'v1', http) zoo = build('zoo', 'v1', http=http)
request = zoo.animals().patch(name='lion', body='{"description": "foo"}') request = zoo.animals().patch(name='lion', body='{"description": "foo"}')
self.assertEqual(request.method, 'PATCH') self.assertEqual(request.method, 'PATCH')
@@ -242,7 +249,7 @@ class Discovery(unittest.TestCase):
({'status': '200'}, 'echo_request_headers_as_json'), ({'status': '200'}, 'echo_request_headers_as_json'),
]) ])
http = tunnel_patch(http) http = tunnel_patch(http)
zoo = build('zoo', 'v1', http) zoo = build('zoo', 'v1', http=http)
resp = zoo.animals().patch( resp = zoo.animals().patch(
name='lion', body='{"description": "foo"}').execute() name='lion', body='{"description": "foo"}').execute()
@@ -250,7 +257,7 @@ class Discovery(unittest.TestCase):
def test_plus_resources(self): def test_plus_resources(self):
self.http = HttpMock(datafile('plus.json'), {'status': '200'}) self.http = HttpMock(datafile('plus.json'), {'status': '200'})
plus = build('plus', 'v1', self.http) plus = build('plus', 'v1', http=self.http)
self.assertTrue(getattr(plus, 'activities')) self.assertTrue(getattr(plus, 'activities'))
self.assertTrue(getattr(plus, 'people')) self.assertTrue(getattr(plus, 'people'))
@@ -258,7 +265,7 @@ class Discovery(unittest.TestCase):
# Zoo should exercise all discovery facets # Zoo should exercise all discovery facets
# and should also have no future.json file. # and should also have no future.json file.
self.http = HttpMock(datafile('zoo.json'), {'status': '200'}) self.http = HttpMock(datafile('zoo.json'), {'status': '200'})
zoo = build('zoo', 'v1', self.http) zoo = build('zoo', 'v1', http=self.http)
self.assertTrue(getattr(zoo, 'animals')) self.assertTrue(getattr(zoo, 'animals'))
request = zoo.animals().list(name='bat', projection="full") request = zoo.animals().list(name='bat', projection="full")
@@ -269,7 +276,7 @@ class Discovery(unittest.TestCase):
def test_nested_resources(self): def test_nested_resources(self):
self.http = HttpMock(datafile('zoo.json'), {'status': '200'}) self.http = HttpMock(datafile('zoo.json'), {'status': '200'})
zoo = build('zoo', 'v1', self.http) zoo = build('zoo', 'v1', http=self.http)
self.assertTrue(getattr(zoo, 'animals')) self.assertTrue(getattr(zoo, 'animals'))
request = zoo.my().favorites().list(max_results="5") request = zoo.my().favorites().list(max_results="5")
parsed = urlparse.urlparse(request.uri) parsed = urlparse.urlparse(request.uri)
@@ -278,7 +285,7 @@ class Discovery(unittest.TestCase):
def test_methods_with_reserved_names(self): def test_methods_with_reserved_names(self):
self.http = HttpMock(datafile('zoo.json'), {'status': '200'}) self.http = HttpMock(datafile('zoo.json'), {'status': '200'})
zoo = build('zoo', 'v1', self.http) zoo = build('zoo', 'v1', http=self.http)
self.assertTrue(getattr(zoo, 'animals')) self.assertTrue(getattr(zoo, 'animals'))
request = zoo.global_().print_().assert_(max_results="5") request = zoo.global_().print_().assert_(max_results="5")
parsed = urlparse.urlparse(request.uri) parsed = urlparse.urlparse(request.uri)
@@ -286,7 +293,7 @@ class Discovery(unittest.TestCase):
def test_top_level_functions(self): def test_top_level_functions(self):
self.http = HttpMock(datafile('zoo.json'), {'status': '200'}) self.http = HttpMock(datafile('zoo.json'), {'status': '200'})
zoo = build('zoo', 'v1', self.http) zoo = build('zoo', 'v1', http=self.http)
self.assertTrue(getattr(zoo, 'query')) self.assertTrue(getattr(zoo, 'query'))
request = zoo.query(q="foo") request = zoo.query(q="foo")
parsed = urlparse.urlparse(request.uri) parsed = urlparse.urlparse(request.uri)
@@ -295,20 +302,20 @@ class Discovery(unittest.TestCase):
def test_simple_media_uploads(self): def test_simple_media_uploads(self):
self.http = HttpMock(datafile('zoo.json'), {'status': '200'}) self.http = HttpMock(datafile('zoo.json'), {'status': '200'})
zoo = build('zoo', 'v1', self.http) zoo = build('zoo', 'v1', http=self.http)
doc = getattr(zoo.animals().insert, '__doc__') doc = getattr(zoo.animals().insert, '__doc__')
self.assertTrue('media_body' in doc) self.assertTrue('media_body' in doc)
def test_simple_media_upload_no_max_size_provided(self): def test_simple_media_upload_no_max_size_provided(self):
self.http = HttpMock(datafile('zoo.json'), {'status': '200'}) self.http = HttpMock(datafile('zoo.json'), {'status': '200'})
zoo = build('zoo', 'v1', self.http) zoo = build('zoo', 'v1', http=self.http)
request = zoo.animals().crossbreed(media_body=datafile('small.png')) request = zoo.animals().crossbreed(media_body=datafile('small.png'))
self.assertEquals('image/png', request.headers['content-type']) self.assertEquals('image/png', request.headers['content-type'])
self.assertEquals('PNG', request.body[1:4]) self.assertEquals('PNG', request.body[1:4])
def test_simple_media_raise_correct_exceptions(self): def test_simple_media_raise_correct_exceptions(self):
self.http = HttpMock(datafile('zoo.json'), {'status': '200'}) self.http = HttpMock(datafile('zoo.json'), {'status': '200'})
zoo = build('zoo', 'v1', self.http) zoo = build('zoo', 'v1', http=self.http)
try: try:
zoo.animals().insert(media_body=datafile('smiley.png')) zoo.animals().insert(media_body=datafile('smiley.png'))
@@ -324,7 +331,7 @@ class Discovery(unittest.TestCase):
def test_simple_media_good_upload(self): def test_simple_media_good_upload(self):
self.http = HttpMock(datafile('zoo.json'), {'status': '200'}) self.http = HttpMock(datafile('zoo.json'), {'status': '200'})
zoo = build('zoo', 'v1', self.http) zoo = build('zoo', 'v1', http=self.http)
request = zoo.animals().insert(media_body=datafile('small.png')) request = zoo.animals().insert(media_body=datafile('small.png'))
self.assertEquals('image/png', request.headers['content-type']) self.assertEquals('image/png', request.headers['content-type'])
@@ -335,7 +342,7 @@ class Discovery(unittest.TestCase):
def test_multipart_media_raise_correct_exceptions(self): def test_multipart_media_raise_correct_exceptions(self):
self.http = HttpMock(datafile('zoo.json'), {'status': '200'}) self.http = HttpMock(datafile('zoo.json'), {'status': '200'})
zoo = build('zoo', 'v1', self.http) zoo = build('zoo', 'v1', http=self.http)
try: try:
zoo.animals().insert(media_body=datafile('smiley.png'), body={}) zoo.animals().insert(media_body=datafile('smiley.png'), body={})
@@ -351,7 +358,7 @@ class Discovery(unittest.TestCase):
def test_multipart_media_good_upload(self): def test_multipart_media_good_upload(self):
self.http = HttpMock(datafile('zoo.json'), {'status': '200'}) self.http = HttpMock(datafile('zoo.json'), {'status': '200'})
zoo = build('zoo', 'v1', self.http) zoo = build('zoo', 'v1', http=self.http)
request = zoo.animals().insert(media_body=datafile('small.png'), body={}) request = zoo.animals().insert(media_body=datafile('small.png'), body={})
self.assertTrue(request.headers['content-type'].startswith( self.assertTrue(request.headers['content-type'].startswith(
@@ -363,14 +370,14 @@ class Discovery(unittest.TestCase):
def test_media_capable_method_without_media(self): def test_media_capable_method_without_media(self):
self.http = HttpMock(datafile('zoo.json'), {'status': '200'}) self.http = HttpMock(datafile('zoo.json'), {'status': '200'})
zoo = build('zoo', 'v1', self.http) zoo = build('zoo', 'v1', http=self.http)
request = zoo.animals().insert(body={}) request = zoo.animals().insert(body={})
self.assertTrue(request.headers['content-type'], 'application/json') self.assertTrue(request.headers['content-type'], 'application/json')
def test_resumable_multipart_media_good_upload(self): def test_resumable_multipart_media_good_upload(self):
self.http = HttpMock(datafile('zoo.json'), {'status': '200'}) self.http = HttpMock(datafile('zoo.json'), {'status': '200'})
zoo = build('zoo', 'v1', self.http) zoo = build('zoo', 'v1', http=self.http)
media_upload = MediaFileUpload(datafile('small.png'), resumable=True) media_upload = MediaFileUpload(datafile('small.png'), resumable=True)
request = zoo.animals().insert(media_body=media_upload, body={}) request = zoo.animals().insert(media_body=media_upload, body={})
@@ -396,7 +403,7 @@ class Discovery(unittest.TestCase):
({'status': '200'}, '{"foo": "bar"}'), ({'status': '200'}, '{"foo": "bar"}'),
]) ])
status, body = request.next_chunk(http) status, body = request.next_chunk(http=http)
self.assertEquals(None, body) self.assertEquals(None, body)
self.assertTrue(isinstance(status, MediaUploadProgress)) self.assertTrue(isinstance(status, MediaUploadProgress))
self.assertEquals(13, status.resumable_progress) self.assertEquals(13, status.resumable_progress)
@@ -408,13 +415,13 @@ class Discovery(unittest.TestCase):
self.assertEquals(media_upload, request.resumable) self.assertEquals(media_upload, request.resumable)
self.assertEquals(13, request.resumable_progress) self.assertEquals(13, request.resumable_progress)
status, body = request.next_chunk(http) status, body = request.next_chunk(http=http)
self.assertEquals(request.resumable_uri, 'http://upload.example.com/3') self.assertEquals(request.resumable_uri, 'http://upload.example.com/3')
self.assertEquals(media_upload.size()-1, request.resumable_progress) self.assertEquals(media_upload.size()-1, request.resumable_progress)
self.assertEquals('{"data": {}}', request.body) self.assertEquals('{"data": {}}', request.body)
# Final call to next_chunk should complete the upload. # Final call to next_chunk should complete the upload.
status, body = request.next_chunk(http) status, body = request.next_chunk(http=http)
self.assertEquals(body, {"foo": "bar"}) self.assertEquals(body, {"foo": "bar"})
self.assertEquals(status, None) self.assertEquals(status, None)
@@ -422,7 +429,7 @@ class Discovery(unittest.TestCase):
def test_resumable_media_good_upload(self): def test_resumable_media_good_upload(self):
"""Not a multipart upload.""" """Not a multipart upload."""
self.http = HttpMock(datafile('zoo.json'), {'status': '200'}) self.http = HttpMock(datafile('zoo.json'), {'status': '200'})
zoo = build('zoo', 'v1', self.http) zoo = build('zoo', 'v1', http=self.http)
media_upload = MediaFileUpload(datafile('small.png'), resumable=True) media_upload = MediaFileUpload(datafile('small.png'), resumable=True)
request = zoo.animals().insert(media_body=media_upload, body=None) request = zoo.animals().insert(media_body=media_upload, body=None)
@@ -445,7 +452,7 @@ class Discovery(unittest.TestCase):
({'status': '200'}, '{"foo": "bar"}'), ({'status': '200'}, '{"foo": "bar"}'),
]) ])
status, body = request.next_chunk(http) status, body = request.next_chunk(http=http)
self.assertEquals(None, body) self.assertEquals(None, body)
self.assertTrue(isinstance(status, MediaUploadProgress)) self.assertTrue(isinstance(status, MediaUploadProgress))
self.assertEquals(13, status.resumable_progress) self.assertEquals(13, status.resumable_progress)
@@ -457,13 +464,13 @@ class Discovery(unittest.TestCase):
self.assertEquals(media_upload, request.resumable) self.assertEquals(media_upload, request.resumable)
self.assertEquals(13, request.resumable_progress) self.assertEquals(13, request.resumable_progress)
status, body = request.next_chunk(http) status, body = request.next_chunk(http=http)
self.assertEquals(request.resumable_uri, 'http://upload.example.com/3') self.assertEquals(request.resumable_uri, 'http://upload.example.com/3')
self.assertEquals(media_upload.size()-1, request.resumable_progress) self.assertEquals(media_upload.size()-1, request.resumable_progress)
self.assertEquals(request.body, None) self.assertEquals(request.body, None)
# Final call to next_chunk should complete the upload. # Final call to next_chunk should complete the upload.
status, body = request.next_chunk(http) status, body = request.next_chunk(http=http)
self.assertEquals(body, {"foo": "bar"}) self.assertEquals(body, {"foo": "bar"})
self.assertEquals(status, None) self.assertEquals(status, None)
@@ -471,7 +478,7 @@ class Discovery(unittest.TestCase):
def test_resumable_media_good_upload_from_execute(self): def test_resumable_media_good_upload_from_execute(self):
"""Not a multipart upload.""" """Not a multipart upload."""
self.http = HttpMock(datafile('zoo.json'), {'status': '200'}) self.http = HttpMock(datafile('zoo.json'), {'status': '200'})
zoo = build('zoo', 'v1', self.http) zoo = build('zoo', 'v1', http=self.http)
media_upload = MediaFileUpload(datafile('small.png'), resumable=True) media_upload = MediaFileUpload(datafile('small.png'), resumable=True)
request = zoo.animals().insert(media_body=media_upload, body=None) request = zoo.animals().insert(media_body=media_upload, body=None)
@@ -491,13 +498,13 @@ class Discovery(unittest.TestCase):
({'status': '200'}, '{"foo": "bar"}'), ({'status': '200'}, '{"foo": "bar"}'),
]) ])
body = request.execute(http) body = request.execute(http=http)
self.assertEquals(body, {"foo": "bar"}) self.assertEquals(body, {"foo": "bar"})
def test_resumable_media_fail_unknown_response_code_first_request(self): def test_resumable_media_fail_unknown_response_code_first_request(self):
"""Not a multipart upload.""" """Not a multipart upload."""
self.http = HttpMock(datafile('zoo.json'), {'status': '200'}) self.http = HttpMock(datafile('zoo.json'), {'status': '200'})
zoo = build('zoo', 'v1', self.http) zoo = build('zoo', 'v1', http=self.http)
media_upload = MediaFileUpload(datafile('small.png'), resumable=True) media_upload = MediaFileUpload(datafile('small.png'), resumable=True)
request = zoo.animals().insert(media_body=media_upload, body=None) request = zoo.animals().insert(media_body=media_upload, body=None)
@@ -507,12 +514,12 @@ class Discovery(unittest.TestCase):
'location': 'http://upload.example.com'}, ''), 'location': 'http://upload.example.com'}, ''),
]) ])
self.assertRaises(ResumableUploadError, request.execute, http) self.assertRaises(ResumableUploadError, request.execute, http=http)
def test_resumable_media_fail_unknown_response_code_subsequent_request(self): def test_resumable_media_fail_unknown_response_code_subsequent_request(self):
"""Not a multipart upload.""" """Not a multipart upload."""
self.http = HttpMock(datafile('zoo.json'), {'status': '200'}) self.http = HttpMock(datafile('zoo.json'), {'status': '200'})
zoo = build('zoo', 'v1', self.http) zoo = build('zoo', 'v1', http=self.http)
media_upload = MediaFileUpload(datafile('small.png'), resumable=True) media_upload = MediaFileUpload(datafile('small.png'), resumable=True)
request = zoo.animals().insert(media_body=media_upload, body=None) request = zoo.animals().insert(media_body=media_upload, body=None)
@@ -523,7 +530,7 @@ class Discovery(unittest.TestCase):
({'status': '400'}, ''), ({'status': '400'}, ''),
]) ])
self.assertRaises(HttpError, request.execute, http) self.assertRaises(HttpError, request.execute, http=http)
self.assertTrue(request._in_error_state) self.assertTrue(request._in_error_state)
http = HttpMockSequence([ http = HttpMockSequence([
@@ -533,7 +540,7 @@ class Discovery(unittest.TestCase):
'range': '0-6'}, ''), 'range': '0-6'}, ''),
]) ])
status, body = request.next_chunk(http) status, body = request.next_chunk(http=http)
self.assertEquals(status.resumable_progress, 7, self.assertEquals(status.resumable_progress, 7,
'Should have first checked length and then tried to PUT more.') 'Should have first checked length and then tried to PUT more.')
self.assertFalse(request._in_error_state) self.assertFalse(request._in_error_state)
@@ -542,14 +549,14 @@ class Discovery(unittest.TestCase):
http = HttpMockSequence([ http = HttpMockSequence([
({'status': '400'}, ''), ({'status': '400'}, ''),
]) ])
self.assertRaises(HttpError, request.execute, http) self.assertRaises(HttpError, request.execute, http=http)
self.assertTrue(request._in_error_state) self.assertTrue(request._in_error_state)
# Pretend the last request that 400'd actually succeeded. # Pretend the last request that 400'd actually succeeded.
http = HttpMockSequence([ http = HttpMockSequence([
({'status': '200'}, '{"foo": "bar"}'), ({'status': '200'}, '{"foo": "bar"}'),
]) ])
status, body = request.next_chunk(http) status, body = request.next_chunk(http=http)
self.assertEqual(body, {'foo': 'bar'}) self.assertEqual(body, {'foo': 'bar'})
def test_resumable_media_handle_uploads_of_unknown_size(self): def test_resumable_media_handle_uploads_of_unknown_size(self):
@@ -560,7 +567,7 @@ class Discovery(unittest.TestCase):
]) ])
self.http = HttpMock(datafile('zoo.json'), {'status': '200'}) self.http = HttpMock(datafile('zoo.json'), {'status': '200'})
zoo = build('zoo', 'v1', self.http) zoo = build('zoo', 'v1', http=self.http)
fd = StringIO.StringIO('data goes here') fd = StringIO.StringIO('data goes here')
@@ -569,7 +576,7 @@ class Discovery(unittest.TestCase):
fd=fd, mimetype='image/png', chunksize=10, resumable=True) fd=fd, mimetype='image/png', chunksize=10, resumable=True)
request = zoo.animals().insert(media_body=upload, body=None) request = zoo.animals().insert(media_body=upload, body=None)
status, body = request.next_chunk(http) status, body = request.next_chunk(http=http)
self.assertEqual(body, {'Content-Range': 'bytes 0-9/*'}, self.assertEqual(body, {'Content-Range': 'bytes 0-9/*'},
'Should be 10 out of * bytes.') 'Should be 10 out of * bytes.')
@@ -581,7 +588,7 @@ class Discovery(unittest.TestCase):
]) ])
self.http = HttpMock(datafile('zoo.json'), {'status': '200'}) self.http = HttpMock(datafile('zoo.json'), {'status': '200'})
zoo = build('zoo', 'v1', self.http) zoo = build('zoo', 'v1', http=self.http)
fd = StringIO.StringIO('data goes here') fd = StringIO.StringIO('data goes here')
@@ -590,7 +597,7 @@ class Discovery(unittest.TestCase):
fd=fd, mimetype='image/png', chunksize=15, resumable=True) fd=fd, mimetype='image/png', chunksize=15, resumable=True)
request = zoo.animals().insert(media_body=upload, body=None) request = zoo.animals().insert(media_body=upload, body=None)
status, body = request.next_chunk(http) status, body = request.next_chunk(http=http)
self.assertEqual(body, {'Content-Range': 'bytes 0-13/14'}) self.assertEqual(body, {'Content-Range': 'bytes 0-13/14'})
def test_resumable_media_handle_resume_of_upload_of_unknown_size(self): def test_resumable_media_handle_resume_of_upload_of_unknown_size(self):
@@ -601,7 +608,7 @@ class Discovery(unittest.TestCase):
]) ])
self.http = HttpMock(datafile('zoo.json'), {'status': '200'}) self.http = HttpMock(datafile('zoo.json'), {'status': '200'})
zoo = build('zoo', 'v1', self.http) zoo = build('zoo', 'v1', http=self.http)
# Create an upload that doesn't know the full size of the media. # Create an upload that doesn't know the full size of the media.
fd = StringIO.StringIO('data goes here') fd = StringIO.StringIO('data goes here')
@@ -612,7 +619,7 @@ class Discovery(unittest.TestCase):
request = zoo.animals().insert(media_body=upload, body=None) request = zoo.animals().insert(media_body=upload, body=None)
# Put it in an error state. # Put it in an error state.
self.assertRaises(HttpError, request.next_chunk, http) self.assertRaises(HttpError, request.next_chunk, http=http)
http = HttpMockSequence([ http = HttpMockSequence([
({'status': '400', ({'status': '400',
@@ -620,7 +627,7 @@ class Discovery(unittest.TestCase):
]) ])
try: try:
# Should resume the upload by first querying the status of the upload. # Should resume the upload by first querying the status of the upload.
request.next_chunk(http) request.next_chunk(http=http)
except HttpError, e: except HttpError, e:
expected = { expected = {
'Content-Range': 'bytes */*', 'Content-Range': 'bytes */*',
@@ -634,13 +641,13 @@ class Next(unittest.TestCase):
def test_next_successful_none_on_no_next_page_token(self): def test_next_successful_none_on_no_next_page_token(self):
self.http = HttpMock(datafile('tasks.json'), {'status': '200'}) self.http = HttpMock(datafile('tasks.json'), {'status': '200'})
tasks = build('tasks', 'v1', self.http) tasks = build('tasks', 'v1', http=self.http)
request = tasks.tasklists().list() request = tasks.tasklists().list()
self.assertEqual(None, tasks.tasklists().list_next(request, {})) self.assertEqual(None, tasks.tasklists().list_next(request, {}))
def test_next_successful_with_next_page_token(self): def test_next_successful_with_next_page_token(self):
self.http = HttpMock(datafile('tasks.json'), {'status': '200'}) self.http = HttpMock(datafile('tasks.json'), {'status': '200'})
tasks = build('tasks', 'v1', self.http) tasks = build('tasks', 'v1', http=self.http)
request = tasks.tasklists().list() request = tasks.tasklists().list()
next_request = tasks.tasklists().list_next( next_request = tasks.tasklists().list_next(
request, {'nextPageToken': '123abc'}) request, {'nextPageToken': '123abc'})
@@ -650,7 +657,7 @@ class Next(unittest.TestCase):
def test_next_with_method_with_no_properties(self): def test_next_with_method_with_no_properties(self):
self.http = HttpMock(datafile('latitude.json'), {'status': '200'}) self.http = HttpMock(datafile('latitude.json'), {'status': '200'})
service = build('latitude', 'v1', self.http) service = build('latitude', 'v1', http=self.http)
request = service.currentLocation().get() request = service.currentLocation().get()
@@ -658,7 +665,7 @@ class MediaGet(unittest.TestCase):
def test_get_media(self): def test_get_media(self):
http = HttpMock(datafile('zoo.json'), {'status': '200'}) http = HttpMock(datafile('zoo.json'), {'status': '200'})
zoo = build('zoo', 'v1', http) zoo = build('zoo', 'v1', http=http)
request = zoo.animals().get_media(name='Lion') request = zoo.animals().get_media(name='Lion')
parsed = urlparse.urlparse(request.uri) parsed = urlparse.urlparse(request.uri)
@@ -669,7 +676,7 @@ class MediaGet(unittest.TestCase):
http = HttpMockSequence([ http = HttpMockSequence([
({'status': '200'}, 'standing in for media'), ({'status': '200'}, 'standing in for media'),
]) ])
response = request.execute(http) response = request.execute(http=http)
self.assertEqual('standing in for media', response) self.assertEqual('standing in for media', response)

View File

@@ -59,7 +59,7 @@ class Error(unittest.TestCase):
resp, content = fake_response(JSON_ERROR_CONTENT, resp, content = fake_response(JSON_ERROR_CONTENT,
{'status':'400', 'content-type': 'application/json'}, {'status':'400', 'content-type': 'application/json'},
reason='Failed') reason='Failed')
error = HttpError(resp, content, 'http://example.org') error = HttpError(resp, content, uri='http://example.org')
self.assertEqual(str(error), '<HttpError 400 when requesting http://example.org returned "country is required">') self.assertEqual(str(error), '<HttpError 400 when requesting http://example.org returned "country is required">')
def test_bad_json_body(self): def test_bad_json_body(self):
@@ -75,7 +75,7 @@ class Error(unittest.TestCase):
resp, content = fake_response('{', resp, content = fake_response('{',
{'status':'400', 'content-type': 'application/json'}, {'status':'400', 'content-type': 'application/json'},
reason='Failure') reason='Failure')
error = HttpError(resp, content, 'http://example.org') error = HttpError(resp, content, uri='http://example.org')
self.assertEqual(str(error), '<HttpError 400 when requesting http://example.org returned "Failure">') self.assertEqual(str(error), '<HttpError 400 when requesting http://example.org returned "Failure">')
def test_missing_message_json_body(self): def test_missing_message_json_body(self):

View File

@@ -140,7 +140,7 @@ class TestMediaUpload(unittest.TestCase):
self.assertEqual('PNG', new_upload.getbytes(1, 3)) self.assertEqual('PNG', new_upload.getbytes(1, 3))
def test_media_inmemory_upload(self): def test_media_inmemory_upload(self):
media = MediaInMemoryUpload('abcdef', 'text/plain', chunksize=10, media = MediaInMemoryUpload('abcdef', mimetype='text/plain', chunksize=10,
resumable=True) resumable=True)
self.assertEqual('text/plain', media.mimetype()) self.assertEqual('text/plain', media.mimetype())
self.assertEqual(10, media.chunksize()) self.assertEqual(10, media.chunksize())
@@ -149,7 +149,7 @@ class TestMediaUpload(unittest.TestCase):
self.assertEqual(6, media.size()) self.assertEqual(6, media.size())
def test_media_inmemory_upload_json_roundtrip(self): def test_media_inmemory_upload_json_roundtrip(self):
media = MediaInMemoryUpload(os.urandom(64), 'text/plain', chunksize=10, media = MediaInMemoryUpload(os.urandom(64), mimetype='text/plain', chunksize=10,
resumable=True) resumable=True)
data = media.to_json() data = media.to_json()
newmedia = MediaInMemoryUpload.new_from_json(data) newmedia = MediaInMemoryUpload.new_from_json(data)
@@ -261,7 +261,7 @@ class TestMediaIoBaseDownload(unittest.TestCase):
def setUp(self): def setUp(self):
http = HttpMock(datafile('zoo.json'), {'status': '200'}) http = HttpMock(datafile('zoo.json'), {'status': '200'})
zoo = build('zoo', 'v1', http) zoo = build('zoo', 'v1', http=http)
self.request = zoo.animals().get_media(name='Lion') self.request = zoo.animals().get_media(name='Lion')
self.fd = StringIO.StringIO() self.fd = StringIO.StringIO()
@@ -587,7 +587,7 @@ class TestBatch(unittest.TestCase):
'content-type': 'multipart/mixed; boundary="batch_foobarbaz"'}, 'content-type': 'multipart/mixed; boundary="batch_foobarbaz"'},
BATCH_RESPONSE), BATCH_RESPONSE),
]) ])
batch.execute(http) batch.execute(http=http)
self.assertEqual({'foo': 42}, callbacks.responses['1']) self.assertEqual({'foo': 42}, callbacks.responses['1'])
self.assertEqual(None, callbacks.exceptions['1']) self.assertEqual(None, callbacks.exceptions['1'])
self.assertEqual({'baz': 'qux'}, callbacks.responses['2']) self.assertEqual({'baz': 'qux'}, callbacks.responses['2'])
@@ -604,7 +604,7 @@ class TestBatch(unittest.TestCase):
'echo_request_body'), 'echo_request_body'),
]) ])
try: try:
batch.execute(http) batch.execute(http=http)
self.fail('Should raise exception') self.fail('Should raise exception')
except BatchError, e: except BatchError, e:
boundary, _ = e.content.split(None, 1) boundary, _ = e.content.split(None, 1)
@@ -642,7 +642,7 @@ class TestBatch(unittest.TestCase):
batch.add(self.request1, callback=callbacks.f) batch.add(self.request1, callback=callbacks.f)
batch.add(self.request2, callback=callbacks.f) batch.add(self.request2, callback=callbacks.f)
batch.execute(http) batch.execute(http=http)
self.assertEqual({'foo': 42}, callbacks.responses['1']) self.assertEqual({'foo': 42}, callbacks.responses['1'])
self.assertEqual(None, callbacks.exceptions['1']) self.assertEqual(None, callbacks.exceptions['1'])
@@ -684,7 +684,7 @@ class TestBatch(unittest.TestCase):
batch.add(self.request1, callback=callbacks.f) batch.add(self.request1, callback=callbacks.f)
batch.add(self.request2, callback=callbacks.f) batch.add(self.request2, callback=callbacks.f)
batch.execute(http) batch.execute(http=http)
self.assertEqual(None, callbacks.responses['1']) self.assertEqual(None, callbacks.responses['1'])
self.assertEqual(401, callbacks.exceptions['1'].resp.status) self.assertEqual(401, callbacks.exceptions['1'].resp.status)
@@ -704,7 +704,7 @@ class TestBatch(unittest.TestCase):
'content-type': 'multipart/mixed; boundary="batch_foobarbaz"'}, 'content-type': 'multipart/mixed; boundary="batch_foobarbaz"'},
BATCH_RESPONSE), BATCH_RESPONSE),
]) ])
batch.execute(http) batch.execute(http=http)
self.assertEqual({'foo': 42}, callbacks.responses['1']) self.assertEqual({'foo': 42}, callbacks.responses['1'])
self.assertEqual({'baz': 'qux'}, callbacks.responses['2']) self.assertEqual({'baz': 'qux'}, callbacks.responses['2'])
@@ -719,7 +719,7 @@ class TestBatch(unittest.TestCase):
'content-type': 'multipart/mixed; boundary="batch_foobarbaz"'}, 'content-type': 'multipart/mixed; boundary="batch_foobarbaz"'},
BATCH_ERROR_RESPONSE), BATCH_ERROR_RESPONSE),
]) ])
batch.execute(http) batch.execute(http=http)
self.assertEqual({'foo': 42}, callbacks.responses['1']) self.assertEqual({'foo': 42}, callbacks.responses['1'])
expected = ('<HttpError 403 when requesting ' expected = ('<HttpError 403 when requesting '
'https://www.googleapis.com/someapi/v1/collection/?foo=bar returned ' 'https://www.googleapis.com/someapi/v1/collection/?foo=bar returned '

View File

@@ -55,13 +55,16 @@ from oauth2client.client import flow_from_clientsecrets
DATA_DIR = os.path.join(os.path.dirname(__file__), 'data') DATA_DIR = os.path.join(os.path.dirname(__file__), 'data')
def datafile(filename): def datafile(filename):
return os.path.join(DATA_DIR, filename) return os.path.join(DATA_DIR, filename)
def load_and_cache(existing_file, fakename, cache_mock): def load_and_cache(existing_file, fakename, cache_mock):
client_type, client_info = _loadfile(datafile(existing_file)) client_type, client_info = _loadfile(datafile(existing_file))
cache_mock.cache[fakename] = {client_type: client_info} cache_mock.cache[fakename] = {client_type: client_info}
class CacheMock(object): class CacheMock(object):
def __init__(self): def __init__(self):
self.cache = {} self.cache = {}
@@ -195,7 +198,7 @@ class TestAssertionCredentials(unittest.TestCase):
def setUp(self): def setUp(self):
user_agent = "fun/2.0" user_agent = "fun/2.0"
self.credentials = self.AssertionCredentialsTestImpl(self.assertion_type, self.credentials = self.AssertionCredentialsTestImpl(self.assertion_type,
user_agent) user_agent=user_agent)
def test_assertion_body(self): def test_assertion_body(self):
body = urlparse.parse_qs(self.credentials._generate_refresh_request_body()) body = urlparse.parse_qs(self.credentials._generate_refresh_request_body())
@@ -230,6 +233,7 @@ class ExtractIdTokenText(unittest.TestCase):
self.assertRaises(VerifyJwtTokenError, _extract_id_token, jwt) self.assertRaises(VerifyJwtTokenError, _extract_id_token, jwt)
class OAuth2WebServerFlowTest(unittest.TestCase): class OAuth2WebServerFlowTest(unittest.TestCase):
def setUp(self): def setUp(self):
@@ -237,18 +241,19 @@ class OAuth2WebServerFlowTest(unittest.TestCase):
client_id='client_id+1', client_id='client_id+1',
client_secret='secret+1', client_secret='secret+1',
scope='foo', scope='foo',
redirect_uri=OOB_CALLBACK_URN,
user_agent='unittest-sample/1.0', user_agent='unittest-sample/1.0',
) )
def test_construct_authorize_url(self): def test_construct_authorize_url(self):
authorize_url = self.flow.step1_get_authorize_url('OOB_CALLBACK_URN') authorize_url = self.flow.step1_get_authorize_url()
parsed = urlparse.urlparse(authorize_url) parsed = urlparse.urlparse(authorize_url)
q = parse_qs(parsed[4]) q = parse_qs(parsed[4])
self.assertEqual('client_id+1', q['client_id'][0]) self.assertEqual('client_id+1', q['client_id'][0])
self.assertEqual('code', q['response_type'][0]) self.assertEqual('code', q['response_type'][0])
self.assertEqual('foo', q['scope'][0]) self.assertEqual('foo', q['scope'][0])
self.assertEqual('OOB_CALLBACK_URN', q['redirect_uri'][0]) self.assertEqual(OOB_CALLBACK_URN, q['redirect_uri'][0])
self.assertEqual('offline', q['access_type'][0]) self.assertEqual('offline', q['access_type'][0])
def test_override_flow_access_type(self): def test_override_flow_access_type(self):
@@ -257,17 +262,18 @@ class OAuth2WebServerFlowTest(unittest.TestCase):
client_id='client_id+1', client_id='client_id+1',
client_secret='secret+1', client_secret='secret+1',
scope='foo', scope='foo',
redirect_uri=OOB_CALLBACK_URN,
user_agent='unittest-sample/1.0', user_agent='unittest-sample/1.0',
access_type='online' access_type='online'
) )
authorize_url = flow.step1_get_authorize_url('OOB_CALLBACK_URN') authorize_url = flow.step1_get_authorize_url()
parsed = urlparse.urlparse(authorize_url) parsed = urlparse.urlparse(authorize_url)
q = parse_qs(parsed[4]) q = parse_qs(parsed[4])
self.assertEqual('client_id+1', q['client_id'][0]) self.assertEqual('client_id+1', q['client_id'][0])
self.assertEqual('code', q['response_type'][0]) self.assertEqual('code', q['response_type'][0])
self.assertEqual('foo', q['scope'][0]) self.assertEqual('foo', q['scope'][0])
self.assertEqual('OOB_CALLBACK_URN', q['redirect_uri'][0]) self.assertEqual(OOB_CALLBACK_URN, q['redirect_uri'][0])
self.assertEqual('online', q['access_type'][0]) self.assertEqual('online', q['access_type'][0])
def test_exchange_failure(self): def test_exchange_failure(self):
@@ -276,7 +282,7 @@ class OAuth2WebServerFlowTest(unittest.TestCase):
]) ])
try: try:
credentials = self.flow.step2_exchange('some random code', http) credentials = self.flow.step2_exchange('some random code', http=http)
self.fail("should raise exception if exchange doesn't get 200") self.fail("should raise exception if exchange doesn't get 200")
except FlowExchangeError: except FlowExchangeError:
pass pass
@@ -287,7 +293,7 @@ class OAuth2WebServerFlowTest(unittest.TestCase):
]) ])
try: try:
credentials = self.flow.step2_exchange('some random code', http) credentials = self.flow.step2_exchange('some random code', http=http)
self.fail("should raise exception if exchange doesn't get 200") self.fail("should raise exception if exchange doesn't get 200")
except FlowExchangeError, e: except FlowExchangeError, e:
self.assertEquals('invalid_request', str(e)) self.assertEquals('invalid_request', str(e))
@@ -305,7 +311,7 @@ class OAuth2WebServerFlowTest(unittest.TestCase):
]) ])
try: try:
credentials = self.flow.step2_exchange('some random code', http) credentials = self.flow.step2_exchange('some random code', http=http)
self.fail("should raise exception if exchange doesn't get 200") self.fail("should raise exception if exchange doesn't get 200")
except FlowExchangeError, e: except FlowExchangeError, e:
pass pass
@@ -318,7 +324,7 @@ class OAuth2WebServerFlowTest(unittest.TestCase):
"refresh_token":"8xLOxBtZp8" }"""), "refresh_token":"8xLOxBtZp8" }"""),
]) ])
credentials = self.flow.step2_exchange('some random code', http) credentials = self.flow.step2_exchange('some random code', http=http)
self.assertEqual('SlAV32hkKG', credentials.access_token) self.assertEqual('SlAV32hkKG', credentials.access_token)
self.assertNotEqual(None, credentials.token_expiry) self.assertNotEqual(None, credentials.token_expiry)
self.assertEqual('8xLOxBtZp8', credentials.refresh_token) self.assertEqual('8xLOxBtZp8', credentials.refresh_token)
@@ -328,7 +334,7 @@ class OAuth2WebServerFlowTest(unittest.TestCase):
({'status': '200'}, "access_token=SlAV32hkKG&expires_in=3600"), ({'status': '200'}, "access_token=SlAV32hkKG&expires_in=3600"),
]) ])
credentials = self.flow.step2_exchange('some random code', http) credentials = self.flow.step2_exchange('some random code', http=http)
self.assertEqual('SlAV32hkKG', credentials.access_token) self.assertEqual('SlAV32hkKG', credentials.access_token)
self.assertNotEqual(None, credentials.token_expiry) self.assertNotEqual(None, credentials.token_expiry)
@@ -339,7 +345,7 @@ class OAuth2WebServerFlowTest(unittest.TestCase):
({'status': '200'}, "access_token=SlAV32hkKG&expires=3600"), ({'status': '200'}, "access_token=SlAV32hkKG&expires=3600"),
]) ])
credentials = self.flow.step2_exchange('some random code', http) credentials = self.flow.step2_exchange('some random code', http=http)
self.assertNotEqual(None, credentials.token_expiry) self.assertNotEqual(None, credentials.token_expiry)
def test_exchange_no_expires_in(self): def test_exchange_no_expires_in(self):
@@ -348,7 +354,7 @@ class OAuth2WebServerFlowTest(unittest.TestCase):
"refresh_token":"8xLOxBtZp8" }"""), "refresh_token":"8xLOxBtZp8" }"""),
]) ])
credentials = self.flow.step2_exchange('some random code', http) credentials = self.flow.step2_exchange('some random code', http=http)
self.assertEqual(None, credentials.token_expiry) self.assertEqual(None, credentials.token_expiry)
def test_urlencoded_exchange_no_expires_in(self): def test_urlencoded_exchange_no_expires_in(self):
@@ -358,7 +364,7 @@ class OAuth2WebServerFlowTest(unittest.TestCase):
({'status': '200'}, "access_token=SlAV32hkKG"), ({'status': '200'}, "access_token=SlAV32hkKG"),
]) ])
credentials = self.flow.step2_exchange('some random code', http) credentials = self.flow.step2_exchange('some random code', http=http)
self.assertEqual(None, credentials.token_expiry) self.assertEqual(None, credentials.token_expiry)
def test_exchange_fails_if_no_code(self): def test_exchange_fails_if_no_code(self):
@@ -369,7 +375,7 @@ class OAuth2WebServerFlowTest(unittest.TestCase):
code = {'error': 'thou shall not pass'} code = {'error': 'thou shall not pass'}
try: try:
credentials = self.flow.step2_exchange(code, http) credentials = self.flow.step2_exchange(code, http=http)
self.fail('should raise exception if no code in dictionary.') self.fail('should raise exception if no code in dictionary.')
except FlowExchangeError, e: except FlowExchangeError, e:
self.assertTrue('shall not pass' in str(e)) self.assertTrue('shall not pass' in str(e))
@@ -382,7 +388,7 @@ class OAuth2WebServerFlowTest(unittest.TestCase):
]) ])
self.assertRaises(VerifyJwtTokenError, self.flow.step2_exchange, self.assertRaises(VerifyJwtTokenError, self.flow.step2_exchange,
'some random code', http) 'some random code', http=http)
def test_exchange_id_token_fail(self): def test_exchange_id_token_fail(self):
body = {'foo': 'bar'} body = {'foo': 'bar'}
@@ -396,19 +402,21 @@ class OAuth2WebServerFlowTest(unittest.TestCase):
"id_token": "%s"}""" % jwt), "id_token": "%s"}""" % jwt),
]) ])
credentials = self.flow.step2_exchange('some random code', http) credentials = self.flow.step2_exchange('some random code', http=http)
self.assertEqual(credentials.id_token, body) self.assertEqual(credentials.id_token, body)
class FlowFromCachedClientsecrets(unittest.TestCase):
class FlowFromCachedClientsecrets(unittest.TestCase):
def test_flow_from_clientsecrets_cached(self): def test_flow_from_clientsecrets_cached(self):
cache_mock = CacheMock() cache_mock = CacheMock()
load_and_cache('client_secrets.json', 'some_secrets', cache_mock) load_and_cache('client_secrets.json', 'some_secrets', cache_mock)
# flow_from_clientsecrets(filename, scope, message=None, cache=None) flow = flow_from_clientsecrets(
flow = flow_from_clientsecrets('some_secrets', '', cache=cache_mock) 'some_secrets', '', redirect_uri='oob', cache=cache_mock)
self.assertEquals('foo_client_secret', flow.client_secret) self.assertEquals('foo_client_secret', flow.client_secret)
class CredentialsFromCodeTests(unittest.TestCase): class CredentialsFromCodeTests(unittest.TestCase):
def setUp(self): def setUp(self):
self.client_id = 'client_id_abc' self.client_id = 'client_id_abc'
@@ -424,8 +432,8 @@ class CredentialsFromCodeTests(unittest.TestCase):
"expires_in":3600 }"""), "expires_in":3600 }"""),
]) ])
credentials = credentials_from_code(self.client_id, self.client_secret, credentials = credentials_from_code(self.client_id, self.client_secret,
self.scope, self.code, self.redirect_uri, self.scope, self.code, redirect_uri=self.redirect_uri,
http) http=http)
self.assertEquals(credentials.access_token, 'asdfghjkl') self.assertEquals(credentials.access_token, 'asdfghjkl')
self.assertNotEqual(None, credentials.token_expiry) self.assertNotEqual(None, credentials.token_expiry)
@@ -436,13 +444,12 @@ class CredentialsFromCodeTests(unittest.TestCase):
try: try:
credentials = credentials_from_code(self.client_id, self.client_secret, credentials = credentials_from_code(self.client_id, self.client_secret,
self.scope, self.code, self.redirect_uri, self.scope, self.code, redirect_uri=self.redirect_uri,
http) http=http)
self.fail("should raise exception if exchange doesn't get 200") self.fail("should raise exception if exchange doesn't get 200")
except FlowExchangeError: except FlowExchangeError:
pass pass
def test_exchange_code_and_file_for_token(self): def test_exchange_code_and_file_for_token(self):
http = HttpMockSequence([ http = HttpMockSequence([
({'status': '200'}, ({'status': '200'},
@@ -481,7 +488,6 @@ class CredentialsFromCodeTests(unittest.TestCase):
pass pass
class MemoryCacheTests(unittest.TestCase): class MemoryCacheTests(unittest.TestCase):
def test_get_set_delete(self): def test_get_set_delete(self):

View File

@@ -55,7 +55,6 @@ from oauth2client.appengine import AppAssertionCredentials
from oauth2client.appengine import CredentialsModel from oauth2client.appengine import CredentialsModel
from oauth2client.appengine import FlowProperty from oauth2client.appengine import FlowProperty
from oauth2client.appengine import OAuth2Decorator from oauth2client.appengine import OAuth2Decorator
from oauth2client.appengine import OAuth2Handler
from oauth2client.appengine import StorageByKeyName from oauth2client.appengine import StorageByKeyName
from oauth2client.appengine import oauth2decorator_from_clientsecrets from oauth2client.appengine import oauth2decorator_from_clientsecrets
from oauth2client.client import AccessTokenRefreshError from oauth2client.client import AccessTokenRefreshError
@@ -200,7 +199,8 @@ class FlowPropertyTest(unittest.TestCase):
def test_flow_get_put(self): def test_flow_get_put(self):
instance = TestFlowModel( instance = TestFlowModel(
flow=flow_from_clientsecrets(datafile('client_secrets.json'), 'foo'), flow=flow_from_clientsecrets(datafile('client_secrets.json'), 'foo',
redirect_uri='oob'),
key_name='foo' key_name='foo'
) )
instance.put() instance.put()
@@ -276,6 +276,14 @@ class StorageByKeyNameTest(unittest.TestCase):
self.assertEqual(None, credentials) self.assertEqual(None, credentials)
self.assertEqual(None, memcache.get('foo')) self.assertEqual(None, memcache.get('foo'))
class MockRequest(object):
url = 'https://example.org'
def relative_url(self, rel):
return self.url + rel
class MockRequestHandler(object):
request = MockRequest()
class DecoratorTests(unittest.TestCase): class DecoratorTests(unittest.TestCase):
@@ -312,7 +320,7 @@ class DecoratorTests(unittest.TestCase):
application = webapp2.WSGIApplication([ application = webapp2.WSGIApplication([
('/oauth2callback', OAuth2Handler), ('/oauth2callback', self.decorator.callback_handler()),
('/foo_path', TestRequiredHandler), ('/foo_path', TestRequiredHandler),
webapp2.Route(r'/bar_path/<year:\d{4}>/<month:\d{2}>', webapp2.Route(r'/bar_path/<year:\d{4}>/<month:\d{2}>',
handler=TestAwareHandler, name='bar')], handler=TestAwareHandler, name='bar')],
@@ -441,6 +449,11 @@ class DecoratorTests(unittest.TestCase):
scope=['foo_scope', 'bar_scope'], scope=['foo_scope', 'bar_scope'],
access_type='offline', access_type='offline',
approval_prompt='force') approval_prompt='force')
request_handler = MockRequestHandler()
decorator._create_flow(request_handler)
self.assertEqual('https://example.org/oauth2callback',
decorator.flow.redirect_uri)
self.assertEqual('offline', decorator.flow.params['access_type']) self.assertEqual('offline', decorator.flow.params['access_type'])
self.assertEqual('force', decorator.flow.params['approval_prompt']) self.assertEqual('force', decorator.flow.params['approval_prompt'])
self.assertEqual('foo_user_agent', decorator.flow.user_agent) self.assertEqual('foo_user_agent', decorator.flow.user_agent)

View File

@@ -114,7 +114,7 @@ class CryptTests(unittest.TestCase):
]) ])
contents = verify_id_token(jwt, contents = verify_id_token(jwt,
'some_audience_address@testing.gserviceaccount.com', http) 'some_audience_address@testing.gserviceaccount.com', http=http)
self.assertEqual('billy bob', contents['user']) self.assertEqual('billy bob', contents['user'])
self.assertEqual('data', contents['metadata']['meta']) self.assertEqual('data', contents['metadata']['meta'])
@@ -126,7 +126,7 @@ class CryptTests(unittest.TestCase):
]) ])
self.assertRaises(VerifyJwtTokenError, verify_id_token, jwt, self.assertRaises(VerifyJwtTokenError, verify_id_token, jwt,
'some_audience_address@testing.gserviceaccount.com', http) 'some_audience_address@testing.gserviceaccount.com', http=http)
def test_verify_id_token_bad_tokens(self): def test_verify_id_token_bad_tokens(self):
private_key = datafile('privatekey.p12') private_key = datafile('privatekey.p12')