From 9da2ad89f12b180ea0c1f18e2f36c7dde02a86f4 Mon Sep 17 00:00:00 2001 From: Joe Gregorio Date: Sun, 11 Sep 2011 14:04:44 -0400 Subject: [PATCH] Add fancy locking to oauth2client. Reviewed in http://codereview.appspot.com/4919049/ --- oauth2client/appengine.py | 3 +- oauth2client/client.py | 280 ++++++++++++++------- oauth2client/django_orm.py | 2 +- oauth2client/file.py | 2 +- oauth2client/multistore_file.py | 361 +++++++++++++++++++++++++++ oauth2client/tools.py | 42 ++-- tests/test_oauth2client_appengine.py | 4 +- 7 files changed, 568 insertions(+), 126 deletions(-) create mode 100644 oauth2client/multistore_file.py diff --git a/oauth2client/appengine.py b/oauth2client/appengine.py index 439a579..64fd3ac 100644 --- a/oauth2client/appengine.py +++ b/oauth2client/appengine.py @@ -23,7 +23,6 @@ import httplib2 import pickle import time import base64 -import logging try: # pragma: no cover import simplejson @@ -222,7 +221,7 @@ class StorageByKeyName(Storage): entity = self._model.get_or_insert(self._key_name) credential = getattr(entity, self._property_name) if credential and hasattr(credential, 'set_store'): - credential.set_store(self.put) + credential.set_store(self) if self._cache: self._cache.set(self._key_name, pickle.dumps(credentials)) diff --git a/oauth2client/client.py b/oauth2client/client.py index f547428..894bfb4 100644 --- a/oauth2client/client.py +++ b/oauth2client/client.py @@ -12,10 +12,9 @@ # See the License for the specific language governing permissions and # limitations under the License. -"""An OAuth 2.0 client +"""An OAuth 2.0 client. -Tools for interacting with OAuth 2.0 protected -resources. +Tools for interacting with OAuth 2.0 protected resources. """ __author__ = 'jcgregorio@google.com (Joe Gregorio)' @@ -27,9 +26,9 @@ import logging import urllib import urlparse -try: # pragma: no cover +try: # pragma: no cover import simplejson -except ImportError: # pragma: no cover +except ImportError: # pragma: no cover try: # Try to import from django, should work on App Engine from django.utils import simplejson @@ -38,9 +37,11 @@ except ImportError: # pragma: no cover import json as simplejson try: - from urlparse import parse_qsl + from urlparse import parse_qsl except ImportError: - from cgi import parse_qsl + from cgi import parse_qsl + +logger = logging.getLogger(__name__) class Error(Exception): @@ -92,28 +93,76 @@ class Flow(object): class Storage(object): """Base class for all Storage objects. - Store and retrieve a single credential. + Store and retrieve a single credential. This class supports locking + such that multiple processes and threads can operate on a single + store. """ - def get(self): + def acquire_lock(self): + """Acquires any lock necessary to access this Storage. + + This lock is not reentrant.""" + pass + + def release_lock(self): + """Release the Storage lock. + + Trying to release a lock that isn't held will result in a + RuntimeError. + """ + pass + + def locked_get(self): """Retrieve credential. + The Storage lock must be held when this is called. + Returns: oauth2client.client.Credentials """ _abstract() - def put(self, credentials): + def locked_put(self, credentials): """Write a credential. + The Storage lock must be held when this is called. + Args: credentials: Credentials, the credentials to store. """ _abstract() + def get(self): + """Retrieve credential. + + The Storage lock must *not* be held when this is called. + + Returns: + oauth2client.client.Credentials + """ + self.acquire_lock() + try: + return self.locked_get() + finally: + self.release_lock() + + def put(self, credentials): + """Write a credential. + + The Storage lock must be held when this is called. + + Args: + credentials: Credentials, the credentials to store. + """ + self.acquire_lock() + try: + self.locked_put(credentials) + finally: + self.release_lock() + class OAuth2Credentials(Credentials): - """Credentials object for OAuth 2.0 + """Credentials object for OAuth 2.0. Credentials can be applied to an httplib2.Http object using the authorize() method, which then signs each request from that object with the OAuth 2.0 @@ -123,22 +172,21 @@ class OAuth2Credentials(Credentials): """ def __init__(self, access_token, client_id, client_secret, refresh_token, - token_expiry, token_uri, user_agent): - """Create an instance of OAuth2Credentials + token_expiry, token_uri, user_agent): + """Create an instance of OAuth2Credentials. This constructor is not usually called by the user, instead OAuth2Credentials objects are instantiated by the OAuth2WebServerFlow. Args: - token_uri: string, URI of token endpoint. + access_token: string, access token. client_id: string, client identifier. client_secret: string, client secret. - access_token: string, access token. - token_expiry: datetime, when the access_token expires. refresh_token: string, refresh token. + token_expiry: datetime, when the access_token expires. + token_uri: string, URI of token endpoint. user_agent: string, The HTTP User-Agent to provide for this application. - Notes: store: callable, a callable that when passed a Credential will store the credential back to where it came from. @@ -156,51 +204,66 @@ class OAuth2Credentials(Credentials): # True if the credentials have been revoked or expired and can't be # refreshed. - self._invalid = False + self.invalid = False @property - def invalid(self): - """True if the credentials are invalid, such as being revoked.""" - return getattr(self, '_invalid', False) + def access_token_expired(self): + """True if the credential is expired or invalid. + + If the token_expiry isn't set, we assume the token doesn't expire. + """ + if self.invalid: + return True + + if not self.token_expiry: + return False + + now = datetime.datetime.now() + if now >= self.token_expiry: + logger.info('access_token is expired. Now: %s, token_expiry: %s', + now, self.token_expiry) + return True + return False def set_store(self, store): - """Set the storage for the credential. + """Set the Storage for the credential. Args: - store: callable, a callable that when passed a Credential - will store the credential back to where it came from. + store: Storage, an implementation of Stroage object. This is needed to store the latest access_token if it - has expired and been refreshed. + has expired and been refreshed. This implementation uses + locking to check for updates before updating the + access_token. """ self.store = store + def _updateFromCredential(self, other): + """Update this Credential from another instance.""" + self.__dict__.update(other.__getstate__()) + def __getstate__(self): - """Trim the state down to something that can be pickled. - """ + """Trim the state down to something that can be pickled.""" d = copy.copy(self.__dict__) del d['store'] return d def __setstate__(self, state): - """Reconstitute the state of the object from being pickled. - """ + """Reconstitute the state of the object from being pickled.""" self.__dict__.update(state) self.store = None def _generate_refresh_request_body(self): - """Generate the body that will be used in the refresh request - """ + """Generate the body that will be used in the refresh request.""" body = urllib.urlencode({ - 'grant_type': 'refresh_token', - 'client_id': self.client_id, - 'client_secret': self.client_secret, - 'refresh_token': self.refresh_token, - }) + 'grant_type': 'refresh_token', + 'client_id': self.client_id, + 'client_secret': self.client_secret, + 'refresh_token': self.refresh_token, + }) return body def _generate_refresh_request_headers(self): - """Generate the headers that will be used in the refresh request - """ + """Generate the headers that will be used in the refresh request.""" headers = { 'content-type': 'application/x-www-form-urlencoded', } @@ -211,16 +274,41 @@ class OAuth2Credentials(Credentials): return headers def _refresh(self, http_request): + """Refreshes the access_token. + + This method first checks by reading the Storage object if available. + If a refresh is still needed, it holds the Storage lock until the + refresh is completed. + """ + if not self.store: + self._do_refresh_request(http_request) + else: + self.store.acquire_lock() + try: + new_cred = self.store.locked_get() + if (new_cred and not new_cred.invalid and + new_cred.access_token != self.access_token): + logger.info('Updated access_token read from Storage') + self._updateFromCredential(new_cred) + else: + self._do_refresh_request(http_request) + finally: + self.store.release_lock() + + def _do_refresh_request(self, http_request): """Refresh the access_token using the refresh_token. Args: http: An instance of httplib2.Http.request or something that acts like it. + + Raises: + AccessTokenRefreshError: When the refresh fails. """ body = self._generate_refresh_request_body() headers = self._generate_refresh_request_headers() - logging.info("Refresing access_token") + logger.info('Refresing access_token') resp, content = http_request( self.token_uri, method='POST', body=body, headers=headers) if resp.status == 200: @@ -233,23 +321,20 @@ class OAuth2Credentials(Credentials): seconds=int(d['expires_in'])) + datetime.datetime.now() else: self.token_expiry = None - if self.store is not None: - self.store(self) + if self.store: + self.store.locked_put(self) else: # An {'error':...} response body means the token is expired or revoked, # so we flag the credentials as such. - logging.error('Failed to retrieve access token: %s' % content) + logger.error('Failed to retrieve access token: %s' % content) error_msg = 'Invalid response %s.' % resp['status'] try: d = simplejson.loads(content) if 'error' in d: error_msg = d['error'] - self._invalid = True - if self.store is not None: - self.store(self) - else: - logging.warning( - "Unable to store refreshed credentials, no Storage provided.") + self.invalid = True + if self.store: + self.store.locked_put(self) except: pass raise AccessTokenRefreshError(error_msg) @@ -269,13 +354,11 @@ class OAuth2Credentials(Credentials): h = httplib2.Http() h = credentials.authorize(h) - You can't create a new OAuth - subclass of httplib2.Authenication because - it never gets passed the absolute URI, which is - needed for signing. So instead we have to overload - 'request' with a closure that adds in the - Authorization header and then calls the original version - of 'request()'. + You can't create a new OAuth subclass of httplib2.Authenication + because it never gets passed the absolute URI, which is needed for + signing. So instead we have to overload 'request' with a closure + that adds in the Authorization header and then calls the original + version of 'request()'. """ request_orig = http.request @@ -284,12 +367,12 @@ class OAuth2Credentials(Credentials): redirections=httplib2.DEFAULT_MAX_REDIRECTS, connection_type=None): if not self.access_token: - logging.info("Attempting refresh to obtain initial access_token") + logger.info('Attempting refresh to obtain initial access_token') self._refresh(request_orig) - """Modify the request headers to add the appropriate - Authorization header.""" - if headers == None: + # Modify the request headers to add the appropriate + # Authorization header. + if headers is None: headers = {} headers['authorization'] = 'OAuth ' + self.access_token @@ -303,7 +386,7 @@ class OAuth2Credentials(Credentials): redirections, connection_type) if resp.status == 401: - logging.info("Refreshing because we got a 401") + logger.info('Refreshing due to a 401') self._refresh(request_orig) headers['authorization'] = 'OAuth ' + self.access_token return request_orig(uri, method, body, headers, @@ -316,14 +399,15 @@ class OAuth2Credentials(Credentials): class AccessTokenCredentials(OAuth2Credentials): - """Credentials object for OAuth 2.0 + """Credentials object for OAuth 2.0. - Credentials can be applied to an httplib2.Http object using the authorize() - method, which then signs each request from that object with the OAuth 2.0 - access token. This set of credentials is for the use case where you have - acquired an OAuth 2.0 access_token from another place such as a JavaScript - client or another web application, and wish to use it from Python. Because - only the access_token is present it can not be refreshed and will in time + Credentials can be applied to an httplib2.Http object using the + authorize() method, which then signs each request from that object + with the OAuth 2.0 access token. This set of credentials is for the + use case where you have acquired an OAuth 2.0 access_token from + another place such as a JavaScript client or another web + application, and wish to use it from Python. Because only the + access_token is present it can not be refreshed and will in time expire. AccessTokenCredentials objects may be safely pickled and unpickled. @@ -368,19 +452,20 @@ class AccessTokenCredentials(OAuth2Credentials): class AssertionCredentials(OAuth2Credentials): - """Abstract Credentials object used for OAuth 2.0 assertion grants + """Abstract Credentials object used for OAuth 2.0 assertion grants. - This credential does not require a flow to instantiate because it represents - a two legged flow, and therefore has all of the required information to - generate and refresh its own access tokens. It must be subclassed to - generate the appropriate assertion string. + This credential does not require a flow to instantiate because it + represents a two legged flow, and therefore has all of the required + information to generate and refresh its own access tokens. It must + be subclassed to generate the appropriate assertion string. AssertionCredentials objects may be safely pickled and unpickled. """ def __init__(self, assertion_type, user_agent, - token_uri='https://accounts.google.com/o/oauth2/token', **kwargs): - """Constructor for AssertionFlowCredentials + token_uri='https://accounts.google.com/o/oauth2/token', + **unused_kwargs): + """Constructor for AssertionFlowCredentials. Args: assertion_type: string, assertion type that will be declared to the auth @@ -403,10 +488,10 @@ class AssertionCredentials(OAuth2Credentials): assertion = self._generate_assertion() body = urllib.urlencode({ - 'assertion_type': self.assertion_type, - 'assertion': assertion, - 'grant_type': "assertion", - }) + 'assertion_type': self.assertion_type, + 'assertion': assertion, + 'grant_type': 'assertion', + }) return body @@ -424,10 +509,10 @@ class OAuth2WebServerFlow(Flow): """ def __init__(self, client_id, client_secret, scope, user_agent, - auth_uri='https://accounts.google.com/o/oauth2/auth', - token_uri='https://accounts.google.com/o/oauth2/token', - **kwargs): - """Constructor for OAuth2WebServerFlow + auth_uri='https://accounts.google.com/o/oauth2/auth', + token_uri='https://accounts.google.com/o/oauth2/token', + **kwargs): + """Constructor for OAuth2WebServerFlow. Args: client_id: string, client identifier. @@ -466,11 +551,11 @@ class OAuth2WebServerFlow(Flow): self.redirect_uri = redirect_uri query = { - 'response_type': 'code', - 'client_id': self.client_id, - 'redirect_uri': redirect_uri, - 'scope': self.scope, - } + 'response_type': 'code', + 'client_id': self.client_id, + 'redirect_uri': redirect_uri, + 'scope': self.scope, + } query.update(self.params) parts = list(urlparse.urlparse(self.auth_uri)) query.update(dict(parse_qsl(parts[4]))) # 4 is the index of the query part @@ -491,15 +576,16 @@ class OAuth2WebServerFlow(Flow): code = code['code'] body = urllib.urlencode({ - 'grant_type': 'authorization_code', - 'client_id': self.client_id, - 'client_secret': self.client_secret, - 'code': code, - 'redirect_uri': self.redirect_uri, - 'scope': self.scope, - }) + 'grant_type': 'authorization_code', + 'client_id': self.client_id, + 'client_secret': self.client_secret, + 'code': code, + 'redirect_uri': self.redirect_uri, + 'scope': self.scope, + }) headers = { - 'content-type': 'application/x-www-form-urlencoded', + 'user-agent': self.user_agent, + 'content-type': 'application/x-www-form-urlencoded', } if self.user_agent is not None: @@ -519,12 +605,12 @@ class OAuth2WebServerFlow(Flow): token_expiry = datetime.datetime.now() + datetime.timedelta( seconds=int(d['expires_in'])) - logging.info('Successfully retrieved access token: %s' % content) + logger.info('Successfully retrieved access token: %s' % content) return OAuth2Credentials(access_token, self.client_id, self.client_secret, refresh_token, token_expiry, self.token_uri, self.user_agent) else: - logging.error('Failed to retrieve access token: %s' % content) + logger.error('Failed to retrieve access token: %s' % content) error_msg = 'Invalid response %s.' % resp['status'] try: d = simplejson.loads(content) diff --git a/oauth2client/django_orm.py b/oauth2client/django_orm.py index c818ea2..581fe8e 100644 --- a/oauth2client/django_orm.py +++ b/oauth2client/django_orm.py @@ -99,7 +99,7 @@ class Storage(BaseStorage): if len(entities) > 0: credential = getattr(entities[0], self.property_name) if credential and hasattr(credential, 'set_store'): - credential.set_store(self.put) + credential.set_store(self) return credential def put(self, credentials): diff --git a/oauth2client/file.py b/oauth2client/file.py index da666c4..b7f9c7d 100644 --- a/oauth2client/file.py +++ b/oauth2client/file.py @@ -44,7 +44,7 @@ class Storage(BaseStorage): f = open(self._filename, 'r') credentials = pickle.loads(f.read()) f.close() - credentials.set_store(self.put) + credentials.set_store(self) except: credentials = None self._lock.release() diff --git a/oauth2client/multistore_file.py b/oauth2client/multistore_file.py new file mode 100644 index 0000000..8841194 --- /dev/null +++ b/oauth2client/multistore_file.py @@ -0,0 +1,361 @@ +# Copyright 2011 Google Inc. All Rights Reserved. + +"""Multi-credential file store with lock support. + +This module implements a JSON credential store where multiple +credentials can be stored in one file. That file supports locking +both in a single process and across processes. + +The credential themselves are keyed off of: +* client_id +* user_agent +* scope + +The format of the stored data is like so: +{ + 'file_version': 1, + 'data': [ + { + 'key': { + 'clientId': '', + 'userAgent': '', + 'scope': '' + }, + 'credential': '' + } + ] +} +""" + +__author__ = 'jbeda@google.com (Joe Beda)' + +import base64 +import fcntl +import logging +import os +import pickle +import threading + +try: # pragma: no cover + import simplejson +except ImportError: # pragma: no cover + try: + # Try to import from django, should work on App Engine + from django.utils import simplejson + except ImportError: + # Should work for Python2.6 and higher. + import json as simplejson + +from client import Storage as BaseStorage + +logger = logging.getLogger(__name__) + +# A dict from 'filename'->_MultiStore instances +_multistores = {} +_multistores_lock = threading.Lock() + + +class Error(Exception): + """Base error for this module.""" + pass + + +class NewerCredentialStoreError(Error): + """The credential store is a newer version that supported.""" + pass + + +def get_credential_storage(filename, client_id, user_agent, scope, + warn_on_readonly=True): + """Get a Storage instance for a credential. + + Args: + filename: The JSON file storing a set of credentials + client_id: The client_id for the credential + user_agent: The user agent for the credential + scope: A string for the scope being requested + warn_on_readonly: if True, log a warning if the store is readonly + + Returns: + An object derived from client.Storage for getting/setting the + credential. + """ + filename = os.path.realpath(os.path.expanduser(filename)) + _multistores_lock.acquire() + try: + multistore = _multistores.setdefault( + filename, _MultiStore(filename, warn_on_readonly)) + finally: + _multistores_lock.release() + return multistore._get_storage(client_id, user_agent, scope) + + +class _MultiStore(object): + """A file backed store for multiple credentials.""" + + def __init__(self, filename, warn_on_readonly=True): + """Initialize the class. + + This will create the file if necessary. + """ + self._filename = filename + self._thread_lock = threading.Lock() + self._file_handle = None + self._read_only = False + self._warn_on_readonly = warn_on_readonly + + self._create_file_if_needed() + + # Cache of deserialized store. This is only valid after the + # _MultiStore is locked or _refresh_data_cache is called. This is + # of the form of: + # + # (client_id, user_agent, scope) -> OAuth2Credential + # + # If this is None, then the store hasn't been read yet. + self._data = None + + class _Storage(BaseStorage): + """A Storage object that knows how to read/write a single credential.""" + + def __init__(self, multistore, client_id, user_agent, scope): + self._multistore = multistore + self._client_id = client_id + self._user_agent = user_agent + self._scope = scope + + def acquire_lock(self): + """Acquires any lock necessary to access this Storage. + + This lock is not reentrant. + """ + self._multistore._lock() + + def release_lock(self): + """Release the Storage lock. + + Trying to release a lock that isn't held will result in a + RuntimeError. + """ + self._multistore._unlock() + + def locked_get(self): + """Retrieve credential. + + The Storage lock must be held when this is called. + + Returns: + oauth2client.client.Credentials + """ + credential = self._multistore._get_credential( + self._client_id, self._user_agent, self._scope) + if credential: + credential.set_store(self) + return credential + + def locked_put(self, credentials): + """Write a credential. + + The Storage lock must be held when this is called. + + Args: + credentials: Credentials, the credentials to store. + """ + self._multistore._update_credential(credentials, self._scope) + + def _create_file_if_needed(self): + """Create an empty file if necessary. + + This method will not initialize the file. Instead it implements a + simple version of "touch" to ensure the file has been created. + """ + if not os.path.exists(self._filename): + old_umask = os.umask(0177) + try: + open(self._filename, 'a+').close() + finally: + os.umask(old_umask) + + def _lock(self): + """Lock the entire multistore.""" + self._thread_lock.acquire() + # Check to see if the file is writeable. + if os.access(self._filename, os.W_OK): + self._file_handle = open(self._filename, 'r+') + fcntl.lockf(self._file_handle.fileno(), fcntl.LOCK_EX) + else: + # Cannot open in read/write mode. Open only in read mode. + self._file_handle = open(self._filename, 'r') + self._read_only = True + if self._warn_on_readonly: + logger.warn('The credentials file (%s) is not writable. Opening in ' + 'read-only mode. Any refreshed credentials will only be ' + 'valid for this run.' % self._filename) + if os.path.getsize(self._filename) == 0: + logger.debug('Initializing empty multistore file') + # The multistore is empty so write out an empty file. + self._data = {} + self._write() + elif not self._read_only or self._data is None: + # Only refresh the data if we are read/write or we haven't + # cached the data yet. If we are readonly, we assume is isn't + # changing out from under us and that we only have to read it + # once. This prevents us from whacking any new access keys that + # we have cached in memory but were unable to write out. + self._refresh_data_cache() + + def _unlock(self): + """Release the lock on the multistore.""" + if not self._read_only: + fcntl.lockf(self._file_handle.fileno(), fcntl.LOCK_UN) + self._file_handle.close() + self._thread_lock.release() + + def _locked_json_read(self): + """Get the raw content of the multistore file. + + The multistore must be locked when this is called. + + Returns: + The contents of the multistore decoded as JSON. + """ + assert self._thread_lock.locked() + self._file_handle.seek(0) + return simplejson.load(self._file_handle) + + def _locked_json_write(self, data): + """Write a JSON serializable data structure to the multistore. + + The multistore must be locked when this is called. + + Args: + data: The data to be serialized and written. + """ + assert self._thread_lock.locked() + if self._read_only: + return + self._file_handle.seek(0) + simplejson.dump(data, self._file_handle, sort_keys=True, indent=2) + self._file_handle.truncate() + + def _refresh_data_cache(self): + """Refresh the contents of the multistore. + + The multistore must be locked when this is called. + + Raises: + NewerCredentialStoreError: Raised when a newer client has written the + store. + """ + self._data = {} + try: + raw_data = self._locked_json_read() + except Exception: + logger.warn('Credential data store could not be loaded. ' + 'Will ignore and overwrite.') + return + + version = 0 + try: + version = raw_data['file_version'] + except Exception: + logger.warn('Missing version for credential data store. It may be ' + 'corrupt or an old version. Overwriting.') + if version > 1: + raise NewerCredentialStoreError( + 'Credential file has file_version of %d. ' + 'Only file_version of 1 is supported.' % version) + + credentials = [] + try: + credentials = raw_data['data'] + except (TypeError, KeyError): + pass + + for cred_entry in credentials: + try: + (key, credential) = self._decode_credential_from_json(cred_entry) + self._data[key] = credential + except: + # If something goes wrong loading a credential, just ignore it + logger.info('Error decoding credential, skipping', exc_info=True) + + def _decode_credential_from_json(self, cred_entry): + """Load a credential from our JSON serialization. + + Args: + cred_entry: A dict entry from the data member of our format + + Returns: + (key, cred) where the key is the key tuple and the cred is the + OAuth2Credential object. + """ + raw_key = cred_entry['key'] + client_id = raw_key['clientId'] + user_agent = raw_key['userAgent'] + scope = raw_key['scope'] + key = (client_id, user_agent, scope) + credential = pickle.loads(base64.b64decode(cred_entry['credential'])) + return (key, credential) + + def _write(self): + """Write the cached data back out. + + The multistore must be locked. + """ + raw_data = {'file_version': 1} + raw_creds = [] + raw_data['data'] = raw_creds + for (cred_key, cred) in self._data.items(): + raw_key = { + 'clientId': cred_key[0], + 'userAgent': cred_key[1], + 'scope': cred_key[2] + } + raw_cred = base64.b64encode(pickle.dumps(cred)) + raw_creds.append({'key': raw_key, 'credential': raw_cred}) + self._locked_json_write(raw_data) + + def _get_credential(self, client_id, user_agent, scope): + """Get a credential from the multistore. + + The multistore must be locked. + + Args: + client_id: The client_id for the credential + user_agent: The user agent for the credential + scope: A string for the scope being requested + + Returns: + The credential specified or None if not present + """ + key = (client_id, user_agent, scope) + return self._data.get(key, None) + + def _update_credential(self, cred, scope): + """Update a credential and write the multistore. + + This must be called when the multistore is locked. + + Args: + cred: The OAuth2Credential to update/set + scope: The scope that this credential covers + """ + key = (cred.client_id, cred.user_agent, scope) + self._data[key] = cred + self._write() + + def _get_storage(self, client_id, user_agent, scope): + """Get a Storage object to get/set a credential. + + This Storage is a 'view' into the multistore. + + Args: + client_id: The client_id for the credential + user_agent: The user agent for the credential + scope: A string for the scope being requested + + Returns: + A Storage object that can be used to get/set this cred + """ + return self._Storage(self, client_id, user_agent, scope) diff --git a/oauth2client/tools.py b/oauth2client/tools.py index f04d4c8..dc779b4 100644 --- a/oauth2client/tools.py +++ b/oauth2client/tools.py @@ -25,31 +25,30 @@ __all__ = ['run'] import BaseHTTPServer import gflags -import logging import socket import sys from client import FlowExchangeError try: - from urlparse import parse_qsl + from urlparse import parse_qsl except ImportError: - from cgi import parse_qsl + from cgi import parse_qsl FLAGS = gflags.FLAGS gflags.DEFINE_boolean('auth_local_webserver', True, - ('Run a local web server to handle redirects during ' + ('Run a local web server to handle redirects during ' 'OAuth authorization.')) gflags.DEFINE_string('auth_host_name', 'localhost', ('Host name to use when running a local web server to ' - 'handle redirects during OAuth authorization.')) + 'handle redirects during OAuth authorization.')) gflags.DEFINE_multi_int('auth_host_port', [8080, 8090], - ('Port to use when running a local web server to ' - 'handle redirects during OAuth authorization.')) + ('Port to use when running a local web server to ' + 'handle redirects during OAuth authorization.')) class ClientRedirectServer(BaseHTTPServer.HTTPServer): @@ -69,7 +68,7 @@ class ClientRedirectHandler(BaseHTTPServer.BaseHTTPRequestHandler): """ def do_GET(s): - """Handle a GET request + """Handle a GET request. Parses the query parameters and prints a message if the flow has completed. Note that we can't detect @@ -106,8 +105,8 @@ def run(flow, storage): for port in FLAGS.auth_host_port: port_number = port try: - httpd = BaseHTTPServer.HTTPServer((FLAGS.auth_host_name, port), - ClientRedirectHandler) + httpd = ClientRedirectServer((FLAGS.auth_host_name, port), + ClientRedirectHandler) except socket.error, e: pass else: @@ -126,10 +125,10 @@ def run(flow, storage): print if FLAGS.auth_local_webserver: print 'If your browser is on a different machine then exit and re-run this' - print 'application with the command-line parameter --noauth_local_webserver.' + print 'application with the command-line parameter ' + print '--noauth_local_webserver.' print - if FLAGS.auth_local_webserver: httpd.handle_request() if 'error' in httpd.query_params: @@ -137,18 +136,15 @@ def run(flow, storage): if 'code' in httpd.query_params: code = httpd.query_params['code'] else: - accepted = 'n' - while accepted.lower() == 'n': - accepted = raw_input('Have you authorized me? (y/n) ') - code = raw_input('What is the verification code? ').strip() + code = raw_input('Enter verification code: ').strip() try: - credentials = flow.step2_exchange(code) - except FlowExchangeError: - sys.exit('The authentication has failed.') + credential = flow.step2_exchange(code) + except FlowExchangeError, e: + sys.exit('Authentication has failed: %s' % e) - storage.put(credentials) - credentials.set_store(storage.put) - print "You have successfully authenticated." + storage.put(credential) + credential.set_store(storage) + print 'Authentication successful.' - return credentials + return credential diff --git a/tests/test_oauth2client_appengine.py b/tests/test_oauth2client_appengine.py index fd06668..f9b5094 100644 --- a/tests/test_oauth2client_appengine.py +++ b/tests/test_oauth2client_appengine.py @@ -191,8 +191,8 @@ class DecoratorTests(unittest.TestCase): self.decorator.credentials.access_token) # Invalidate the stored Credentials - self.decorator.credentials._invalid = True - self.decorator.credentials.store(self.decorator.credentials) + self.decorator.credentials.invalid = True + self.decorator.credentials.store.put(self.decorator.credentials) # Invalid Credentials should start the OAuth dance again response = self.app.get('/foo_path')