Add fancy locking to oauth2client.
Reviewed in http://codereview.appspot.com/4919049/
This commit is contained in:
@@ -23,7 +23,6 @@ import httplib2
|
||||
import pickle
|
||||
import time
|
||||
import base64
|
||||
import logging
|
||||
|
||||
try: # pragma: no cover
|
||||
import simplejson
|
||||
@@ -222,7 +221,7 @@ class StorageByKeyName(Storage):
|
||||
entity = self._model.get_or_insert(self._key_name)
|
||||
credential = getattr(entity, self._property_name)
|
||||
if credential and hasattr(credential, 'set_store'):
|
||||
credential.set_store(self.put)
|
||||
credential.set_store(self)
|
||||
if self._cache:
|
||||
self._cache.set(self._key_name, pickle.dumps(credentials))
|
||||
|
||||
|
||||
@@ -12,10 +12,9 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""An OAuth 2.0 client
|
||||
"""An OAuth 2.0 client.
|
||||
|
||||
Tools for interacting with OAuth 2.0 protected
|
||||
resources.
|
||||
Tools for interacting with OAuth 2.0 protected resources.
|
||||
"""
|
||||
|
||||
__author__ = 'jcgregorio@google.com (Joe Gregorio)'
|
||||
@@ -27,9 +26,9 @@ import logging
|
||||
import urllib
|
||||
import urlparse
|
||||
|
||||
try: # pragma: no cover
|
||||
try: # pragma: no cover
|
||||
import simplejson
|
||||
except ImportError: # pragma: no cover
|
||||
except ImportError: # pragma: no cover
|
||||
try:
|
||||
# Try to import from django, should work on App Engine
|
||||
from django.utils import simplejson
|
||||
@@ -38,9 +37,11 @@ except ImportError: # pragma: no cover
|
||||
import json as simplejson
|
||||
|
||||
try:
|
||||
from urlparse import parse_qsl
|
||||
from urlparse import parse_qsl
|
||||
except ImportError:
|
||||
from cgi import parse_qsl
|
||||
from cgi import parse_qsl
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class Error(Exception):
|
||||
@@ -92,28 +93,76 @@ class Flow(object):
|
||||
class Storage(object):
|
||||
"""Base class for all Storage objects.
|
||||
|
||||
Store and retrieve a single credential.
|
||||
Store and retrieve a single credential. This class supports locking
|
||||
such that multiple processes and threads can operate on a single
|
||||
store.
|
||||
"""
|
||||
|
||||
def get(self):
|
||||
def acquire_lock(self):
|
||||
"""Acquires any lock necessary to access this Storage.
|
||||
|
||||
This lock is not reentrant."""
|
||||
pass
|
||||
|
||||
def release_lock(self):
|
||||
"""Release the Storage lock.
|
||||
|
||||
Trying to release a lock that isn't held will result in a
|
||||
RuntimeError.
|
||||
"""
|
||||
pass
|
||||
|
||||
def locked_get(self):
|
||||
"""Retrieve credential.
|
||||
|
||||
The Storage lock must be held when this is called.
|
||||
|
||||
Returns:
|
||||
oauth2client.client.Credentials
|
||||
"""
|
||||
_abstract()
|
||||
|
||||
def put(self, credentials):
|
||||
def locked_put(self, credentials):
|
||||
"""Write a credential.
|
||||
|
||||
The Storage lock must be held when this is called.
|
||||
|
||||
Args:
|
||||
credentials: Credentials, the credentials to store.
|
||||
"""
|
||||
_abstract()
|
||||
|
||||
def get(self):
|
||||
"""Retrieve credential.
|
||||
|
||||
The Storage lock must *not* be held when this is called.
|
||||
|
||||
Returns:
|
||||
oauth2client.client.Credentials
|
||||
"""
|
||||
self.acquire_lock()
|
||||
try:
|
||||
return self.locked_get()
|
||||
finally:
|
||||
self.release_lock()
|
||||
|
||||
def put(self, credentials):
|
||||
"""Write a credential.
|
||||
|
||||
The Storage lock must be held when this is called.
|
||||
|
||||
Args:
|
||||
credentials: Credentials, the credentials to store.
|
||||
"""
|
||||
self.acquire_lock()
|
||||
try:
|
||||
self.locked_put(credentials)
|
||||
finally:
|
||||
self.release_lock()
|
||||
|
||||
|
||||
class OAuth2Credentials(Credentials):
|
||||
"""Credentials object for OAuth 2.0
|
||||
"""Credentials object for OAuth 2.0.
|
||||
|
||||
Credentials can be applied to an httplib2.Http object using the authorize()
|
||||
method, which then signs each request from that object with the OAuth 2.0
|
||||
@@ -123,22 +172,21 @@ class OAuth2Credentials(Credentials):
|
||||
"""
|
||||
|
||||
def __init__(self, access_token, client_id, client_secret, refresh_token,
|
||||
token_expiry, token_uri, user_agent):
|
||||
"""Create an instance of OAuth2Credentials
|
||||
token_expiry, token_uri, user_agent):
|
||||
"""Create an instance of OAuth2Credentials.
|
||||
|
||||
This constructor is not usually called by the user, instead
|
||||
OAuth2Credentials objects are instantiated by the OAuth2WebServerFlow.
|
||||
|
||||
Args:
|
||||
token_uri: string, URI of token endpoint.
|
||||
access_token: string, access token.
|
||||
client_id: string, client identifier.
|
||||
client_secret: string, client secret.
|
||||
access_token: string, access token.
|
||||
token_expiry: datetime, when the access_token expires.
|
||||
refresh_token: string, refresh token.
|
||||
token_expiry: datetime, when the access_token expires.
|
||||
token_uri: string, URI of token endpoint.
|
||||
user_agent: string, The HTTP User-Agent to provide for this application.
|
||||
|
||||
|
||||
Notes:
|
||||
store: callable, a callable that when passed a Credential
|
||||
will store the credential back to where it came from.
|
||||
@@ -156,51 +204,66 @@ class OAuth2Credentials(Credentials):
|
||||
|
||||
# True if the credentials have been revoked or expired and can't be
|
||||
# refreshed.
|
||||
self._invalid = False
|
||||
self.invalid = False
|
||||
|
||||
@property
|
||||
def invalid(self):
|
||||
"""True if the credentials are invalid, such as being revoked."""
|
||||
return getattr(self, '_invalid', False)
|
||||
def access_token_expired(self):
|
||||
"""True if the credential is expired or invalid.
|
||||
|
||||
If the token_expiry isn't set, we assume the token doesn't expire.
|
||||
"""
|
||||
if self.invalid:
|
||||
return True
|
||||
|
||||
if not self.token_expiry:
|
||||
return False
|
||||
|
||||
now = datetime.datetime.now()
|
||||
if now >= self.token_expiry:
|
||||
logger.info('access_token is expired. Now: %s, token_expiry: %s',
|
||||
now, self.token_expiry)
|
||||
return True
|
||||
return False
|
||||
|
||||
def set_store(self, store):
|
||||
"""Set the storage for the credential.
|
||||
"""Set the Storage for the credential.
|
||||
|
||||
Args:
|
||||
store: callable, a callable that when passed a Credential
|
||||
will store the credential back to where it came from.
|
||||
store: Storage, an implementation of Stroage object.
|
||||
This is needed to store the latest access_token if it
|
||||
has expired and been refreshed.
|
||||
has expired and been refreshed. This implementation uses
|
||||
locking to check for updates before updating the
|
||||
access_token.
|
||||
"""
|
||||
self.store = store
|
||||
|
||||
def _updateFromCredential(self, other):
|
||||
"""Update this Credential from another instance."""
|
||||
self.__dict__.update(other.__getstate__())
|
||||
|
||||
def __getstate__(self):
|
||||
"""Trim the state down to something that can be pickled.
|
||||
"""
|
||||
"""Trim the state down to something that can be pickled."""
|
||||
d = copy.copy(self.__dict__)
|
||||
del d['store']
|
||||
return d
|
||||
|
||||
def __setstate__(self, state):
|
||||
"""Reconstitute the state of the object from being pickled.
|
||||
"""
|
||||
"""Reconstitute the state of the object from being pickled."""
|
||||
self.__dict__.update(state)
|
||||
self.store = None
|
||||
|
||||
def _generate_refresh_request_body(self):
|
||||
"""Generate the body that will be used in the refresh request
|
||||
"""
|
||||
"""Generate the body that will be used in the refresh request."""
|
||||
body = urllib.urlencode({
|
||||
'grant_type': 'refresh_token',
|
||||
'client_id': self.client_id,
|
||||
'client_secret': self.client_secret,
|
||||
'refresh_token': self.refresh_token,
|
||||
})
|
||||
'grant_type': 'refresh_token',
|
||||
'client_id': self.client_id,
|
||||
'client_secret': self.client_secret,
|
||||
'refresh_token': self.refresh_token,
|
||||
})
|
||||
return body
|
||||
|
||||
def _generate_refresh_request_headers(self):
|
||||
"""Generate the headers that will be used in the refresh request
|
||||
"""
|
||||
"""Generate the headers that will be used in the refresh request."""
|
||||
headers = {
|
||||
'content-type': 'application/x-www-form-urlencoded',
|
||||
}
|
||||
@@ -211,16 +274,41 @@ class OAuth2Credentials(Credentials):
|
||||
return headers
|
||||
|
||||
def _refresh(self, http_request):
|
||||
"""Refreshes the access_token.
|
||||
|
||||
This method first checks by reading the Storage object if available.
|
||||
If a refresh is still needed, it holds the Storage lock until the
|
||||
refresh is completed.
|
||||
"""
|
||||
if not self.store:
|
||||
self._do_refresh_request(http_request)
|
||||
else:
|
||||
self.store.acquire_lock()
|
||||
try:
|
||||
new_cred = self.store.locked_get()
|
||||
if (new_cred and not new_cred.invalid and
|
||||
new_cred.access_token != self.access_token):
|
||||
logger.info('Updated access_token read from Storage')
|
||||
self._updateFromCredential(new_cred)
|
||||
else:
|
||||
self._do_refresh_request(http_request)
|
||||
finally:
|
||||
self.store.release_lock()
|
||||
|
||||
def _do_refresh_request(self, http_request):
|
||||
"""Refresh the access_token using the refresh_token.
|
||||
|
||||
Args:
|
||||
http: An instance of httplib2.Http.request
|
||||
or something that acts like it.
|
||||
|
||||
Raises:
|
||||
AccessTokenRefreshError: When the refresh fails.
|
||||
"""
|
||||
body = self._generate_refresh_request_body()
|
||||
headers = self._generate_refresh_request_headers()
|
||||
|
||||
logging.info("Refresing access_token")
|
||||
logger.info('Refresing access_token')
|
||||
resp, content = http_request(
|
||||
self.token_uri, method='POST', body=body, headers=headers)
|
||||
if resp.status == 200:
|
||||
@@ -233,23 +321,20 @@ class OAuth2Credentials(Credentials):
|
||||
seconds=int(d['expires_in'])) + datetime.datetime.now()
|
||||
else:
|
||||
self.token_expiry = None
|
||||
if self.store is not None:
|
||||
self.store(self)
|
||||
if self.store:
|
||||
self.store.locked_put(self)
|
||||
else:
|
||||
# An {'error':...} response body means the token is expired or revoked,
|
||||
# so we flag the credentials as such.
|
||||
logging.error('Failed to retrieve access token: %s' % content)
|
||||
logger.error('Failed to retrieve access token: %s' % content)
|
||||
error_msg = 'Invalid response %s.' % resp['status']
|
||||
try:
|
||||
d = simplejson.loads(content)
|
||||
if 'error' in d:
|
||||
error_msg = d['error']
|
||||
self._invalid = True
|
||||
if self.store is not None:
|
||||
self.store(self)
|
||||
else:
|
||||
logging.warning(
|
||||
"Unable to store refreshed credentials, no Storage provided.")
|
||||
self.invalid = True
|
||||
if self.store:
|
||||
self.store.locked_put(self)
|
||||
except:
|
||||
pass
|
||||
raise AccessTokenRefreshError(error_msg)
|
||||
@@ -269,13 +354,11 @@ class OAuth2Credentials(Credentials):
|
||||
h = httplib2.Http()
|
||||
h = credentials.authorize(h)
|
||||
|
||||
You can't create a new OAuth
|
||||
subclass of httplib2.Authenication because
|
||||
it never gets passed the absolute URI, which is
|
||||
needed for signing. So instead we have to overload
|
||||
'request' with a closure that adds in the
|
||||
Authorization header and then calls the original version
|
||||
of 'request()'.
|
||||
You can't create a new OAuth subclass of httplib2.Authenication
|
||||
because it never gets passed the absolute URI, which is needed for
|
||||
signing. So instead we have to overload 'request' with a closure
|
||||
that adds in the Authorization header and then calls the original
|
||||
version of 'request()'.
|
||||
"""
|
||||
request_orig = http.request
|
||||
|
||||
@@ -284,12 +367,12 @@ class OAuth2Credentials(Credentials):
|
||||
redirections=httplib2.DEFAULT_MAX_REDIRECTS,
|
||||
connection_type=None):
|
||||
if not self.access_token:
|
||||
logging.info("Attempting refresh to obtain initial access_token")
|
||||
logger.info('Attempting refresh to obtain initial access_token')
|
||||
self._refresh(request_orig)
|
||||
|
||||
"""Modify the request headers to add the appropriate
|
||||
Authorization header."""
|
||||
if headers == None:
|
||||
# Modify the request headers to add the appropriate
|
||||
# Authorization header.
|
||||
if headers is None:
|
||||
headers = {}
|
||||
headers['authorization'] = 'OAuth ' + self.access_token
|
||||
|
||||
@@ -303,7 +386,7 @@ class OAuth2Credentials(Credentials):
|
||||
redirections, connection_type)
|
||||
|
||||
if resp.status == 401:
|
||||
logging.info("Refreshing because we got a 401")
|
||||
logger.info('Refreshing due to a 401')
|
||||
self._refresh(request_orig)
|
||||
headers['authorization'] = 'OAuth ' + self.access_token
|
||||
return request_orig(uri, method, body, headers,
|
||||
@@ -316,14 +399,15 @@ class OAuth2Credentials(Credentials):
|
||||
|
||||
|
||||
class AccessTokenCredentials(OAuth2Credentials):
|
||||
"""Credentials object for OAuth 2.0
|
||||
"""Credentials object for OAuth 2.0.
|
||||
|
||||
Credentials can be applied to an httplib2.Http object using the authorize()
|
||||
method, which then signs each request from that object with the OAuth 2.0
|
||||
access token. This set of credentials is for the use case where you have
|
||||
acquired an OAuth 2.0 access_token from another place such as a JavaScript
|
||||
client or another web application, and wish to use it from Python. Because
|
||||
only the access_token is present it can not be refreshed and will in time
|
||||
Credentials can be applied to an httplib2.Http object using the
|
||||
authorize() method, which then signs each request from that object
|
||||
with the OAuth 2.0 access token. This set of credentials is for the
|
||||
use case where you have acquired an OAuth 2.0 access_token from
|
||||
another place such as a JavaScript client or another web
|
||||
application, and wish to use it from Python. Because only the
|
||||
access_token is present it can not be refreshed and will in time
|
||||
expire.
|
||||
|
||||
AccessTokenCredentials objects may be safely pickled and unpickled.
|
||||
@@ -368,19 +452,20 @@ class AccessTokenCredentials(OAuth2Credentials):
|
||||
|
||||
|
||||
class AssertionCredentials(OAuth2Credentials):
|
||||
"""Abstract Credentials object used for OAuth 2.0 assertion grants
|
||||
"""Abstract Credentials object used for OAuth 2.0 assertion grants.
|
||||
|
||||
This credential does not require a flow to instantiate because it represents
|
||||
a two legged flow, and therefore has all of the required information to
|
||||
generate and refresh its own access tokens. It must be subclassed to
|
||||
generate the appropriate assertion string.
|
||||
This credential does not require a flow to instantiate because it
|
||||
represents a two legged flow, and therefore has all of the required
|
||||
information to generate and refresh its own access tokens. It must
|
||||
be subclassed to generate the appropriate assertion string.
|
||||
|
||||
AssertionCredentials objects may be safely pickled and unpickled.
|
||||
"""
|
||||
|
||||
def __init__(self, assertion_type, user_agent,
|
||||
token_uri='https://accounts.google.com/o/oauth2/token', **kwargs):
|
||||
"""Constructor for AssertionFlowCredentials
|
||||
token_uri='https://accounts.google.com/o/oauth2/token',
|
||||
**unused_kwargs):
|
||||
"""Constructor for AssertionFlowCredentials.
|
||||
|
||||
Args:
|
||||
assertion_type: string, assertion type that will be declared to the auth
|
||||
@@ -403,10 +488,10 @@ class AssertionCredentials(OAuth2Credentials):
|
||||
assertion = self._generate_assertion()
|
||||
|
||||
body = urllib.urlencode({
|
||||
'assertion_type': self.assertion_type,
|
||||
'assertion': assertion,
|
||||
'grant_type': "assertion",
|
||||
})
|
||||
'assertion_type': self.assertion_type,
|
||||
'assertion': assertion,
|
||||
'grant_type': 'assertion',
|
||||
})
|
||||
|
||||
return body
|
||||
|
||||
@@ -424,10 +509,10 @@ class OAuth2WebServerFlow(Flow):
|
||||
"""
|
||||
|
||||
def __init__(self, client_id, client_secret, scope, user_agent,
|
||||
auth_uri='https://accounts.google.com/o/oauth2/auth',
|
||||
token_uri='https://accounts.google.com/o/oauth2/token',
|
||||
**kwargs):
|
||||
"""Constructor for OAuth2WebServerFlow
|
||||
auth_uri='https://accounts.google.com/o/oauth2/auth',
|
||||
token_uri='https://accounts.google.com/o/oauth2/token',
|
||||
**kwargs):
|
||||
"""Constructor for OAuth2WebServerFlow.
|
||||
|
||||
Args:
|
||||
client_id: string, client identifier.
|
||||
@@ -466,11 +551,11 @@ class OAuth2WebServerFlow(Flow):
|
||||
|
||||
self.redirect_uri = redirect_uri
|
||||
query = {
|
||||
'response_type': 'code',
|
||||
'client_id': self.client_id,
|
||||
'redirect_uri': redirect_uri,
|
||||
'scope': self.scope,
|
||||
}
|
||||
'response_type': 'code',
|
||||
'client_id': self.client_id,
|
||||
'redirect_uri': redirect_uri,
|
||||
'scope': self.scope,
|
||||
}
|
||||
query.update(self.params)
|
||||
parts = list(urlparse.urlparse(self.auth_uri))
|
||||
query.update(dict(parse_qsl(parts[4]))) # 4 is the index of the query part
|
||||
@@ -491,15 +576,16 @@ class OAuth2WebServerFlow(Flow):
|
||||
code = code['code']
|
||||
|
||||
body = urllib.urlencode({
|
||||
'grant_type': 'authorization_code',
|
||||
'client_id': self.client_id,
|
||||
'client_secret': self.client_secret,
|
||||
'code': code,
|
||||
'redirect_uri': self.redirect_uri,
|
||||
'scope': self.scope,
|
||||
})
|
||||
'grant_type': 'authorization_code',
|
||||
'client_id': self.client_id,
|
||||
'client_secret': self.client_secret,
|
||||
'code': code,
|
||||
'redirect_uri': self.redirect_uri,
|
||||
'scope': self.scope,
|
||||
})
|
||||
headers = {
|
||||
'content-type': 'application/x-www-form-urlencoded',
|
||||
'user-agent': self.user_agent,
|
||||
'content-type': 'application/x-www-form-urlencoded',
|
||||
}
|
||||
|
||||
if self.user_agent is not None:
|
||||
@@ -519,12 +605,12 @@ class OAuth2WebServerFlow(Flow):
|
||||
token_expiry = datetime.datetime.now() + datetime.timedelta(
|
||||
seconds=int(d['expires_in']))
|
||||
|
||||
logging.info('Successfully retrieved access token: %s' % content)
|
||||
logger.info('Successfully retrieved access token: %s' % content)
|
||||
return OAuth2Credentials(access_token, self.client_id,
|
||||
self.client_secret, refresh_token, token_expiry,
|
||||
self.token_uri, self.user_agent)
|
||||
else:
|
||||
logging.error('Failed to retrieve access token: %s' % content)
|
||||
logger.error('Failed to retrieve access token: %s' % content)
|
||||
error_msg = 'Invalid response %s.' % resp['status']
|
||||
try:
|
||||
d = simplejson.loads(content)
|
||||
|
||||
@@ -99,7 +99,7 @@ class Storage(BaseStorage):
|
||||
if len(entities) > 0:
|
||||
credential = getattr(entities[0], self.property_name)
|
||||
if credential and hasattr(credential, 'set_store'):
|
||||
credential.set_store(self.put)
|
||||
credential.set_store(self)
|
||||
return credential
|
||||
|
||||
def put(self, credentials):
|
||||
|
||||
@@ -44,7 +44,7 @@ class Storage(BaseStorage):
|
||||
f = open(self._filename, 'r')
|
||||
credentials = pickle.loads(f.read())
|
||||
f.close()
|
||||
credentials.set_store(self.put)
|
||||
credentials.set_store(self)
|
||||
except:
|
||||
credentials = None
|
||||
self._lock.release()
|
||||
|
||||
361
oauth2client/multistore_file.py
Normal file
361
oauth2client/multistore_file.py
Normal file
@@ -0,0 +1,361 @@
|
||||
# Copyright 2011 Google Inc. All Rights Reserved.
|
||||
|
||||
"""Multi-credential file store with lock support.
|
||||
|
||||
This module implements a JSON credential store where multiple
|
||||
credentials can be stored in one file. That file supports locking
|
||||
both in a single process and across processes.
|
||||
|
||||
The credential themselves are keyed off of:
|
||||
* client_id
|
||||
* user_agent
|
||||
* scope
|
||||
|
||||
The format of the stored data is like so:
|
||||
{
|
||||
'file_version': 1,
|
||||
'data': [
|
||||
{
|
||||
'key': {
|
||||
'clientId': '<client id>',
|
||||
'userAgent': '<user agent>',
|
||||
'scope': '<scope>'
|
||||
},
|
||||
'credential': '<base64 encoding of pickeled Credential object>'
|
||||
}
|
||||
]
|
||||
}
|
||||
"""
|
||||
|
||||
__author__ = 'jbeda@google.com (Joe Beda)'
|
||||
|
||||
import base64
|
||||
import fcntl
|
||||
import logging
|
||||
import os
|
||||
import pickle
|
||||
import threading
|
||||
|
||||
try: # pragma: no cover
|
||||
import simplejson
|
||||
except ImportError: # pragma: no cover
|
||||
try:
|
||||
# Try to import from django, should work on App Engine
|
||||
from django.utils import simplejson
|
||||
except ImportError:
|
||||
# Should work for Python2.6 and higher.
|
||||
import json as simplejson
|
||||
|
||||
from client import Storage as BaseStorage
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# A dict from 'filename'->_MultiStore instances
|
||||
_multistores = {}
|
||||
_multistores_lock = threading.Lock()
|
||||
|
||||
|
||||
class Error(Exception):
|
||||
"""Base error for this module."""
|
||||
pass
|
||||
|
||||
|
||||
class NewerCredentialStoreError(Error):
|
||||
"""The credential store is a newer version that supported."""
|
||||
pass
|
||||
|
||||
|
||||
def get_credential_storage(filename, client_id, user_agent, scope,
|
||||
warn_on_readonly=True):
|
||||
"""Get a Storage instance for a credential.
|
||||
|
||||
Args:
|
||||
filename: The JSON file storing a set of credentials
|
||||
client_id: The client_id for the credential
|
||||
user_agent: The user agent for the credential
|
||||
scope: A string for the scope being requested
|
||||
warn_on_readonly: if True, log a warning if the store is readonly
|
||||
|
||||
Returns:
|
||||
An object derived from client.Storage for getting/setting the
|
||||
credential.
|
||||
"""
|
||||
filename = os.path.realpath(os.path.expanduser(filename))
|
||||
_multistores_lock.acquire()
|
||||
try:
|
||||
multistore = _multistores.setdefault(
|
||||
filename, _MultiStore(filename, warn_on_readonly))
|
||||
finally:
|
||||
_multistores_lock.release()
|
||||
return multistore._get_storage(client_id, user_agent, scope)
|
||||
|
||||
|
||||
class _MultiStore(object):
|
||||
"""A file backed store for multiple credentials."""
|
||||
|
||||
def __init__(self, filename, warn_on_readonly=True):
|
||||
"""Initialize the class.
|
||||
|
||||
This will create the file if necessary.
|
||||
"""
|
||||
self._filename = filename
|
||||
self._thread_lock = threading.Lock()
|
||||
self._file_handle = None
|
||||
self._read_only = False
|
||||
self._warn_on_readonly = warn_on_readonly
|
||||
|
||||
self._create_file_if_needed()
|
||||
|
||||
# Cache of deserialized store. This is only valid after the
|
||||
# _MultiStore is locked or _refresh_data_cache is called. This is
|
||||
# of the form of:
|
||||
#
|
||||
# (client_id, user_agent, scope) -> OAuth2Credential
|
||||
#
|
||||
# If this is None, then the store hasn't been read yet.
|
||||
self._data = None
|
||||
|
||||
class _Storage(BaseStorage):
|
||||
"""A Storage object that knows how to read/write a single credential."""
|
||||
|
||||
def __init__(self, multistore, client_id, user_agent, scope):
|
||||
self._multistore = multistore
|
||||
self._client_id = client_id
|
||||
self._user_agent = user_agent
|
||||
self._scope = scope
|
||||
|
||||
def acquire_lock(self):
|
||||
"""Acquires any lock necessary to access this Storage.
|
||||
|
||||
This lock is not reentrant.
|
||||
"""
|
||||
self._multistore._lock()
|
||||
|
||||
def release_lock(self):
|
||||
"""Release the Storage lock.
|
||||
|
||||
Trying to release a lock that isn't held will result in a
|
||||
RuntimeError.
|
||||
"""
|
||||
self._multistore._unlock()
|
||||
|
||||
def locked_get(self):
|
||||
"""Retrieve credential.
|
||||
|
||||
The Storage lock must be held when this is called.
|
||||
|
||||
Returns:
|
||||
oauth2client.client.Credentials
|
||||
"""
|
||||
credential = self._multistore._get_credential(
|
||||
self._client_id, self._user_agent, self._scope)
|
||||
if credential:
|
||||
credential.set_store(self)
|
||||
return credential
|
||||
|
||||
def locked_put(self, credentials):
|
||||
"""Write a credential.
|
||||
|
||||
The Storage lock must be held when this is called.
|
||||
|
||||
Args:
|
||||
credentials: Credentials, the credentials to store.
|
||||
"""
|
||||
self._multistore._update_credential(credentials, self._scope)
|
||||
|
||||
def _create_file_if_needed(self):
|
||||
"""Create an empty file if necessary.
|
||||
|
||||
This method will not initialize the file. Instead it implements a
|
||||
simple version of "touch" to ensure the file has been created.
|
||||
"""
|
||||
if not os.path.exists(self._filename):
|
||||
old_umask = os.umask(0177)
|
||||
try:
|
||||
open(self._filename, 'a+').close()
|
||||
finally:
|
||||
os.umask(old_umask)
|
||||
|
||||
def _lock(self):
|
||||
"""Lock the entire multistore."""
|
||||
self._thread_lock.acquire()
|
||||
# Check to see if the file is writeable.
|
||||
if os.access(self._filename, os.W_OK):
|
||||
self._file_handle = open(self._filename, 'r+')
|
||||
fcntl.lockf(self._file_handle.fileno(), fcntl.LOCK_EX)
|
||||
else:
|
||||
# Cannot open in read/write mode. Open only in read mode.
|
||||
self._file_handle = open(self._filename, 'r')
|
||||
self._read_only = True
|
||||
if self._warn_on_readonly:
|
||||
logger.warn('The credentials file (%s) is not writable. Opening in '
|
||||
'read-only mode. Any refreshed credentials will only be '
|
||||
'valid for this run.' % self._filename)
|
||||
if os.path.getsize(self._filename) == 0:
|
||||
logger.debug('Initializing empty multistore file')
|
||||
# The multistore is empty so write out an empty file.
|
||||
self._data = {}
|
||||
self._write()
|
||||
elif not self._read_only or self._data is None:
|
||||
# Only refresh the data if we are read/write or we haven't
|
||||
# cached the data yet. If we are readonly, we assume is isn't
|
||||
# changing out from under us and that we only have to read it
|
||||
# once. This prevents us from whacking any new access keys that
|
||||
# we have cached in memory but were unable to write out.
|
||||
self._refresh_data_cache()
|
||||
|
||||
def _unlock(self):
|
||||
"""Release the lock on the multistore."""
|
||||
if not self._read_only:
|
||||
fcntl.lockf(self._file_handle.fileno(), fcntl.LOCK_UN)
|
||||
self._file_handle.close()
|
||||
self._thread_lock.release()
|
||||
|
||||
def _locked_json_read(self):
|
||||
"""Get the raw content of the multistore file.
|
||||
|
||||
The multistore must be locked when this is called.
|
||||
|
||||
Returns:
|
||||
The contents of the multistore decoded as JSON.
|
||||
"""
|
||||
assert self._thread_lock.locked()
|
||||
self._file_handle.seek(0)
|
||||
return simplejson.load(self._file_handle)
|
||||
|
||||
def _locked_json_write(self, data):
|
||||
"""Write a JSON serializable data structure to the multistore.
|
||||
|
||||
The multistore must be locked when this is called.
|
||||
|
||||
Args:
|
||||
data: The data to be serialized and written.
|
||||
"""
|
||||
assert self._thread_lock.locked()
|
||||
if self._read_only:
|
||||
return
|
||||
self._file_handle.seek(0)
|
||||
simplejson.dump(data, self._file_handle, sort_keys=True, indent=2)
|
||||
self._file_handle.truncate()
|
||||
|
||||
def _refresh_data_cache(self):
|
||||
"""Refresh the contents of the multistore.
|
||||
|
||||
The multistore must be locked when this is called.
|
||||
|
||||
Raises:
|
||||
NewerCredentialStoreError: Raised when a newer client has written the
|
||||
store.
|
||||
"""
|
||||
self._data = {}
|
||||
try:
|
||||
raw_data = self._locked_json_read()
|
||||
except Exception:
|
||||
logger.warn('Credential data store could not be loaded. '
|
||||
'Will ignore and overwrite.')
|
||||
return
|
||||
|
||||
version = 0
|
||||
try:
|
||||
version = raw_data['file_version']
|
||||
except Exception:
|
||||
logger.warn('Missing version for credential data store. It may be '
|
||||
'corrupt or an old version. Overwriting.')
|
||||
if version > 1:
|
||||
raise NewerCredentialStoreError(
|
||||
'Credential file has file_version of %d. '
|
||||
'Only file_version of 1 is supported.' % version)
|
||||
|
||||
credentials = []
|
||||
try:
|
||||
credentials = raw_data['data']
|
||||
except (TypeError, KeyError):
|
||||
pass
|
||||
|
||||
for cred_entry in credentials:
|
||||
try:
|
||||
(key, credential) = self._decode_credential_from_json(cred_entry)
|
||||
self._data[key] = credential
|
||||
except:
|
||||
# If something goes wrong loading a credential, just ignore it
|
||||
logger.info('Error decoding credential, skipping', exc_info=True)
|
||||
|
||||
def _decode_credential_from_json(self, cred_entry):
|
||||
"""Load a credential from our JSON serialization.
|
||||
|
||||
Args:
|
||||
cred_entry: A dict entry from the data member of our format
|
||||
|
||||
Returns:
|
||||
(key, cred) where the key is the key tuple and the cred is the
|
||||
OAuth2Credential object.
|
||||
"""
|
||||
raw_key = cred_entry['key']
|
||||
client_id = raw_key['clientId']
|
||||
user_agent = raw_key['userAgent']
|
||||
scope = raw_key['scope']
|
||||
key = (client_id, user_agent, scope)
|
||||
credential = pickle.loads(base64.b64decode(cred_entry['credential']))
|
||||
return (key, credential)
|
||||
|
||||
def _write(self):
|
||||
"""Write the cached data back out.
|
||||
|
||||
The multistore must be locked.
|
||||
"""
|
||||
raw_data = {'file_version': 1}
|
||||
raw_creds = []
|
||||
raw_data['data'] = raw_creds
|
||||
for (cred_key, cred) in self._data.items():
|
||||
raw_key = {
|
||||
'clientId': cred_key[0],
|
||||
'userAgent': cred_key[1],
|
||||
'scope': cred_key[2]
|
||||
}
|
||||
raw_cred = base64.b64encode(pickle.dumps(cred))
|
||||
raw_creds.append({'key': raw_key, 'credential': raw_cred})
|
||||
self._locked_json_write(raw_data)
|
||||
|
||||
def _get_credential(self, client_id, user_agent, scope):
|
||||
"""Get a credential from the multistore.
|
||||
|
||||
The multistore must be locked.
|
||||
|
||||
Args:
|
||||
client_id: The client_id for the credential
|
||||
user_agent: The user agent for the credential
|
||||
scope: A string for the scope being requested
|
||||
|
||||
Returns:
|
||||
The credential specified or None if not present
|
||||
"""
|
||||
key = (client_id, user_agent, scope)
|
||||
return self._data.get(key, None)
|
||||
|
||||
def _update_credential(self, cred, scope):
|
||||
"""Update a credential and write the multistore.
|
||||
|
||||
This must be called when the multistore is locked.
|
||||
|
||||
Args:
|
||||
cred: The OAuth2Credential to update/set
|
||||
scope: The scope that this credential covers
|
||||
"""
|
||||
key = (cred.client_id, cred.user_agent, scope)
|
||||
self._data[key] = cred
|
||||
self._write()
|
||||
|
||||
def _get_storage(self, client_id, user_agent, scope):
|
||||
"""Get a Storage object to get/set a credential.
|
||||
|
||||
This Storage is a 'view' into the multistore.
|
||||
|
||||
Args:
|
||||
client_id: The client_id for the credential
|
||||
user_agent: The user agent for the credential
|
||||
scope: A string for the scope being requested
|
||||
|
||||
Returns:
|
||||
A Storage object that can be used to get/set this cred
|
||||
"""
|
||||
return self._Storage(self, client_id, user_agent, scope)
|
||||
@@ -25,31 +25,30 @@ __all__ = ['run']
|
||||
|
||||
import BaseHTTPServer
|
||||
import gflags
|
||||
import logging
|
||||
import socket
|
||||
import sys
|
||||
|
||||
from client import FlowExchangeError
|
||||
|
||||
try:
|
||||
from urlparse import parse_qsl
|
||||
from urlparse import parse_qsl
|
||||
except ImportError:
|
||||
from cgi import parse_qsl
|
||||
from cgi import parse_qsl
|
||||
|
||||
|
||||
FLAGS = gflags.FLAGS
|
||||
|
||||
gflags.DEFINE_boolean('auth_local_webserver', True,
|
||||
('Run a local web server to handle redirects during '
|
||||
('Run a local web server to handle redirects during '
|
||||
'OAuth authorization.'))
|
||||
|
||||
gflags.DEFINE_string('auth_host_name', 'localhost',
|
||||
('Host name to use when running a local web server to '
|
||||
'handle redirects during OAuth authorization.'))
|
||||
'handle redirects during OAuth authorization.'))
|
||||
|
||||
gflags.DEFINE_multi_int('auth_host_port', [8080, 8090],
|
||||
('Port to use when running a local web server to '
|
||||
'handle redirects during OAuth authorization.'))
|
||||
('Port to use when running a local web server to '
|
||||
'handle redirects during OAuth authorization.'))
|
||||
|
||||
|
||||
class ClientRedirectServer(BaseHTTPServer.HTTPServer):
|
||||
@@ -69,7 +68,7 @@ class ClientRedirectHandler(BaseHTTPServer.BaseHTTPRequestHandler):
|
||||
"""
|
||||
|
||||
def do_GET(s):
|
||||
"""Handle a GET request
|
||||
"""Handle a GET request.
|
||||
|
||||
Parses the query parameters and prints a message
|
||||
if the flow has completed. Note that we can't detect
|
||||
@@ -106,8 +105,8 @@ def run(flow, storage):
|
||||
for port in FLAGS.auth_host_port:
|
||||
port_number = port
|
||||
try:
|
||||
httpd = BaseHTTPServer.HTTPServer((FLAGS.auth_host_name, port),
|
||||
ClientRedirectHandler)
|
||||
httpd = ClientRedirectServer((FLAGS.auth_host_name, port),
|
||||
ClientRedirectHandler)
|
||||
except socket.error, e:
|
||||
pass
|
||||
else:
|
||||
@@ -126,10 +125,10 @@ def run(flow, storage):
|
||||
print
|
||||
if FLAGS.auth_local_webserver:
|
||||
print 'If your browser is on a different machine then exit and re-run this'
|
||||
print 'application with the command-line parameter --noauth_local_webserver.'
|
||||
print 'application with the command-line parameter '
|
||||
print '--noauth_local_webserver.'
|
||||
print
|
||||
|
||||
|
||||
if FLAGS.auth_local_webserver:
|
||||
httpd.handle_request()
|
||||
if 'error' in httpd.query_params:
|
||||
@@ -137,18 +136,15 @@ def run(flow, storage):
|
||||
if 'code' in httpd.query_params:
|
||||
code = httpd.query_params['code']
|
||||
else:
|
||||
accepted = 'n'
|
||||
while accepted.lower() == 'n':
|
||||
accepted = raw_input('Have you authorized me? (y/n) ')
|
||||
code = raw_input('What is the verification code? ').strip()
|
||||
code = raw_input('Enter verification code: ').strip()
|
||||
|
||||
try:
|
||||
credentials = flow.step2_exchange(code)
|
||||
except FlowExchangeError:
|
||||
sys.exit('The authentication has failed.')
|
||||
credential = flow.step2_exchange(code)
|
||||
except FlowExchangeError, e:
|
||||
sys.exit('Authentication has failed: %s' % e)
|
||||
|
||||
storage.put(credentials)
|
||||
credentials.set_store(storage.put)
|
||||
print "You have successfully authenticated."
|
||||
storage.put(credential)
|
||||
credential.set_store(storage)
|
||||
print 'Authentication successful.'
|
||||
|
||||
return credentials
|
||||
return credential
|
||||
|
||||
@@ -191,8 +191,8 @@ class DecoratorTests(unittest.TestCase):
|
||||
self.decorator.credentials.access_token)
|
||||
|
||||
# Invalidate the stored Credentials
|
||||
self.decorator.credentials._invalid = True
|
||||
self.decorator.credentials.store(self.decorator.credentials)
|
||||
self.decorator.credentials.invalid = True
|
||||
self.decorator.credentials.store.put(self.decorator.credentials)
|
||||
|
||||
# Invalid Credentials should start the OAuth dance again
|
||||
response = self.app.get('/foo_path')
|
||||
|
||||
Reference in New Issue
Block a user