diff --git a/oauth2client/_helpers.py b/oauth2client/_helpers.py
index 2789cdf..89921d8 100644
--- a/oauth2client/_helpers.py
+++ b/oauth2client/_helpers.py
@@ -19,7 +19,7 @@ import six
def _parse_pem_key(raw_key_input):
- """Identify and extract PEM keys.
+ """Identify and extract PEM keys.
Determines whether the given key is in the format of PEM key, and extracts
the relevant part of the key if it is.
@@ -30,17 +30,17 @@ def _parse_pem_key(raw_key_input):
Returns:
string, The actual key if the contents are from a PEM file, or else None.
"""
- offset = raw_key_input.find(b'-----BEGIN ')
- if offset != -1:
- return raw_key_input[offset:]
+ offset = raw_key_input.find(b'-----BEGIN ')
+ if offset != -1:
+ return raw_key_input[offset:]
def _json_encode(data):
- return json.dumps(data, separators=(',', ':'))
+ return json.dumps(data, separators=(',', ':'))
def _to_bytes(value, encoding='ascii'):
- """Converts a string value to bytes, if necessary.
+ """Converts a string value to bytes, if necessary.
Unfortunately, ``six.b`` is insufficient for this task since in
Python2 it does not modify ``unicode`` objects.
@@ -60,16 +60,16 @@ def _to_bytes(value, encoding='ascii'):
Raises:
ValueError if the value could not be converted to bytes.
"""
- result = (value.encode(encoding)
+ result = (value.encode(encoding)
if isinstance(value, six.text_type) else value)
- if isinstance(result, six.binary_type):
- return result
- else:
- raise ValueError('%r could not be converted to bytes' % (value,))
+ if isinstance(result, six.binary_type):
+ return result
+ else:
+ raise ValueError('%r could not be converted to bytes' % (value, ))
def _from_bytes(value):
- """Converts bytes to a string value, if necessary.
+ """Converts bytes to a string value, if necessary.
Args:
value: The string/bytes value to be converted.
@@ -81,21 +81,21 @@ def _from_bytes(value):
Raises:
ValueError if the value could not be converted to unicode.
"""
- result = (value.decode('utf-8')
+ result = (value.decode('utf-8')
if isinstance(value, six.binary_type) else value)
- if isinstance(result, six.text_type):
- return result
- else:
- raise ValueError('%r could not be converted to unicode' % (value,))
+ if isinstance(result, six.text_type):
+ return result
+ else:
+ raise ValueError('%r could not be converted to unicode' % (value, ))
def _urlsafe_b64encode(raw_bytes):
- raw_bytes = _to_bytes(raw_bytes, encoding='utf-8')
- return base64.urlsafe_b64encode(raw_bytes).rstrip(b'=')
+ raw_bytes = _to_bytes(raw_bytes, encoding='utf-8')
+ return base64.urlsafe_b64encode(raw_bytes).rstrip(b'=')
def _urlsafe_b64decode(b64string):
- # Guard against unicode strings, which base64 can't handle.
- b64string = _to_bytes(b64string)
- padded = b64string + b'=' * (4 - len(b64string) % 4)
- return base64.urlsafe_b64decode(padded)
+ # Guard against unicode strings, which base64 can't handle.
+ b64string = _to_bytes(b64string)
+ padded = b64string + b'=' * (4 - len(b64string) % 4)
+ return base64.urlsafe_b64decode(padded)
diff --git a/oauth2client/_openssl_crypt.py b/oauth2client/_openssl_crypt.py
index 9fcd996..34f64ec 100644
--- a/oauth2client/_openssl_crypt.py
+++ b/oauth2client/_openssl_crypt.py
@@ -22,18 +22,18 @@ from oauth2client._helpers import _to_bytes
class OpenSSLVerifier(object):
- """Verifies the signature on a message."""
+ """Verifies the signature on a message."""
- def __init__(self, pubkey):
- """Constructor.
+ def __init__(self, pubkey):
+ """Constructor.
Args:
pubkey, OpenSSL.crypto.PKey, The public key to verify with.
"""
- self._pubkey = pubkey
+ self._pubkey = pubkey
- def verify(self, message, signature):
- """Verifies a message against a signature.
+ def verify(self, message, signature):
+ """Verifies a message against a signature.
Args:
message: string or bytes, The message to verify. If string, will be
@@ -45,17 +45,17 @@ class OpenSSLVerifier(object):
True if message was signed by the private key associated with the public
key that this object was constructed with.
"""
- message = _to_bytes(message, encoding='utf-8')
- signature = _to_bytes(signature, encoding='utf-8')
- try:
- crypto.verify(self._pubkey, signature, message, 'sha256')
- return True
- except crypto.Error:
- return False
+ message = _to_bytes(message, encoding='utf-8')
+ signature = _to_bytes(signature, encoding='utf-8')
+ try:
+ crypto.verify(self._pubkey, signature, message, 'sha256')
+ return True
+ except crypto.Error:
+ return False
- @staticmethod
+ @staticmethod
def from_string(key_pem, is_x509_cert):
- """Construct a Verified instance from a string.
+ """Construct a Verified instance from a string.
Args:
key_pem: string, public key in PEM format.
@@ -68,26 +68,26 @@ class OpenSSLVerifier(object):
Raises:
OpenSSL.crypto.Error if the key_pem can't be parsed.
"""
- if is_x509_cert:
- pubkey = crypto.load_certificate(crypto.FILETYPE_PEM, key_pem)
- else:
- pubkey = crypto.load_privatekey(crypto.FILETYPE_PEM, key_pem)
- return OpenSSLVerifier(pubkey)
+ if is_x509_cert:
+ pubkey = crypto.load_certificate(crypto.FILETYPE_PEM, key_pem)
+ else:
+ pubkey = crypto.load_privatekey(crypto.FILETYPE_PEM, key_pem)
+ return OpenSSLVerifier(pubkey)
class OpenSSLSigner(object):
- """Signs messages with a private key."""
+ """Signs messages with a private key."""
- def __init__(self, pkey):
- """Constructor.
+ def __init__(self, pkey):
+ """Constructor.
Args:
pkey, OpenSSL.crypto.PKey (or equiv), The private key to sign with.
"""
- self._key = pkey
+ self._key = pkey
- def sign(self, message):
- """Signs a message.
+ def sign(self, message):
+ """Signs a message.
Args:
message: bytes, Message to be signed.
@@ -95,12 +95,12 @@ class OpenSSLSigner(object):
Returns:
string, The signature of the message for the given key.
"""
- message = _to_bytes(message, encoding='utf-8')
- return crypto.sign(self._key, message, 'sha256')
+ message = _to_bytes(message, encoding='utf-8')
+ return crypto.sign(self._key, message, 'sha256')
- @staticmethod
- def from_string(key, password=b'notasecret'):
- """Construct a Signer instance from a string.
+ @staticmethod
+ def from_string(key, password=b'notasecret'):
+ """Construct a Signer instance from a string.
Args:
key: string, private key in PKCS12 or PEM format.
@@ -112,17 +112,17 @@ class OpenSSLSigner(object):
Raises:
OpenSSL.crypto.Error if the key can't be parsed.
"""
- parsed_pem_key = _parse_pem_key(key)
- if parsed_pem_key:
- pkey = crypto.load_privatekey(crypto.FILETYPE_PEM, parsed_pem_key)
- else:
- password = _to_bytes(password, encoding='utf-8')
- pkey = crypto.load_pkcs12(key, password).get_privatekey()
- return OpenSSLSigner(pkey)
+ parsed_pem_key = _parse_pem_key(key)
+ if parsed_pem_key:
+ pkey = crypto.load_privatekey(crypto.FILETYPE_PEM, parsed_pem_key)
+ else:
+ password = _to_bytes(password, encoding='utf-8')
+ pkey = crypto.load_pkcs12(key, password).get_privatekey()
+ return OpenSSLSigner(pkey)
def pkcs12_key_as_pem(private_key_text, private_key_password):
- """Convert the contents of a PKCS12 key to PEM using OpenSSL.
+ """Convert the contents of a PKCS12 key to PEM using OpenSSL.
Args:
private_key_text: String. Private key.
@@ -131,9 +131,9 @@ def pkcs12_key_as_pem(private_key_text, private_key_password):
Returns:
String. PEM contents of ``private_key_text``.
"""
- decoded_body = base64.b64decode(private_key_text)
- private_key_password = _to_bytes(private_key_password)
+ decoded_body = base64.b64decode(private_key_text)
+ private_key_password = _to_bytes(private_key_password)
- pkcs12 = crypto.load_pkcs12(decoded_body, private_key_password)
- return crypto.dump_privatekey(crypto.FILETYPE_PEM,
+ pkcs12 = crypto.load_pkcs12(decoded_body, private_key_password)
+ return crypto.dump_privatekey(crypto.FILETYPE_PEM,
pkcs12.get_privatekey())
diff --git a/oauth2client/_pycrypto_crypt.py b/oauth2client/_pycrypto_crypt.py
index 9576438..0f13b0a 100644
--- a/oauth2client/_pycrypto_crypt.py
+++ b/oauth2client/_pycrypto_crypt.py
@@ -25,18 +25,18 @@ from oauth2client._helpers import _urlsafe_b64decode
class PyCryptoVerifier(object):
- """Verifies the signature on a message."""
+ """Verifies the signature on a message."""
- def __init__(self, pubkey):
- """Constructor.
+ def __init__(self, pubkey):
+ """Constructor.
Args:
pubkey, OpenSSL.crypto.PKey (or equiv), The public key to verify with.
"""
- self._pubkey = pubkey
+ self._pubkey = pubkey
- def verify(self, message, signature):
- """Verifies a message against a signature.
+ def verify(self, message, signature):
+ """Verifies a message against a signature.
Args:
message: string or bytes, The message to verify. If string, will be
@@ -47,13 +47,13 @@ class PyCryptoVerifier(object):
True if message was signed by the private key associated with the public
key that this object was constructed with.
"""
- message = _to_bytes(message, encoding='utf-8')
- return PKCS1_v1_5.new(self._pubkey).verify(
+ message = _to_bytes(message, encoding='utf-8')
+ return PKCS1_v1_5.new(self._pubkey).verify(
SHA256.new(message), signature)
- @staticmethod
- def from_string(key_pem, is_x509_cert):
- """Construct a Verified instance from a string.
+ @staticmethod
+ def from_string(key_pem, is_x509_cert):
+ """Construct a Verified instance from a string.
Args:
key_pem: string, public key in PEM format.
@@ -63,33 +63,33 @@ class PyCryptoVerifier(object):
Returns:
Verifier instance.
"""
- if is_x509_cert:
- key_pem = _to_bytes(key_pem)
- pemLines = key_pem.replace(b' ', b'').split()
- certDer = _urlsafe_b64decode(b''.join(pemLines[1:-1]))
- certSeq = DerSequence()
- certSeq.decode(certDer)
- tbsSeq = DerSequence()
- tbsSeq.decode(certSeq[0])
- pubkey = RSA.importKey(tbsSeq[6])
- else:
- pubkey = RSA.importKey(key_pem)
- return PyCryptoVerifier(pubkey)
+ if is_x509_cert:
+ key_pem = _to_bytes(key_pem)
+ pemLines = key_pem.replace(b' ', b'').split()
+ certDer = _urlsafe_b64decode(b''.join(pemLines[1:-1]))
+ certSeq = DerSequence()
+ certSeq.decode(certDer)
+ tbsSeq = DerSequence()
+ tbsSeq.decode(certSeq[0])
+ pubkey = RSA.importKey(tbsSeq[6])
+ else:
+ pubkey = RSA.importKey(key_pem)
+ return PyCryptoVerifier(pubkey)
class PyCryptoSigner(object):
- """Signs messages with a private key."""
+ """Signs messages with a private key."""
- def __init__(self, pkey):
- """Constructor.
+ def __init__(self, pkey):
+ """Constructor.
Args:
pkey, OpenSSL.crypto.PKey (or equiv), The private key to sign with.
"""
- self._key = pkey
+ self._key = pkey
- def sign(self, message):
- """Signs a message.
+ def sign(self, message):
+ """Signs a message.
Args:
message: string, Message to be signed.
@@ -97,12 +97,12 @@ class PyCryptoSigner(object):
Returns:
string, The signature of the message for the given key.
"""
- message = _to_bytes(message, encoding='utf-8')
- return PKCS1_v1_5.new(self._key).sign(SHA256.new(message))
+ message = _to_bytes(message, encoding='utf-8')
+ return PKCS1_v1_5.new(self._key).sign(SHA256.new(message))
- @staticmethod
- def from_string(key, password='notasecret'):
- """Construct a Signer instance from a string.
+ @staticmethod
+ def from_string(key, password='notasecret'):
+ """Construct a Signer instance from a string.
Args:
key: string, private key in PEM format.
@@ -114,13 +114,13 @@ class PyCryptoSigner(object):
Raises:
NotImplementedError if the key isn't in PEM format.
"""
- parsed_pem_key = _parse_pem_key(key)
- if parsed_pem_key:
- pkey = RSA.importKey(parsed_pem_key)
- else:
- raise NotImplementedError(
+ parsed_pem_key = _parse_pem_key(key)
+ if parsed_pem_key:
+ pkey = RSA.importKey(parsed_pem_key)
+ else:
+ raise NotImplementedError(
'PKCS12 format is not supported by the PyCrypto library. '
'Try converting to a "PEM" '
'(openssl pkcs12 -in xxxxx.p12 -nodes -nocerts > privatekey.pem) '
'or using PyOpenSSL if native code is an option.')
- return PyCryptoSigner(pkey)
+ return PyCryptoSigner(pkey)
diff --git a/oauth2client/appengine.py b/oauth2client/appengine.py
index 7888a0d..e434675 100644
--- a/oauth2client/appengine.py
+++ b/oauth2client/appengine.py
@@ -51,10 +51,9 @@ from oauth2client.client import Storage
# TODO(dhermes): Resolve import issue.
# This is a temporary fix for a Google internal issue.
try:
- from google.appengine.ext import ndb
+ from google.appengine.ext import ndb
except ImportError:
- ndb = None
-
+ ndb = None
logger = logging.getLogger(__name__)
@@ -64,7 +63,7 @@ XSRF_MEMCACHE_ID = 'xsrf_secret_key'
def _safe_html(s):
- """Escape text to make it safe to display.
+ """Escape text to make it safe to display.
Args:
s: string, The text to escape.
@@ -72,28 +71,28 @@ def _safe_html(s):
Returns:
The escaped text as a string.
"""
- return cgi.escape(s, quote=1).replace("'", ''')
+ return cgi.escape(s, quote=1).replace("'", ''')
class InvalidClientSecretsError(Exception):
- """The client_secrets.json file is malformed or missing required fields."""
+ """The client_secrets.json file is malformed or missing required fields."""
class InvalidXsrfTokenError(Exception):
- """The XSRF token is invalid or expired."""
+ """The XSRF token is invalid or expired."""
class SiteXsrfSecretKey(db.Model):
- """Storage for the sites XSRF secret key.
+ """Storage for the sites XSRF secret key.
There will only be one instance stored of this model, the one used for the
site.
"""
- secret = db.StringProperty()
+ secret = db.StringProperty()
if ndb is not None:
- class SiteXsrfSecretKeyNDB(ndb.Model):
- """NDB Model for storage for the sites XSRF secret key.
+ class SiteXsrfSecretKeyNDB(ndb.Model):
+ """NDB Model for storage for the sites XSRF secret key.
Since this model uses the same kind as SiteXsrfSecretKey, it can be used
interchangeably. This simply provides an NDB model for interacting with the
@@ -102,22 +101,22 @@ if ndb is not None:
There should only be one instance stored of this model, the one used for the
site.
"""
- secret = ndb.StringProperty()
+ secret = ndb.StringProperty()
- @classmethod
- def _get_kind(cls):
- """Return the kind name for this class."""
- return 'SiteXsrfSecretKey'
+ @classmethod
+ def _get_kind(cls):
+ """Return the kind name for this class."""
+ return 'SiteXsrfSecretKey'
def _generate_new_xsrf_secret_key():
- """Returns a random XSRF secret key.
+ """Returns a random XSRF secret key.
"""
- return os.urandom(16).encode("hex")
+ return os.urandom(16).encode("hex")
def xsrf_secret_key():
- """Return the secret key for use for XSRF protection.
+ """Return the secret key for use for XSRF protection.
If the Site entity does not have a secret key, this method will also create
one and persist it.
@@ -125,21 +124,21 @@ def xsrf_secret_key():
Returns:
The secret key.
"""
- secret = memcache.get(XSRF_MEMCACHE_ID, namespace=OAUTH2CLIENT_NAMESPACE)
- if not secret:
- # Load the one and only instance of SiteXsrfSecretKey.
- model = SiteXsrfSecretKey.get_or_insert(key_name='site')
- if not model.secret:
- model.secret = _generate_new_xsrf_secret_key()
- model.put()
- secret = model.secret
- memcache.add(XSRF_MEMCACHE_ID, secret, namespace=OAUTH2CLIENT_NAMESPACE)
+ secret = memcache.get(XSRF_MEMCACHE_ID, namespace=OAUTH2CLIENT_NAMESPACE)
+ if not secret:
+ # Load the one and only instance of SiteXsrfSecretKey.
+ model = SiteXsrfSecretKey.get_or_insert(key_name='site')
+ if not model.secret:
+ model.secret = _generate_new_xsrf_secret_key()
+ model.put()
+ secret = model.secret
+ memcache.add(XSRF_MEMCACHE_ID, secret, namespace=OAUTH2CLIENT_NAMESPACE)
- return str(secret)
+ return str(secret)
class AppAssertionCredentials(AssertionCredentials):
- """Credentials object for App Engine Assertion Grants
+ """Credentials object for App Engine Assertion Grants
This object will allow an App Engine application to identify itself to Google
and other OAuth 2.0 servers that can verify assertions. It can be used for the
@@ -151,9 +150,9 @@ class AppAssertionCredentials(AssertionCredentials):
generate and refresh its own access tokens.
"""
- @util.positional(2)
- def __init__(self, scope, **kwargs):
- """Constructor for AppAssertionCredentials
+ @util.positional(2)
+ def __init__(self, scope, **kwargs):
+ """Constructor for AppAssertionCredentials
Args:
scope: string or iterable of strings, scope(s) of the credentials being
@@ -162,20 +161,20 @@ class AppAssertionCredentials(AssertionCredentials):
service_account_id: service account id of the application. If None or
unspecified, the default service account for the app is used.
"""
- self.scope = util.scopes_to_string(scope)
- self._kwargs = kwargs
- self.service_account_id = kwargs.get('service_account_id', None)
+ self.scope = util.scopes_to_string(scope)
+ self._kwargs = kwargs
+ self.service_account_id = kwargs.get('service_account_id', None)
- # Assertion type is no longer used, but still in the parent class signature.
- super(AppAssertionCredentials, self).__init__(None)
+ # Assertion type is no longer used, but still in the parent class signature.
+ super(AppAssertionCredentials, self).__init__(None)
- @classmethod
- def from_json(cls, json_data):
- data = json.loads(json_data)
- return AppAssertionCredentials(data['scope'])
+ @classmethod
+ def from_json(cls, json_data):
+ data = json.loads(json_data)
+ return AppAssertionCredentials(data['scope'])
- def _refresh(self, http_request):
- """Refreshes the access_token.
+ def _refresh(self, http_request):
+ """Refreshes the access_token.
Since the underlying App Engine app_identity implementation does its own
caching we can skip all the storage hoops and just to a refresh using the
@@ -188,60 +187,60 @@ class AppAssertionCredentials(AssertionCredentials):
Raises:
AccessTokenRefreshError: When the refresh fails.
"""
- try:
- scopes = self.scope.split()
- (token, _) = app_identity.get_access_token(
+ try:
+ scopes = self.scope.split()
+ (token, _) = app_identity.get_access_token(
scopes, service_account_id=self.service_account_id)
- except app_identity.Error as e:
- raise AccessTokenRefreshError(str(e))
- self.access_token = token
+ except app_identity.Error as e:
+ raise AccessTokenRefreshError(str(e))
+ self.access_token = token
- @property
- def serialization_data(self):
- raise NotImplementedError('Cannot serialize credentials for AppEngine.')
+ @property
+ def serialization_data(self):
+ raise NotImplementedError('Cannot serialize credentials for AppEngine.')
- def create_scoped_required(self):
- return not self.scope
+ def create_scoped_required(self):
+ return not self.scope
- def create_scoped(self, scopes):
- return AppAssertionCredentials(scopes, **self._kwargs)
+ def create_scoped(self, scopes):
+ return AppAssertionCredentials(scopes, **self._kwargs)
class FlowProperty(db.Property):
- """App Engine datastore Property for Flow.
+ """App Engine datastore Property for Flow.
Utility property that allows easy storage and retrieval of an
oauth2client.Flow"""
- # Tell what the user type is.
- data_type = Flow
+ # Tell what the user type is.
+ data_type = Flow
- # For writing to datastore.
- def get_value_for_datastore(self, model_instance):
- flow = super(FlowProperty,
+ # For writing to datastore.
+ def get_value_for_datastore(self, model_instance):
+ flow = super(FlowProperty,
self).get_value_for_datastore(model_instance)
- return db.Blob(pickle.dumps(flow))
+ return db.Blob(pickle.dumps(flow))
- # For reading from datastore.
- def make_value_from_datastore(self, value):
- if value is None:
- return None
- return pickle.loads(value)
+ # For reading from datastore.
+ def make_value_from_datastore(self, value):
+ if value is None:
+ return None
+ return pickle.loads(value)
- def validate(self, value):
- if value is not None and not isinstance(value, Flow):
- raise db.BadValueError('Property %s must be convertible '
+ def validate(self, value):
+ if value is not None and not isinstance(value, Flow):
+ raise db.BadValueError('Property %s must be convertible '
'to a FlowThreeLegged instance (%s)' %
(self.name, value))
- return super(FlowProperty, self).validate(value)
+ return super(FlowProperty, self).validate(value)
- def empty(self, value):
- return not value
+ def empty(self, value):
+ return not value
if ndb is not None:
- class FlowNDBProperty(ndb.PickleProperty):
- """App Engine NDB datastore Property for Flow.
+ class FlowNDBProperty(ndb.PickleProperty):
+ """App Engine NDB datastore Property for Flow.
Serves the same purpose as the DB FlowProperty, but for NDB models. Since
PickleProperty inherits from BlobProperty, the underlying representation of
@@ -251,8 +250,8 @@ if ndb is not None:
oauth2client.Flow
"""
- def _validate(self, value):
- """Validates a value as a proper Flow object.
+ def _validate(self, value):
+ """Validates a value as a proper Flow object.
Args:
value: A value to be set on the property.
@@ -260,64 +259,66 @@ if ndb is not None:
Raises:
TypeError if the value is not an instance of Flow.
"""
- logger.info('validate: Got type %s', type(value))
- if value is not None and not isinstance(value, Flow):
- raise TypeError('Property %s must be convertible to a flow '
+ logger.info('validate: Got type %s', type(value))
+ if value is not None and not isinstance(value, Flow):
+ raise TypeError('Property %s must be convertible to a flow '
'instance; received: %s.' % (self._name, value))
class CredentialsProperty(db.Property):
- """App Engine datastore Property for Credentials.
+ """App Engine datastore Property for Credentials.
Utility property that allows easy storage and retrieval of
oath2client.Credentials
"""
- # Tell what the user type is.
- data_type = Credentials
+ # Tell what the user type is.
+ data_type = Credentials
- # For writing to datastore.
- def get_value_for_datastore(self, model_instance):
- logger.info("get: Got type " + str(type(model_instance)))
- cred = super(CredentialsProperty,
+ # For writing to datastore.
+ def get_value_for_datastore(self, model_instance):
+ logger.info("get: Got type " + str(type(model_instance)))
+ cred = super(CredentialsProperty,
self).get_value_for_datastore(model_instance)
- if cred is None:
- cred = ''
- else:
- cred = cred.to_json()
- return db.Blob(cred)
+ if cred is None:
+ cred = ''
+ else:
+ cred = cred.to_json()
+ return db.Blob(cred)
- # For reading from datastore.
- def make_value_from_datastore(self, value):
- logger.info("make: Got type " + str(type(value)))
- if value is None:
- return None
- if len(value) == 0:
- return None
- try:
- credentials = Credentials.new_from_json(value)
- except ValueError:
- credentials = None
- return credentials
+ # For reading from datastore.
+ def make_value_from_datastore(self, value):
+ logger.info("make: Got type " + str(type(value)))
+ if value is None:
+ return None
+ if len(value) == 0:
+ return None
+ try:
+ credentials = Credentials.new_from_json(value)
+ except ValueError:
+ credentials = None
+ return credentials
- def validate(self, value):
- value = super(CredentialsProperty, self).validate(value)
- logger.info("validate: Got type " + str(type(value)))
- if value is not None and not isinstance(value, Credentials):
- raise db.BadValueError('Property %s must be convertible '
+ def validate(self, value):
+ value = super(CredentialsProperty, self).validate(value)
+ logger.info("validate: Got type " + str(type(value)))
+ if value is not None and not isinstance(value, Credentials):
+ raise db.BadValueError('Property %s must be convertible '
'to a Credentials instance (%s)' %
(self.name, value))
- #if value is not None and not isinstance(value, Credentials):
- # return None
- return value
+ #if value is not None and not isinstance(value, Credentials):
+ # return None
+ return value
if ndb is not None:
- # TODO(dhermes): Turn this into a JsonProperty and overhaul the Credentials
- # and subclass mechanics to use new_from_dict, to_dict,
- # from_dict, etc.
- class CredentialsNDBProperty(ndb.BlobProperty):
- """App Engine NDB datastore Property for Credentials.
+
+
+ # TODO(dhermes): Turn this into a JsonProperty and overhaul the Credentials
+ # and subclass mechanics to use new_from_dict, to_dict,
+ # from_dict, etc.
+ class CredentialsNDBProperty(ndb.BlobProperty):
+ """App Engine NDB datastore Property for Credentials.
Serves the same purpose as the DB CredentialsProperty, but for NDB models.
Since CredentialsProperty stores data as a blob and this inherits from
@@ -326,8 +327,9 @@ if ndb is not None:
Utility property that allows easy storage and retrieval of Credentials and
subclasses.
"""
- def _validate(self, value):
- """Validates a value as a proper credentials object.
+
+ def _validate(self, value):
+ """Validates a value as a proper credentials object.
Args:
value: A value to be set on the property.
@@ -335,13 +337,13 @@ if ndb is not None:
Raises:
TypeError if the value is not an instance of Credentials.
"""
- logger.info('validate: Got type %s', type(value))
- if value is not None and not isinstance(value, Credentials):
- raise TypeError('Property %s must be convertible to a credentials '
+ logger.info('validate: Got type %s', type(value))
+ if value is not None and not isinstance(value, Credentials):
+ raise TypeError('Property %s must be convertible to a credentials '
'instance; received: %s.' % (self._name, value))
- def _to_base_type(self, value):
- """Converts our validated value to a JSON serialized string.
+ def _to_base_type(self, value):
+ """Converts our validated value to a JSON serialized string.
Args:
value: A value to be set in the datastore.
@@ -349,13 +351,13 @@ if ndb is not None:
Returns:
A JSON serialized version of the credential, else '' if value is None.
"""
- if value is None:
- return ''
- else:
- return value.to_json()
+ if value is None:
+ return ''
+ else:
+ return value.to_json()
- def _from_base_type(self, value):
- """Converts our stored JSON string back to the desired type.
+ def _from_base_type(self, value):
+ """Converts our stored JSON string back to the desired type.
Args:
value: A value from the datastore to be converted to the desired type.
@@ -364,27 +366,27 @@ if ndb is not None:
A deserialized Credentials (or subclass) object, else None if the
value can't be parsed.
"""
- if not value:
- return None
- try:
- # Uses the from_json method of the implied class of value
- credentials = Credentials.new_from_json(value)
- except ValueError:
- credentials = None
- return credentials
+ if not value:
+ return None
+ try:
+ # Uses the from_json method of the implied class of value
+ credentials = Credentials.new_from_json(value)
+ except ValueError:
+ credentials = None
+ return credentials
class StorageByKeyName(Storage):
- """Store and retrieve a credential to and from the App Engine datastore.
+ """Store and retrieve a credential to and from the App Engine datastore.
This Storage helper presumes the Credentials have been stored as a
CredentialsProperty or CredentialsNDBProperty on a datastore model class, and
that entities are stored by key_name.
"""
- @util.positional(4)
- def __init__(self, model, key_name, property_name, cache=None, user=None):
- """Constructor for Storage.
+ @util.positional(4)
+ def __init__(self, model, key_name, property_name, cache=None, user=None):
+ """Constructor for Storage.
Args:
model: db.Model or ndb.Model, model class
@@ -397,34 +399,34 @@ class StorageByKeyName(Storage):
user: users.User object, optional. Can be used to grab user ID as a
key_name if no key name is specified.
"""
- if key_name is None:
- if user is None:
- raise ValueError('StorageByKeyName called with no key name or user.')
- key_name = user.user_id()
+ if key_name is None:
+ if user is None:
+ raise ValueError('StorageByKeyName called with no key name or user.')
+ key_name = user.user_id()
- self._model = model
- self._key_name = key_name
- self._property_name = property_name
- self._cache = cache
+ self._model = model
+ self._key_name = key_name
+ self._property_name = property_name
+ self._cache = cache
- def _is_ndb(self):
- """Determine whether the model of the instance is an NDB model.
+ def _is_ndb(self):
+ """Determine whether the model of the instance is an NDB model.
Returns:
Boolean indicating whether or not the model is an NDB or DB model.
"""
- # issubclass will fail if one of the arguments is not a class, only need
- # worry about new-style classes since ndb and db models are new-style
- if isinstance(self._model, type):
- if ndb is not None and issubclass(self._model, ndb.Model):
- return True
- elif issubclass(self._model, db.Model):
- return False
+ # issubclass will fail if one of the arguments is not a class, only need
+ # worry about new-style classes since ndb and db models are new-style
+ if isinstance(self._model, type):
+ if ndb is not None and issubclass(self._model, ndb.Model):
+ return True
+ elif issubclass(self._model, db.Model):
+ return False
- raise TypeError('Model class not an NDB or DB model: %s.' % (self._model,))
+ raise TypeError('Model class not an NDB or DB model: %s.' % (self._model, ))
- def _get_entity(self):
- """Retrieve entity from datastore.
+ def _get_entity(self):
+ """Retrieve entity from datastore.
Uses a different model method for db or ndb models.
@@ -432,80 +434,80 @@ class StorageByKeyName(Storage):
Instance of the model corresponding to the current storage object
and stored using the key name of the storage object.
"""
- if self._is_ndb():
- return self._model.get_by_id(self._key_name)
- else:
- return self._model.get_by_key_name(self._key_name)
+ if self._is_ndb():
+ return self._model.get_by_id(self._key_name)
+ else:
+ return self._model.get_by_key_name(self._key_name)
- def _delete_entity(self):
- """Delete entity from datastore.
+ def _delete_entity(self):
+ """Delete entity from datastore.
Attempts to delete using the key_name stored on the object, whether or not
the given key is in the datastore.
"""
- if self._is_ndb():
- ndb.Key(self._model, self._key_name).delete()
- else:
- entity_key = db.Key.from_path(self._model.kind(), self._key_name)
- db.delete(entity_key)
+ if self._is_ndb():
+ ndb.Key(self._model, self._key_name).delete()
+ else:
+ entity_key = db.Key.from_path(self._model.kind(), self._key_name)
+ db.delete(entity_key)
- @db.non_transactional(allow_existing=True)
+ @db.non_transactional(allow_existing=True)
def locked_get(self):
- """Retrieve Credential from datastore.
+ """Retrieve Credential from datastore.
Returns:
oauth2client.Credentials
"""
- credentials = None
- if self._cache:
- json = self._cache.get(self._key_name)
- if json:
- credentials = Credentials.new_from_json(json)
- if credentials is None:
- entity = self._get_entity()
- if entity is not None:
- credentials = getattr(entity, self._property_name)
+ credentials = None
if self._cache:
- self._cache.set(self._key_name, credentials.to_json())
+ json = self._cache.get(self._key_name)
+ if json:
+ credentials = Credentials.new_from_json(json)
+ if credentials is None:
+ entity = self._get_entity()
+ if entity is not None:
+ credentials = getattr(entity, self._property_name)
+ if self._cache:
+ self._cache.set(self._key_name, credentials.to_json())
- if credentials and hasattr(credentials, 'set_store'):
- credentials.set_store(self)
- return credentials
+ if credentials and hasattr(credentials, 'set_store'):
+ credentials.set_store(self)
+ return credentials
- @db.non_transactional(allow_existing=True)
- def locked_put(self, credentials):
- """Write a Credentials to the datastore.
+ @db.non_transactional(allow_existing=True)
+ def locked_put(self, credentials):
+ """Write a Credentials to the datastore.
Args:
credentials: Credentials, the credentials to store.
"""
- entity = self._model.get_or_insert(self._key_name)
- setattr(entity, self._property_name, credentials)
- entity.put()
- if self._cache:
- self._cache.set(self._key_name, credentials.to_json())
+ entity = self._model.get_or_insert(self._key_name)
+ setattr(entity, self._property_name, credentials)
+ entity.put()
+ if self._cache:
+ self._cache.set(self._key_name, credentials.to_json())
- @db.non_transactional(allow_existing=True)
+ @db.non_transactional(allow_existing=True)
def locked_delete(self):
- """Delete Credential from datastore."""
+ """Delete Credential from datastore."""
- if self._cache:
- self._cache.delete(self._key_name)
+ if self._cache:
+ self._cache.delete(self._key_name)
- self._delete_entity()
+ self._delete_entity()
class CredentialsModel(db.Model):
- """Storage for OAuth 2.0 Credentials
+ """Storage for OAuth 2.0 Credentials
Storage of the model is keyed by the user.user_id().
"""
- credentials = CredentialsProperty()
+ credentials = CredentialsProperty()
if ndb is not None:
- class CredentialsNDBModel(ndb.Model):
- """NDB Model for storage of OAuth 2.0 Credentials
+ class CredentialsNDBModel(ndb.Model):
+ """NDB Model for storage of OAuth 2.0 Credentials
Since this model uses the same kind as CredentialsModel and has a property
which can serialize and deserialize Credentials correctly, it can be used
@@ -515,16 +517,16 @@ if ndb is not None:
Storage of the model is keyed by the user.user_id().
"""
- credentials = CredentialsNDBProperty()
+ credentials = CredentialsNDBProperty()
- @classmethod
- def _get_kind(cls):
- """Return the kind name for this class."""
- return 'CredentialsModel'
+ @classmethod
+ def _get_kind(cls):
+ """Return the kind name for this class."""
+ return 'CredentialsModel'
def _build_state_value(request_handler, user):
- """Composes the value for the 'state' parameter.
+ """Composes the value for the 'state' parameter.
Packs the current request URI and an XSRF token into an opaque string that
can be passed to the authentication server via the 'state' parameter.
@@ -536,14 +538,14 @@ def _build_state_value(request_handler, user):
Returns:
The state value as a string.
"""
- uri = request_handler.request.url
- token = xsrfutil.generate_token(xsrf_secret_key(), user.user_id(),
+ uri = request_handler.request.url
+ token = xsrfutil.generate_token(xsrf_secret_key(), user.user_id(),
action_id=str(uri))
- return uri + ':' + token
+ return uri + ':' + token
def _parse_state_value(state, user):
- """Parse the value of the 'state' parameter.
+ """Parse the value of the 'state' parameter.
Parses the value and validates the XSRF token in the state parameter.
@@ -557,16 +559,16 @@ def _parse_state_value(state, user):
Returns:
The redirect URI.
"""
- uri, token = state.rsplit(':', 1)
- if not xsrfutil.validate_token(xsrf_secret_key(), token, user.user_id(),
+ uri, token = state.rsplit(':', 1)
+ if not xsrfutil.validate_token(xsrf_secret_key(), token, user.user_id(),
action_id=uri):
- raise InvalidXsrfTokenError()
+ raise InvalidXsrfTokenError()
- return uri
+ return uri
class OAuth2Decorator(object):
- """Utility for making OAuth 2.0 easier.
+ """Utility for making OAuth 2.0 easier.
Instantiate and then use with oauth_required or oauth_aware
as decorators on webapp.RequestHandler methods.
@@ -587,39 +589,38 @@ class OAuth2Decorator(object):
"""
- def set_credentials(self, credentials):
- self._tls.credentials = credentials
+ def set_credentials(self, credentials):
+ self._tls.credentials = credentials
- def get_credentials(self):
- """A thread local Credentials object.
+ def get_credentials(self):
+ """A thread local Credentials object.
Returns:
A client.Credentials object, or None if credentials hasn't been set in
this thread yet, which may happen when calling has_credentials inside
oauth_aware.
"""
- return getattr(self._tls, 'credentials', None)
+ return getattr(self._tls, 'credentials', None)
- credentials = property(get_credentials, set_credentials)
+ credentials = property(get_credentials, set_credentials)
- def set_flow(self, flow):
- self._tls.flow = flow
+ def set_flow(self, flow):
+ self._tls.flow = flow
- def get_flow(self):
- """A thread local Flow object.
+ def get_flow(self):
+ """A thread local Flow object.
Returns:
A credentials.Flow object, or None if the flow hasn't been set in this
thread yet, which happens in _create_flow() since Flows are created
lazily.
"""
- return getattr(self._tls, 'flow', None)
+ return getattr(self._tls, 'flow', None)
- flow = property(get_flow, set_flow)
+ flow = property(get_flow, set_flow)
-
- @util.positional(4)
- def __init__(self, client_id, client_secret, scope,
+ @util.positional(4)
+ def __init__(self, client_id, client_secret, scope,
auth_uri=GOOGLE_AUTH_URI,
token_uri=GOOGLE_TOKEN_URI,
revoke_uri=GOOGLE_REVOKE_URI,
@@ -632,7 +633,7 @@ class OAuth2Decorator(object):
_credentials_property_name='credentials',
**kwargs):
- """Constructor for OAuth2Decorator
+ """Constructor for OAuth2Decorator
Args:
client_id: string, client identifier.
@@ -670,32 +671,32 @@ class OAuth2Decorator(object):
the OAuth2WebServerFlow constructor.
"""
- self._tls = threading.local()
- self.flow = None
- self.credentials = None
- self._client_id = client_id
- self._client_secret = client_secret
- self._scope = util.scopes_to_string(scope)
- self._auth_uri = auth_uri
- self._token_uri = token_uri
- self._revoke_uri = revoke_uri
- self._user_agent = user_agent
- self._kwargs = kwargs
- self._message = message
- self._in_error = False
- self._callback_path = callback_path
- self._token_response_param = token_response_param
- self._storage_class = _storage_class
- self._credentials_class = _credentials_class
- self._credentials_property_name = _credentials_property_name
+ self._tls = threading.local()
+ self.flow = None
+ self.credentials = None
+ self._client_id = client_id
+ self._client_secret = client_secret
+ self._scope = util.scopes_to_string(scope)
+ self._auth_uri = auth_uri
+ self._token_uri = token_uri
+ self._revoke_uri = revoke_uri
+ self._user_agent = user_agent
+ self._kwargs = kwargs
+ self._message = message
+ self._in_error = False
+ self._callback_path = callback_path
+ self._token_response_param = token_response_param
+ self._storage_class = _storage_class
+ self._credentials_class = _credentials_class
+ self._credentials_property_name = _credentials_property_name
- def _display_error_message(self, request_handler):
- request_handler.response.out.write('
')
- request_handler.response.out.write(_safe_html(self._message))
- request_handler.response.out.write('')
+ def _display_error_message(self, request_handler):
+ request_handler.response.out.write('')
+ request_handler.response.out.write(_safe_html(self._message))
+ request_handler.response.out.write('')
- def oauth_required(self, method):
- """Decorator that starts the OAuth 2.0 dance.
+ def oauth_required(self, method):
+ """Decorator that starts the OAuth 2.0 dance.
Starts the OAuth dance for the logged in user if they haven't already
granted access for this application.
@@ -705,40 +706,40 @@ class OAuth2Decorator(object):
instance.
"""
- def check_oauth(request_handler, *args, **kwargs):
- if self._in_error:
- self._display_error_message(request_handler)
- return
+ def check_oauth(request_handler, *args, **kwargs):
+ if self._in_error:
+ self._display_error_message(request_handler)
+ return
- user = users.get_current_user()
- # Don't use @login_decorator as this could be used in a POST request.
- if not user:
- request_handler.redirect(users.create_login_url(
+ user = users.get_current_user()
+ # Don't use @login_decorator as this could be used in a POST request.
+ if not user:
+ request_handler.redirect(users.create_login_url(
request_handler.request.uri))
- return
+ return
- self._create_flow(request_handler)
+ self._create_flow(request_handler)
- # Store the request URI in 'state' so we can use it later
- self.flow.params['state'] = _build_state_value(request_handler, user)
- self.credentials = self._storage_class(
+ # Store the request URI in 'state' so we can use it later
+ self.flow.params['state'] = _build_state_value(request_handler, user)
+ self.credentials = self._storage_class(
self._credentials_class, None,
self._credentials_property_name, user=user).get()
- if not self.has_credentials():
- return request_handler.redirect(self.authorize_url())
- try:
- resp = method(request_handler, *args, **kwargs)
- except AccessTokenRefreshError:
- return request_handler.redirect(self.authorize_url())
- finally:
- self.credentials = None
- return resp
+ if not self.has_credentials():
+ return request_handler.redirect(self.authorize_url())
+ try:
+ resp = method(request_handler, *args, **kwargs)
+ except AccessTokenRefreshError:
+ return request_handler.redirect(self.authorize_url())
+ finally:
+ self.credentials = None
+ return resp
- return check_oauth
+ return check_oauth
- def _create_flow(self, request_handler):
- """Create the Flow object.
+ def _create_flow(self, request_handler):
+ """Create the Flow object.
The Flow is calculated lazily since we don't know where this app is
running until it receives a request, at which point redirect_uri can be
@@ -747,10 +748,10 @@ class OAuth2Decorator(object):
Args:
request_handler: webapp.RequestHandler, the request handler.
"""
- if self.flow is None:
- redirect_uri = request_handler.request.relative_url(
- self._callback_path) # Usually /oauth2callback
- self.flow = OAuth2WebServerFlow(self._client_id, self._client_secret,
+ if self.flow is None:
+ redirect_uri = request_handler.request.relative_url(
+ self._callback_path) # Usually /oauth2callback
+ self.flow = OAuth2WebServerFlow(self._client_id, self._client_secret,
self._scope, redirect_uri=redirect_uri,
user_agent=self._user_agent,
auth_uri=self._auth_uri,
@@ -758,8 +759,8 @@ class OAuth2Decorator(object):
revoke_uri=self._revoke_uri,
**self._kwargs)
- def oauth_aware(self, method):
- """Decorator that sets up for OAuth 2.0 dance, but doesn't do it.
+ def oauth_aware(self, method):
+ """Decorator that sets up for OAuth 2.0 dance, but doesn't do it.
Does all the setup for the OAuth dance, but doesn't initiate it.
This decorator is useful if you want to create a page that knows
@@ -772,51 +773,50 @@ class OAuth2Decorator(object):
instance.
"""
- def setup_oauth(request_handler, *args, **kwargs):
- if self._in_error:
- self._display_error_message(request_handler)
- return
+ def setup_oauth(request_handler, *args, **kwargs):
+ if self._in_error:
+ self._display_error_message(request_handler)
+ return
- user = users.get_current_user()
- # Don't use @login_decorator as this could be used in a POST request.
- if not user:
- request_handler.redirect(users.create_login_url(
+ user = users.get_current_user()
+ # Don't use @login_decorator as this could be used in a POST request.
+ if not user:
+ request_handler.redirect(users.create_login_url(
request_handler.request.uri))
- return
+ return
- self._create_flow(request_handler)
+ self._create_flow(request_handler)
- self.flow.params['state'] = _build_state_value(request_handler, user)
- self.credentials = self._storage_class(
+ self.flow.params['state'] = _build_state_value(request_handler, user)
+ self.credentials = self._storage_class(
self._credentials_class, None,
self._credentials_property_name, user=user).get()
- try:
- resp = method(request_handler, *args, **kwargs)
- finally:
- self.credentials = None
- return resp
- return setup_oauth
+ try:
+ resp = method(request_handler, *args, **kwargs)
+ finally:
+ self.credentials = None
+ return resp
+ return setup_oauth
-
- def has_credentials(self):
- """True if for the logged in user there are valid access Credentials.
+ def has_credentials(self):
+ """True if for the logged in user there are valid access Credentials.
Must only be called from with a webapp.RequestHandler subclassed method
that had been decorated with either @oauth_required or @oauth_aware.
"""
- return self.credentials is not None and not self.credentials.invalid
+ return self.credentials is not None and not self.credentials.invalid
- def authorize_url(self):
- """Returns the URL to start the OAuth dance.
+ def authorize_url(self):
+ """Returns the URL to start the OAuth dance.
Must only be called from with a webapp.RequestHandler subclassed method
that had been decorated with either @oauth_required or @oauth_aware.
"""
- url = self.flow.step1_get_authorize_url()
- return str(url)
+ url = self.flow.step1_get_authorize_url()
+ return str(url)
- def http(self, *args, **kwargs):
- """Returns an authorized http instance.
+ def http(self, *args, **kwargs):
+ """Returns an authorized http instance.
Must only be called from within an @oauth_required decorated method, or
from within an @oauth_aware decorated method where has_credentials()
@@ -826,11 +826,11 @@ class OAuth2Decorator(object):
*args: Positional arguments passed to httplib2.Http constructor.
**kwargs: Positional arguments passed to httplib2.Http constructor.
"""
- return self.credentials.authorize(httplib2.Http(*args, **kwargs))
+ return self.credentials.authorize(httplib2.Http(*args, **kwargs))
- @property
- def callback_path(self):
- """The absolute path where the callback will occur.
+ @property
+ def callback_path(self):
+ """The absolute path where the callback will occur.
Note this is the absolute path, not the absolute URI, that will be
calculated by the decorator at runtime. See callback_handler() for how this
@@ -839,11 +839,10 @@ class OAuth2Decorator(object):
Returns:
The callback path as a string.
"""
- return self._callback_path
+ return self._callback_path
-
- def callback_handler(self):
- """RequestHandler for the OAuth 2.0 redirect callback.
+ def callback_handler(self):
+ """RequestHandler for the OAuth 2.0 redirect callback.
Usage::
@@ -857,39 +856,39 @@ class OAuth2Decorator(object):
A webapp.RequestHandler that handles the redirect back from the
server during the OAuth 2.0 dance.
"""
- decorator = self
+ decorator = self
- class OAuth2Handler(webapp.RequestHandler):
- """Handler for the redirect_uri of the OAuth 2.0 dance."""
+ class OAuth2Handler(webapp.RequestHandler):
+ """Handler for the redirect_uri of the OAuth 2.0 dance."""
- @login_required
- def get(self):
- error = self.request.get('error')
- if error:
- errormsg = self.request.get('error_description', error)
- self.response.out.write(
+ @login_required
+ def get(self):
+ error = self.request.get('error')
+ if error:
+ errormsg = self.request.get('error_description', error)
+ self.response.out.write(
'The authorization request failed: %s' % _safe_html(errormsg))
- else:
- user = users.get_current_user()
- decorator._create_flow(self)
- credentials = decorator.flow.step2_exchange(self.request.params)
- decorator._storage_class(
+ else:
+ user = users.get_current_user()
+ decorator._create_flow(self)
+ credentials = decorator.flow.step2_exchange(self.request.params)
+ decorator._storage_class(
decorator._credentials_class, None,
decorator._credentials_property_name, user=user).put(credentials)
- redirect_uri = _parse_state_value(str(self.request.get('state')),
+ redirect_uri = _parse_state_value(str(self.request.get('state')),
user)
- if decorator._token_response_param and credentials.token_response:
- resp_json = json.dumps(credentials.token_response)
- redirect_uri = util._add_query_parameter(
+ if decorator._token_response_param and credentials.token_response:
+ resp_json = json.dumps(credentials.token_response)
+ redirect_uri = util._add_query_parameter(
redirect_uri, decorator._token_response_param, resp_json)
- self.redirect(redirect_uri)
+ self.redirect(redirect_uri)
- return OAuth2Handler
+ return OAuth2Handler
- def callback_application(self):
- """WSGI application for handling the OAuth 2.0 redirect callback.
+ def callback_application(self):
+ """WSGI application for handling the OAuth 2.0 redirect callback.
If you need finer grained control use `callback_handler` which returns just
the webapp.RequestHandler.
@@ -898,13 +897,13 @@ class OAuth2Decorator(object):
A webapp.WSGIApplication that handles the redirect back from the
server during the OAuth 2.0 dance.
"""
- return webapp.WSGIApplication([
+ return webapp.WSGIApplication([
(self.callback_path, self.callback_handler())
])
class OAuth2DecoratorFromClientSecrets(OAuth2Decorator):
- """An OAuth2Decorator that builds from a clientsecrets file.
+ """An OAuth2Decorator that builds from a clientsecrets file.
Uses a clientsecrets file as the source for all the information when
constructing an OAuth2Decorator.
@@ -924,9 +923,9 @@ class OAuth2DecoratorFromClientSecrets(OAuth2Decorator):
"""
- @util.positional(3)
- def __init__(self, filename, scope, message=None, cache=None, **kwargs):
- """Constructor
+ @util.positional(3)
+ def __init__(self, filename, scope, message=None, cache=None, **kwargs):
+ """Constructor
Args:
filename: string, File name of client secrets.
@@ -941,33 +940,33 @@ class OAuth2DecoratorFromClientSecrets(OAuth2Decorator):
**kwargs: dict, Keyword arguments are passed along as kwargs to
the OAuth2WebServerFlow constructor.
"""
- client_type, client_info = clientsecrets.loadfile(filename, cache=cache)
- if client_type not in [
+ client_type, client_info = clientsecrets.loadfile(filename, cache=cache)
+ if client_type not in [
clientsecrets.TYPE_WEB, clientsecrets.TYPE_INSTALLED]:
- raise InvalidClientSecretsError(
+ raise InvalidClientSecretsError(
"OAuth2Decorator doesn't support this OAuth 2.0 flow.")
- constructor_kwargs = dict(kwargs)
- constructor_kwargs.update({
+ constructor_kwargs = dict(kwargs)
+ constructor_kwargs.update({
'auth_uri': client_info['auth_uri'],
'token_uri': client_info['token_uri'],
'message': message,
- })
- revoke_uri = client_info.get('revoke_uri')
- if revoke_uri is not None:
- constructor_kwargs['revoke_uri'] = revoke_uri
- super(OAuth2DecoratorFromClientSecrets, self).__init__(
+ })
+ revoke_uri = client_info.get('revoke_uri')
+ if revoke_uri is not None:
+ constructor_kwargs['revoke_uri'] = revoke_uri
+ super(OAuth2DecoratorFromClientSecrets, self).__init__(
client_info['client_id'], client_info['client_secret'],
scope, **constructor_kwargs)
- if message is not None:
- self._message = message
- else:
- self._message = 'Please configure your application for OAuth 2.0.'
+ if message is not None:
+ self._message = message
+ else:
+ self._message = 'Please configure your application for OAuth 2.0.'
@util.positional(2)
def oauth2decorator_from_clientsecrets(filename, scope,
message=None, cache=None):
- """Creates an OAuth2Decorator populated from a clientsecrets file.
+ """Creates an OAuth2Decorator populated from a clientsecrets file.
Args:
filename: string, File name of client secrets.
@@ -983,5 +982,5 @@ def oauth2decorator_from_clientsecrets(filename, scope,
Returns: An OAuth2Decorator
"""
- return OAuth2DecoratorFromClientSecrets(filename, scope,
+ return OAuth2DecoratorFromClientSecrets(filename, scope,
message=message, cache=cache)
diff --git a/oauth2client/client.py b/oauth2client/client.py
index 69d2db1..45e54fe 100644
--- a/oauth2client/client.py
+++ b/oauth2client/client.py
@@ -49,12 +49,12 @@ from oauth2client import util
HAS_OPENSSL = False
HAS_CRYPTO = False
try:
- from oauth2client import crypt
- HAS_CRYPTO = True
- if crypt.OpenSSLVerifier is not None:
- HAS_OPENSSL = True
+ from oauth2client import crypt
+ HAS_CRYPTO = True
+ if crypt.OpenSSLVerifier is not None:
+ HAS_OPENSSL = True
except ImportError:
- pass
+ pass
logger = logging.getLogger(__name__)
@@ -107,77 +107,78 @@ DEFAULT_ENV_NAME = 'UNKNOWN'
# If set to True _get_environment avoid GCE check (_detect_gce_environment)
NO_GCE_CHECK = os.environ.setdefault('NO_GCE_CHECK', 'False')
+
class SETTINGS(object):
- """Settings namespace for globally defined values."""
- env_name = None
+ """Settings namespace for globally defined values."""
+ env_name = None
class Error(Exception):
- """Base error for this module."""
+ """Base error for this module."""
class FlowExchangeError(Error):
- """Error trying to exchange an authorization grant for an access token."""
+ """Error trying to exchange an authorization grant for an access token."""
class AccessTokenRefreshError(Error):
- """Error trying to refresh an expired access token."""
+ """Error trying to refresh an expired access token."""
class TokenRevokeError(Error):
- """Error trying to revoke a token."""
+ """Error trying to revoke a token."""
class UnknownClientSecretsFlowError(Error):
- """The client secrets file called for an unknown type of OAuth 2.0 flow. """
+ """The client secrets file called for an unknown type of OAuth 2.0 flow. """
class AccessTokenCredentialsError(Error):
- """Having only the access_token means no refresh is possible."""
+ """Having only the access_token means no refresh is possible."""
class VerifyJwtTokenError(Error):
- """Could not retrieve certificates for validation."""
+ """Could not retrieve certificates for validation."""
class NonAsciiHeaderError(Error):
- """Header names and values must be ASCII strings."""
+ """Header names and values must be ASCII strings."""
class ApplicationDefaultCredentialsError(Error):
- """Error retrieving the Application Default Credentials."""
+ """Error retrieving the Application Default Credentials."""
class OAuth2DeviceCodeError(Error):
- """Error trying to retrieve a device code."""
+ """Error trying to retrieve a device code."""
class CryptoUnavailableError(Error, NotImplementedError):
- """Raised when a crypto library is required, but none is available."""
+ """Raised when a crypto library is required, but none is available."""
def _abstract():
- raise NotImplementedError('You need to override this function')
+ raise NotImplementedError('You need to override this function')
class MemoryCache(object):
- """httplib2 Cache implementation which only caches locally."""
+ """httplib2 Cache implementation which only caches locally."""
- def __init__(self):
- self.cache = {}
+ def __init__(self):
+ self.cache = {}
- def get(self, key):
- return self.cache.get(key)
+ def get(self, key):
+ return self.cache.get(key)
- def set(self, key, value):
- self.cache[key] = value
+ def set(self, key, value):
+ self.cache[key] = value
- def delete(self, key):
- self.cache.pop(key, None)
+ def delete(self, key):
+ self.cache.pop(key, None)
class Credentials(object):
- """Base class for all Credentials objects.
+ """Base class for all Credentials objects.
Subclasses must define an authorize() method that applies the credentials to
an HTTP transport.
@@ -186,11 +187,10 @@ class Credentials(object):
string as input and returns an instantiated Credentials object.
"""
- NON_SERIALIZED_MEMBERS = ['store']
+ NON_SERIALIZED_MEMBERS = ['store']
-
- def authorize(self, http):
- """Take an httplib2.Http instance (or equivalent) and authorizes it.
+ def authorize(self, http):
+ """Take an httplib2.Http instance (or equivalent) and authorizes it.
Authorizes it for the set of credentials, usually by replacing
http.request() with a method that adds in the appropriate headers and then
@@ -200,39 +200,36 @@ class Credentials(object):
http: httplib2.Http, an http object to be used to make the refresh
request.
"""
- _abstract()
+ _abstract()
-
- def refresh(self, http):
- """Forces a refresh of the access_token.
+ def refresh(self, http):
+ """Forces a refresh of the access_token.
Args:
http: httplib2.Http, an http object to be used to make the refresh
request.
"""
- _abstract()
+ _abstract()
-
- def revoke(self, http):
- """Revokes a refresh_token and makes the credentials void.
+ def revoke(self, http):
+ """Revokes a refresh_token and makes the credentials void.
Args:
http: httplib2.Http, an http object to be used to make the revoke
request.
"""
- _abstract()
+ _abstract()
-
- def apply(self, headers):
- """Add the authorization to the headers.
+ def apply(self, headers):
+ """Add the authorization to the headers.
Args:
headers: dict, the headers to add the Authorization header to.
"""
- _abstract()
+ _abstract()
- def _to_json(self, strip):
- """Utility function that creates JSON repr. of a Credentials object.
+ def _to_json(self, strip):
+ """Utility function that creates JSON repr. of a Credentials object.
Args:
strip: array, An array of names of members to not include in the JSON.
@@ -241,36 +238,36 @@ class Credentials(object):
string, a JSON representation of this instance, suitable to pass to
from_json().
"""
- t = type(self)
- d = copy.copy(self.__dict__)
- for member in strip:
- if member in d:
- del d[member]
- if (d.get('token_expiry') and
+ t = type(self)
+ d = copy.copy(self.__dict__)
+ for member in strip:
+ if member in d:
+ del d[member]
+ if (d.get('token_expiry') and
isinstance(d['token_expiry'], datetime.datetime)):
- d['token_expiry'] = d['token_expiry'].strftime(EXPIRY_FORMAT)
- # Add in information we will need later to reconsistitue this instance.
- d['_class'] = t.__name__
- d['_module'] = t.__module__
- for key, val in d.items():
- if isinstance(val, bytes):
- d[key] = val.decode('utf-8')
- if isinstance(val, set):
- d[key] = list(val)
- return json.dumps(d)
+ d['token_expiry'] = d['token_expiry'].strftime(EXPIRY_FORMAT)
+ # Add in information we will need later to reconsistitue this instance.
+ d['_class'] = t.__name__
+ d['_module'] = t.__module__
+ for key, val in d.items():
+ if isinstance(val, bytes):
+ d[key] = val.decode('utf-8')
+ if isinstance(val, set):
+ d[key] = list(val)
+ return json.dumps(d)
- def to_json(self):
- """Creating a JSON representation of an instance of Credentials.
+ def to_json(self):
+ """Creating a JSON representation of an instance of Credentials.
Returns:
string, a JSON representation of this instance, suitable to pass to
from_json().
"""
- return self._to_json(Credentials.NON_SERIALIZED_MEMBERS)
+ return self._to_json(Credentials.NON_SERIALIZED_MEMBERS)
- @classmethod
- def new_from_json(cls, s):
- """Utility class method to instantiate a Credentials subclass from JSON.
+ @classmethod
+ def new_from_json(cls, s):
+ """Utility class method to instantiate a Credentials subclass from JSON.
Expects the JSON string to have been produced by to_json().
@@ -281,25 +278,25 @@ class Credentials(object):
An instance of the subclass of Credentials that was serialized with
to_json().
"""
- json_string_as_unicode = _from_bytes(s)
- data = json.loads(json_string_as_unicode)
- # Find and call the right classmethod from_json() to restore the object.
- module_name = data['_module']
- try:
- module_obj = __import__(module_name)
- except ImportError:
- # In case there's an object from the old package structure, update it
- module_name = module_name.replace('.googleapiclient', '')
- module_obj = __import__(module_name)
+ json_string_as_unicode = _from_bytes(s)
+ data = json.loads(json_string_as_unicode)
+ # Find and call the right classmethod from_json() to restore the object.
+ module_name = data['_module']
+ try:
+ module_obj = __import__(module_name)
+ except ImportError:
+ # In case there's an object from the old package structure, update it
+ module_name = module_name.replace('.googleapiclient', '')
+ module_obj = __import__(module_name)
- module_obj = __import__(module_name, fromlist=module_name.split('.')[:-1])
- kls = getattr(module_obj, data['_class'])
- from_json = getattr(kls, 'from_json')
- return from_json(json_string_as_unicode)
+ module_obj = __import__(module_name, fromlist=module_name.split('.')[:-1])
+ kls = getattr(module_obj, data['_class'])
+ from_json = getattr(kls, 'from_json')
+ return from_json(json_string_as_unicode)
- @classmethod
- def from_json(cls, unused_data):
- """Instantiate a Credentials object from a JSON description of it.
+ @classmethod
+ def from_json(cls, unused_data):
+ """Instantiate a Credentials object from a JSON description of it.
The JSON should have been produced by calling .to_json() on the object.
@@ -309,94 +306,94 @@ class Credentials(object):
Returns:
An instance of a Credentials subclass.
"""
- return Credentials()
+ return Credentials()
class Flow(object):
- """Base class for all Flow objects."""
- pass
+ """Base class for all Flow objects."""
+ pass
class Storage(object):
- """Base class for all Storage objects.
+ """Base class for all Storage objects.
Store and retrieve a single credential. This class supports locking
such that multiple processes and threads can operate on a single
store.
"""
- def acquire_lock(self):
- """Acquires any lock necessary to access this Storage.
+ def acquire_lock(self):
+ """Acquires any lock necessary to access this Storage.
This lock is not reentrant.
"""
- pass
+ pass
- def release_lock(self):
- """Release the Storage lock.
+ def release_lock(self):
+ """Release the Storage lock.
Trying to release a lock that isn't held will result in a
RuntimeError.
"""
- pass
+ pass
- def locked_get(self):
- """Retrieve credential.
+ def locked_get(self):
+ """Retrieve credential.
The Storage lock must be held when this is called.
Returns:
oauth2client.client.Credentials
"""
- _abstract()
+ _abstract()
- def locked_put(self, credentials):
- """Write a credential.
+ def locked_put(self, credentials):
+ """Write a credential.
The Storage lock must be held when this is called.
Args:
credentials: Credentials, the credentials to store.
"""
- _abstract()
+ _abstract()
- def locked_delete(self):
- """Delete a credential.
+ def locked_delete(self):
+ """Delete a credential.
The Storage lock must be held when this is called.
"""
- _abstract()
+ _abstract()
- def get(self):
- """Retrieve credential.
+ def get(self):
+ """Retrieve credential.
The Storage lock must *not* be held when this is called.
Returns:
oauth2client.client.Credentials
"""
- self.acquire_lock()
- try:
- return self.locked_get()
- finally:
- self.release_lock()
+ self.acquire_lock()
+ try:
+ return self.locked_get()
+ finally:
+ self.release_lock()
- def put(self, credentials):
- """Write a credential.
+ def put(self, credentials):
+ """Write a credential.
The Storage lock must be held when this is called.
Args:
credentials: Credentials, the credentials to store.
"""
- self.acquire_lock()
- try:
- self.locked_put(credentials)
- finally:
- self.release_lock()
+ self.acquire_lock()
+ try:
+ self.locked_put(credentials)
+ finally:
+ self.release_lock()
- def delete(self):
- """Delete credential.
+ def delete(self):
+ """Delete credential.
Frees any resources associated with storing the credential.
The Storage lock must *not* be held when this is called.
@@ -404,15 +401,15 @@ class Storage(object):
Returns:
None
"""
- self.acquire_lock()
- try:
- return self.locked_delete()
- finally:
- self.release_lock()
+ self.acquire_lock()
+ try:
+ return self.locked_delete()
+ finally:
+ self.release_lock()
def clean_headers(headers):
- """Forces header keys and values to be strings, i.e not unicode.
+ """Forces header keys and values to be strings, i.e not unicode.
The httplib module just concats the header keys and values in a way that may
make the message header a unicode string, which, if it then tries to
@@ -424,21 +421,21 @@ def clean_headers(headers):
Returns:
The same dictionary but with all the keys converted to strings.
"""
- clean = {}
- try:
- for k, v in six.iteritems(headers):
- if not isinstance(k, six.binary_type):
- k = str(k)
- if not isinstance(v, six.binary_type):
- v = str(v)
- clean[_to_bytes(k)] = _to_bytes(v)
- except UnicodeEncodeError:
- raise NonAsciiHeaderError(k, ': ', v)
- return clean
+ clean = {}
+ try:
+ for k, v in six.iteritems(headers):
+ if not isinstance(k, six.binary_type):
+ k = str(k)
+ if not isinstance(v, six.binary_type):
+ v = str(v)
+ clean[_to_bytes(k)] = _to_bytes(v)
+ except UnicodeEncodeError:
+ raise NonAsciiHeaderError(k, ': ', v)
+ return clean
def _update_query_params(uri, params):
- """Updates a URI with new query parameters.
+ """Updates a URI with new query parameters.
Args:
uri: string, A valid URI, with potential existing query parameters.
@@ -447,15 +444,15 @@ def _update_query_params(uri, params):
Returns:
The same URI but with the new query parameters added.
"""
- parts = urllib.parse.urlparse(uri)
- query_params = dict(urllib.parse.parse_qsl(parts.query))
- query_params.update(params)
- new_parts = parts._replace(query=urllib.parse.urlencode(query_params))
- return urllib.parse.urlunparse(new_parts)
+ parts = urllib.parse.urlparse(uri)
+ query_params = dict(urllib.parse.parse_qsl(parts.query))
+ query_params.update(params)
+ new_parts = parts._replace(query=urllib.parse.urlencode(query_params))
+ return urllib.parse.urlunparse(new_parts)
class OAuth2Credentials(Credentials):
- """Credentials object for OAuth 2.0.
+ """Credentials object for OAuth 2.0.
Credentials can be applied to an httplib2.Http object using the authorize()
method, which then adds the OAuth 2.0 access token to each request.
@@ -463,12 +460,12 @@ class OAuth2Credentials(Credentials):
OAuth2Credentials objects may be safely pickled and unpickled.
"""
- @util.positional(8)
- def __init__(self, access_token, client_id, client_secret, refresh_token,
+ @util.positional(8)
+ def __init__(self, access_token, client_id, client_secret, refresh_token,
token_expiry, token_uri, user_agent, revoke_uri=None,
id_token=None, token_response=None, scopes=None,
token_info_uri=None):
- """Create an instance of OAuth2Credentials.
+ """Create an instance of OAuth2Credentials.
This constructor is not usually called by the user, instead
OAuth2Credentials objects are instantiated by the OAuth2WebServerFlow.
@@ -497,26 +494,26 @@ class OAuth2Credentials(Credentials):
This is needed to store the latest access_token if it
has expired and been refreshed.
"""
- self.access_token = access_token
- self.client_id = client_id
- self.client_secret = client_secret
- self.refresh_token = refresh_token
- self.store = None
- self.token_expiry = token_expiry
- self.token_uri = token_uri
- self.user_agent = user_agent
- self.revoke_uri = revoke_uri
- self.id_token = id_token
- self.token_response = token_response
- self.scopes = set(util.string_to_scopes(scopes or []))
- self.token_info_uri = token_info_uri
+ self.access_token = access_token
+ self.client_id = client_id
+ self.client_secret = client_secret
+ self.refresh_token = refresh_token
+ self.store = None
+ self.token_expiry = token_expiry
+ self.token_uri = token_uri
+ self.user_agent = user_agent
+ self.revoke_uri = revoke_uri
+ self.id_token = id_token
+ self.token_response = token_response
+ self.scopes = set(util.string_to_scopes(scopes or []))
+ self.token_info_uri = token_info_uri
- # True if the credentials have been revoked or expired and can't be
- # refreshed.
- self.invalid = False
+ # True if the credentials have been revoked or expired and can't be
+ # refreshed.
+ self.invalid = False
- def authorize(self, http):
- """Authorize an httplib2.Http instance with these credentials.
+ def authorize(self, http):
+ """Authorize an httplib2.Http instance with these credentials.
The modified http.request method will add authentication headers to each
request and will refresh access_tokens when a 401 is received on a
@@ -543,92 +540,92 @@ class OAuth2Credentials(Credentials):
version of 'request()'.
"""
- request_orig = http.request
+ request_orig = http.request
- # The closure that will replace 'httplib2.Http.request'.
- def new_request(uri, method='GET', body=None, headers=None,
+ # The closure that will replace 'httplib2.Http.request'.
+ def new_request(uri, method='GET', body=None, headers=None,
redirections=httplib2.DEFAULT_MAX_REDIRECTS,
connection_type=None):
- if not self.access_token:
- logger.info('Attempting refresh to obtain initial access_token')
- self._refresh(request_orig)
+ if not self.access_token:
+ logger.info('Attempting refresh to obtain initial access_token')
+ self._refresh(request_orig)
- # Clone and modify the request headers to add the appropriate
- # Authorization header.
- if headers is None:
- headers = {}
- else:
- headers = dict(headers)
- self.apply(headers)
+ # Clone and modify the request headers to add the appropriate
+ # Authorization header.
+ if headers is None:
+ headers = {}
+ else:
+ headers = dict(headers)
+ self.apply(headers)
- if self.user_agent is not None:
- if 'user-agent' in headers:
- headers['user-agent'] = self.user_agent + ' ' + headers['user-agent']
- else:
- headers['user-agent'] = self.user_agent
+ if self.user_agent is not None:
+ if 'user-agent' in headers:
+ headers['user-agent'] = self.user_agent + ' ' + headers['user-agent']
+ else:
+ headers['user-agent'] = self.user_agent
- body_stream_position = None
+ body_stream_position = None
if all(getattr(body, stream_prop, None) for stream_prop in
('read', 'seek', 'tell')):
- body_stream_position = body.tell()
+ body_stream_position = body.tell()
- resp, content = request_orig(uri, method, body, clean_headers(headers),
+ resp, content = request_orig(uri, method, body, clean_headers(headers),
redirections, connection_type)
- # A stored token may expire between the time it is retrieved and the time
- # the request is made, so we may need to try twice.
- max_refresh_attempts = 2
- for refresh_attempt in range(max_refresh_attempts):
- if resp.status not in REFRESH_STATUS_CODES:
- break
- logger.info('Refreshing due to a %s (attempt %s/%s)', resp.status,
+ # A stored token may expire between the time it is retrieved and the time
+ # the request is made, so we may need to try twice.
+ max_refresh_attempts = 2
+ for refresh_attempt in range(max_refresh_attempts):
+ if resp.status not in REFRESH_STATUS_CODES:
+ break
+ logger.info('Refreshing due to a %s (attempt %s/%s)', resp.status,
refresh_attempt + 1, max_refresh_attempts)
- self._refresh(request_orig)
- self.apply(headers)
- if body_stream_position is not None:
- body.seek(body_stream_position)
+ self._refresh(request_orig)
+ self.apply(headers)
+ if body_stream_position is not None:
+ body.seek(body_stream_position)
- resp, content = request_orig(uri, method, body, clean_headers(headers),
+ resp, content = request_orig(uri, method, body, clean_headers(headers),
redirections, connection_type)
- return (resp, content)
+ return (resp, content)
- # Replace the request method with our own closure.
- http.request = new_request
+ # Replace the request method with our own closure.
+ http.request = new_request
- # Set credentials as a property of the request method.
- setattr(http.request, 'credentials', self)
+ # Set credentials as a property of the request method.
+ setattr(http.request, 'credentials', self)
- return http
+ return http
- def refresh(self, http):
- """Forces a refresh of the access_token.
+ def refresh(self, http):
+ """Forces a refresh of the access_token.
Args:
http: httplib2.Http, an http object to be used to make the refresh
request.
"""
- self._refresh(http.request)
+ self._refresh(http.request)
- def revoke(self, http):
- """Revokes a refresh_token and makes the credentials void.
+ def revoke(self, http):
+ """Revokes a refresh_token and makes the credentials void.
Args:
http: httplib2.Http, an http object to be used to make the revoke
request.
"""
- self._revoke(http.request)
+ self._revoke(http.request)
- def apply(self, headers):
- """Add the authorization to the headers.
+ def apply(self, headers):
+ """Add the authorization to the headers.
Args:
headers: dict, the headers to add the Authorization header to.
"""
- headers['Authorization'] = 'Bearer ' + self.access_token
+ headers['Authorization'] = 'Bearer ' + self.access_token
- def has_scopes(self, scopes):
- """Verify that the credentials are authorized for the given scopes.
+ def has_scopes(self, scopes):
+ """Verify that the credentials are authorized for the given scopes.
Returns True if the credentials authorized scopes contain all of the scopes
given.
@@ -643,11 +640,11 @@ class OAuth2Credentials(Credentials):
both cases, you can use refresh_scopes() to obtain the canonical set of
scopes.
"""
- scopes = util.string_to_scopes(scopes)
- return set(scopes).issubset(self.scopes)
+ scopes = util.string_to_scopes(scopes)
+ return set(scopes).issubset(self.scopes)
- def retrieve_scopes(self, http):
- """Retrieves the canonical list of scopes for this access token from the
+ def retrieve_scopes(self, http):
+ """Retrieves the canonical list of scopes for this access token from the
OAuth2 provider.
Args:
@@ -657,15 +654,15 @@ class OAuth2Credentials(Credentials):
Returns:
A set of strings containing the canonical list of scopes.
"""
- self._retrieve_scopes(http.request)
- return self.scopes
+ self._retrieve_scopes(http.request)
+ return self.scopes
- def to_json(self):
- return self._to_json(Credentials.NON_SERIALIZED_MEMBERS)
+ def to_json(self):
+ return self._to_json(Credentials.NON_SERIALIZED_MEMBERS)
- @classmethod
- def from_json(cls, s):
- """Instantiate a Credentials object from a JSON description of it. The JSON
+ @classmethod
+ def from_json(cls, s):
+ """Instantiate a Credentials object from a JSON description of it. The JSON
should have been produced by calling .to_json() on the object.
Args:
@@ -674,16 +671,16 @@ class OAuth2Credentials(Credentials):
Returns:
An instance of a Credentials subclass.
"""
- s = _from_bytes(s)
- data = json.loads(s)
- if (data.get('token_expiry') and
+ s = _from_bytes(s)
+ data = json.loads(s)
+ if (data.get('token_expiry') and
not isinstance(data['token_expiry'], datetime.datetime)):
- try:
- data['token_expiry'] = datetime.datetime.strptime(
+ try:
+ data['token_expiry'] = datetime.datetime.strptime(
data['token_expiry'], EXPIRY_FORMAT)
- except ValueError:
- data['token_expiry'] = None
- retval = cls(
+ except ValueError:
+ data['token_expiry'] = None
+ retval = cls(
data['access_token'],
data['client_id'],
data['client_secret'],
@@ -699,40 +696,40 @@ class OAuth2Credentials(Credentials):
retval.invalid = data['invalid']
return retval
- @property
- def access_token_expired(self):
- """True if the credential is expired or invalid.
+ @property
+ def access_token_expired(self):
+ """True if the credential is expired or invalid.
If the token_expiry isn't set, we assume the token doesn't expire.
"""
- if self.invalid:
- return True
+ if self.invalid:
+ return True
- if not self.token_expiry:
- return False
+ if not self.token_expiry:
+ return False
- now = datetime.datetime.utcnow()
- if now >= self.token_expiry:
- logger.info('access_token is expired. Now: %s, token_expiry: %s',
+ now = datetime.datetime.utcnow()
+ if now >= self.token_expiry:
+ logger.info('access_token is expired. Now: %s, token_expiry: %s',
now, self.token_expiry)
- return True
- return False
+ return True
+ return False
- def get_access_token(self, http=None):
- """Return the access token and its expiration information.
+ def get_access_token(self, http=None):
+ """Return the access token and its expiration information.
If the token does not exist, get one.
If the token expired, refresh it.
"""
- if not self.access_token or self.access_token_expired:
- if not http:
- http = httplib2.Http()
- self.refresh(http)
- return AccessTokenInfo(access_token=self.access_token,
+ if not self.access_token or self.access_token_expired:
+ if not http:
+ http = httplib2.Http()
+ self.refresh(http)
+ return AccessTokenInfo(access_token=self.access_token,
expires_in=self._expires_in())
- def set_store(self, store):
- """Set the Storage for the credential.
+ def set_store(self, store):
+ """Set the Storage for the credential.
Args:
store: Storage, an implementation of Storage object.
@@ -741,10 +738,10 @@ class OAuth2Credentials(Credentials):
locking to check for updates before updating the
access_token.
"""
- self.store = store
+ self.store = store
- def _expires_in(self):
- """Return the number of seconds until this token expires.
+ def _expires_in(self):
+ """Return the number of seconds until this token expires.
If token_expiry is in the past, this method will return 0, meaning the
token has already expired.
@@ -752,54 +749,54 @@ class OAuth2Credentials(Credentials):
0 in such a case would not be fair: the token may still be valid;
we just don't know anything about it.
"""
- if self.token_expiry:
- now = datetime.datetime.utcnow()
- if self.token_expiry > now:
- time_delta = self.token_expiry - now
- # TODO(orestica): return time_delta.total_seconds()
- # once dropping support for Python 2.6
- return time_delta.days * 86400 + time_delta.seconds
- else:
- return 0
+ if self.token_expiry:
+ now = datetime.datetime.utcnow()
+ if self.token_expiry > now:
+ time_delta = self.token_expiry - now
+ # TODO(orestica): return time_delta.total_seconds()
+ # once dropping support for Python 2.6
+ return time_delta.days * 86400 + time_delta.seconds
+ else:
+ return 0
- def _updateFromCredential(self, other):
- """Update this Credential from another instance."""
- self.__dict__.update(other.__getstate__())
+ def _updateFromCredential(self, other):
+ """Update this Credential from another instance."""
+ self.__dict__.update(other.__getstate__())
- def __getstate__(self):
- """Trim the state down to something that can be pickled."""
- d = copy.copy(self.__dict__)
- del d['store']
- return d
+ def __getstate__(self):
+ """Trim the state down to something that can be pickled."""
+ d = copy.copy(self.__dict__)
+ del d['store']
+ return d
- def __setstate__(self, state):
- """Reconstitute the state of the object from being pickled."""
- self.__dict__.update(state)
- self.store = None
+ def __setstate__(self, state):
+ """Reconstitute the state of the object from being pickled."""
+ self.__dict__.update(state)
+ self.store = None
- def _generate_refresh_request_body(self):
- """Generate the body that will be used in the refresh request."""
- body = urllib.parse.urlencode({
+ def _generate_refresh_request_body(self):
+ """Generate the body that will be used in the refresh request."""
+ body = urllib.parse.urlencode({
'grant_type': 'refresh_token',
'client_id': self.client_id,
'client_secret': self.client_secret,
'refresh_token': self.refresh_token,
})
- return body
+ return body
- def _generate_refresh_request_headers(self):
- """Generate the headers that will be used in the refresh request."""
- headers = {
+ def _generate_refresh_request_headers(self):
+ """Generate the headers that will be used in the refresh request."""
+ headers = {
'content-type': 'application/x-www-form-urlencoded',
- }
+ }
- if self.user_agent is not None:
- headers['user-agent'] = self.user_agent
+ if self.user_agent is not None:
+ headers['user-agent'] = self.user_agent
- return headers
+ return headers
- def _refresh(self, http_request):
- """Refreshes the access_token.
+ def _refresh(self, http_request):
+ """Refreshes the access_token.
This method first checks by reading the Storage object if available.
If a refresh is still needed, it holds the Storage lock until the
@@ -812,25 +809,25 @@ class OAuth2Credentials(Credentials):
Raises:
AccessTokenRefreshError: When the refresh fails.
"""
- if not self.store:
- self._do_refresh_request(http_request)
- else:
- self.store.acquire_lock()
- try:
- new_cred = self.store.locked_get()
+ if not self.store:
+ self._do_refresh_request(http_request)
+ else:
+ self.store.acquire_lock()
+ try:
+ new_cred = self.store.locked_get()
- if (new_cred and not new_cred.invalid and
+ if (new_cred and not new_cred.invalid and
new_cred.access_token != self.access_token and
not new_cred.access_token_expired):
- logger.info('Updated access_token read from Storage')
- self._updateFromCredential(new_cred)
- else:
- self._do_refresh_request(http_request)
- finally:
- self.store.release_lock()
+ logger.info('Updated access_token read from Storage')
+ self._updateFromCredential(new_cred)
+ else:
+ self._do_refresh_request(http_request)
+ finally:
+ self.store.release_lock()
- def _do_refresh_request(self, http_request):
- """Refresh the access_token using the refresh_token.
+ def _do_refresh_request(self, http_request):
+ """Refresh the access_token using the refresh_token.
Args:
http_request: callable, a callable that matches the method signature of
@@ -839,57 +836,57 @@ class OAuth2Credentials(Credentials):
Raises:
AccessTokenRefreshError: When the refresh fails.
"""
- body = self._generate_refresh_request_body()
- headers = self._generate_refresh_request_headers()
+ body = self._generate_refresh_request_body()
+ headers = self._generate_refresh_request_headers()
- logger.info('Refreshing access_token')
- resp, content = http_request(
+ logger.info('Refreshing access_token')
+ resp, content = http_request(
self.token_uri, method='POST', body=body, headers=headers)
- content = _from_bytes(content)
- if resp.status == 200:
- d = json.loads(content)
- self.token_response = d
- self.access_token = d['access_token']
- self.refresh_token = d.get('refresh_token', self.refresh_token)
- if 'expires_in' in d:
- self.token_expiry = datetime.timedelta(
+ content = _from_bytes(content)
+ if resp.status == 200:
+ d = json.loads(content)
+ self.token_response = d
+ self.access_token = d['access_token']
+ self.refresh_token = d.get('refresh_token', self.refresh_token)
+ if 'expires_in' in d:
+ self.token_expiry = datetime.timedelta(
seconds=int(d['expires_in'])) + datetime.datetime.utcnow()
- else:
- self.token_expiry = None
- # On temporary refresh errors, the user does not actually have to
- # re-authorize, so we unflag here.
- self.invalid = False
- if self.store:
- self.store.locked_put(self)
- else:
- # An {'error':...} response body means the token is expired or revoked,
- # so we flag the credentials as such.
- logger.info('Failed to retrieve access token: %s', content)
- error_msg = 'Invalid response %s.' % resp['status']
- try:
- d = json.loads(content)
- if 'error' in d:
- error_msg = d['error']
- if 'error_description' in d:
- error_msg += ': ' + d['error_description']
- self.invalid = True
- if self.store:
- self.store.locked_put(self)
- except (TypeError, ValueError):
- pass
- raise AccessTokenRefreshError(error_msg)
+ else:
+ self.token_expiry = None
+ # On temporary refresh errors, the user does not actually have to
+ # re-authorize, so we unflag here.
+ self.invalid = False
+ if self.store:
+ self.store.locked_put(self)
+ else:
+ # An {'error':...} response body means the token is expired or revoked,
+ # so we flag the credentials as such.
+ logger.info('Failed to retrieve access token: %s', content)
+ error_msg = 'Invalid response %s.' % resp['status']
+ try:
+ d = json.loads(content)
+ if 'error' in d:
+ error_msg = d['error']
+ if 'error_description' in d:
+ error_msg += ': ' + d['error_description']
+ self.invalid = True
+ if self.store:
+ self.store.locked_put(self)
+ except (TypeError, ValueError):
+ pass
+ raise AccessTokenRefreshError(error_msg)
- def _revoke(self, http_request):
- """Revokes this credential and deletes the stored copy (if it exists).
+ def _revoke(self, http_request):
+ """Revokes this credential and deletes the stored copy (if it exists).
Args:
http_request: callable, a callable that matches the method signature of
httplib2.Http.request, used to make the revoke request.
"""
- self._do_revoke(http_request, self.refresh_token or self.access_token)
+ self._do_revoke(http_request, self.refresh_token or self.access_token)
- def _do_revoke(self, http_request, token):
- """Revokes this credential and deletes the stored copy (if it exists).
+ def _do_revoke(self, http_request, token):
+ """Revokes this credential and deletes the stored copy (if it exists).
Args:
http_request: callable, a callable that matches the method signature of
@@ -900,36 +897,36 @@ class OAuth2Credentials(Credentials):
Raises:
TokenRevokeError: If the revoke request does not return with a 200 OK.
"""
- logger.info('Revoking token')
- query_params = {'token': token}
- token_revoke_uri = _update_query_params(self.revoke_uri, query_params)
- resp, content = http_request(token_revoke_uri)
- if resp.status == 200:
- self.invalid = True
- else:
- error_msg = 'Invalid response %s.' % resp.status
- try:
- d = json.loads(_from_bytes(content))
- if 'error' in d:
- error_msg = d['error']
- except (TypeError, ValueError):
- pass
- raise TokenRevokeError(error_msg)
+ logger.info('Revoking token')
+ query_params = {'token': token}
+ token_revoke_uri = _update_query_params(self.revoke_uri, query_params)
+ resp, content = http_request(token_revoke_uri)
+ if resp.status == 200:
+ self.invalid = True
+ else:
+ error_msg = 'Invalid response %s.' % resp.status
+ try:
+ d = json.loads(_from_bytes(content))
+ if 'error' in d:
+ error_msg = d['error']
+ except (TypeError, ValueError):
+ pass
+ raise TokenRevokeError(error_msg)
- if self.store:
- self.store.delete()
+ if self.store:
+ self.store.delete()
- def _retrieve_scopes(self, http_request):
- """Retrieves the list of authorized scopes from the OAuth2 provider.
+ def _retrieve_scopes(self, http_request):
+ """Retrieves the list of authorized scopes from the OAuth2 provider.
Args:
http_request: callable, a callable that matches the method signature of
httplib2.Http.request, used to make the revoke request.
"""
- self._do_retrieve_scopes(http_request, self.access_token)
+ self._do_retrieve_scopes(http_request, self.access_token)
- def _do_retrieve_scopes(self, http_request, token):
- """Retrieves the list of authorized scopes from the OAuth2 provider.
+ def _do_retrieve_scopes(self, http_request, token):
+ """Retrieves the list of authorized scopes from the OAuth2 provider.
Args:
http_request: callable, a callable that matches the method signature of
@@ -940,27 +937,27 @@ class OAuth2Credentials(Credentials):
Raises:
Error: When refresh fails, indicating the the access token is invalid.
"""
- logger.info('Refreshing scopes')
- query_params = {'access_token': token, 'fields': 'scope'}
- token_info_uri = _update_query_params(self.token_info_uri, query_params)
- resp, content = http_request(token_info_uri)
- content = _from_bytes(content)
- if resp.status == 200:
- d = json.loads(content)
- self.scopes = set(util.string_to_scopes(d.get('scope', '')))
- else:
- error_msg = 'Invalid response %s.' % (resp.status,)
- try:
- d = json.loads(content)
- if 'error_description' in d:
- error_msg = d['error_description']
- except (TypeError, ValueError):
- pass
- raise Error(error_msg)
+ logger.info('Refreshing scopes')
+ query_params = {'access_token': token, 'fields': 'scope'}
+ token_info_uri = _update_query_params(self.token_info_uri, query_params)
+ resp, content = http_request(token_info_uri)
+ content = _from_bytes(content)
+ if resp.status == 200:
+ d = json.loads(content)
+ self.scopes = set(util.string_to_scopes(d.get('scope', '')))
+ else:
+ error_msg = 'Invalid response %s.' % (resp.status, )
+ try:
+ d = json.loads(content)
+ if 'error_description' in d:
+ error_msg = d['error_description']
+ except (TypeError, ValueError):
+ pass
+ raise Error(error_msg)
class AccessTokenCredentials(OAuth2Credentials):
- """Credentials object for OAuth 2.0.
+ """Credentials object for OAuth 2.0.
Credentials can be applied to an httplib2.Http object using the
authorize() method, which then signs each request from that object
@@ -985,8 +982,8 @@ class AccessTokenCredentials(OAuth2Credentials):
revoked.
"""
- def __init__(self, access_token, user_agent, revoke_uri=None):
- """Create an instance of OAuth2Credentials
+ def __init__(self, access_token, user_agent, revoke_uri=None):
+ """Create an instance of OAuth2Credentials
This is one of the few types if Credentials that you should contrust,
Credentials objects are usually instantiated by a Flow.
@@ -997,7 +994,7 @@ class AccessTokenCredentials(OAuth2Credentials):
revoke_uri: string, URI for revoke endpoint. Defaults to None; a token
can't be revoked if this is None.
"""
- super(AccessTokenCredentials, self).__init__(
+ super(AccessTokenCredentials, self).__init__(
access_token,
None,
None,
@@ -1007,31 +1004,30 @@ class AccessTokenCredentials(OAuth2Credentials):
user_agent,
revoke_uri=revoke_uri)
-
- @classmethod
- def from_json(cls, s):
- data = json.loads(_from_bytes(s))
- retval = AccessTokenCredentials(
+ @classmethod
+ def from_json(cls, s):
+ data = json.loads(_from_bytes(s))
+ retval = AccessTokenCredentials(
data['access_token'],
data['user_agent'])
- return retval
+ return retval
- def _refresh(self, http_request):
- raise AccessTokenCredentialsError(
+ def _refresh(self, http_request):
+ raise AccessTokenCredentialsError(
'The access_token is expired or invalid and can\'t be refreshed.')
- def _revoke(self, http_request):
- """Revokes the access_token and deletes the store if available.
+ def _revoke(self, http_request):
+ """Revokes the access_token and deletes the store if available.
Args:
http_request: callable, a callable that matches the method signature of
httplib2.Http.request, used to make the revoke request.
"""
- self._do_revoke(http_request, self.access_token)
+ self._do_revoke(http_request, self.access_token)
def _detect_gce_environment(urlopen=None):
- """Determine if the current environment is Compute Engine.
+ """Determine if the current environment is Compute Engine.
Args:
urlopen: Optional argument. Function used to open a connection to a URL.
@@ -1040,51 +1036,51 @@ def _detect_gce_environment(urlopen=None):
Boolean indicating whether or not the current environment is Google
Compute Engine.
"""
- urlopen = urlopen or urllib.request.urlopen
- # Note: the explicit `timeout` below is a workaround. The underlying
- # issue is that resolving an unknown host on some networks will take
- # 20-30 seconds; making this timeout short fixes the issue, but
- # could lead to false negatives in the event that we are on GCE, but
- # the metadata resolution was particularly slow. The latter case is
- # "unlikely".
- try:
- response = urlopen('http://169.254.169.254/', timeout=1)
- return response.info().get('Metadata-Flavor', '') == 'Google'
- except socket.timeout:
- logger.info('Timeout attempting to reach GCE metadata service.')
- return False
- except urllib.error.URLError as e:
- if isinstance(getattr(e, 'reason', None), socket.timeout):
- logger.info('Timeout attempting to reach GCE metadata service.')
- return False
+ urlopen = urlopen or urllib.request.urlopen
+ # Note: the explicit `timeout` below is a workaround. The underlying
+ # issue is that resolving an unknown host on some networks will take
+ # 20-30 seconds; making this timeout short fixes the issue, but
+ # could lead to false negatives in the event that we are on GCE, but
+ # the metadata resolution was particularly slow. The latter case is
+ # "unlikely".
+ try:
+ response = urlopen('http://169.254.169.254/', timeout=1)
+ return response.info().get('Metadata-Flavor', '') == 'Google'
+ except socket.timeout:
+ logger.info('Timeout attempting to reach GCE metadata service.')
+ return False
+ except urllib.error.URLError as e:
+ if isinstance(getattr(e, 'reason', None), socket.timeout):
+ logger.info('Timeout attempting to reach GCE metadata service.')
+ return False
def _in_gae_environment():
- """Detects if the code is running in the App Engine environment.
+ """Detects if the code is running in the App Engine environment.
Returns:
True if running in the GAE environment, False otherwise.
"""
- if SETTINGS.env_name is not None:
- return SETTINGS.env_name in ('GAE_PRODUCTION', 'GAE_LOCAL')
+ if SETTINGS.env_name is not None:
+ return SETTINGS.env_name in ('GAE_PRODUCTION', 'GAE_LOCAL')
- try:
- import google.appengine
- server_software = os.environ.get('SERVER_SOFTWARE', '')
- if server_software.startswith('Google App Engine/'):
- SETTINGS.env_name = 'GAE_PRODUCTION'
- return True
- elif server_software.startswith('Development/'):
- SETTINGS.env_name = 'GAE_LOCAL'
- return True
- except ImportError:
- pass
+ try:
+ import google.appengine
+ server_software = os.environ.get('SERVER_SOFTWARE', '')
+ if server_software.startswith('Google App Engine/'):
+ SETTINGS.env_name = 'GAE_PRODUCTION'
+ return True
+ elif server_software.startswith('Development/'):
+ SETTINGS.env_name = 'GAE_LOCAL'
+ return True
+ except ImportError:
+ pass
- return False
+ return False
def _in_gce_environment(urlopen=None):
- """Detect if the code is running in the Compute Engine environment.
+ """Detect if the code is running in the Compute Engine environment.
Args:
urlopen: Optional argument. Function used to open a connection to a URL.
@@ -1092,17 +1088,17 @@ def _in_gce_environment(urlopen=None):
Returns:
True if running in the GCE environment, False otherwise.
"""
- if SETTINGS.env_name is not None:
- return SETTINGS.env_name == 'GCE_PRODUCTION'
+ if SETTINGS.env_name is not None:
+ return SETTINGS.env_name == 'GCE_PRODUCTION'
- if NO_GCE_CHECK != 'True' and _detect_gce_environment(urlopen=urlopen):
- SETTINGS.env_name = 'GCE_PRODUCTION'
- return True
- return False
+ if NO_GCE_CHECK != 'True' and _detect_gce_environment(urlopen=urlopen):
+ SETTINGS.env_name = 'GCE_PRODUCTION'
+ return True
+ return False
class GoogleCredentials(OAuth2Credentials):
- """Application Default Credentials for use in calling Google APIs.
+ """Application Default Credentials for use in calling Google APIs.
The Application Default Credentials are being constructed as a function of
the environment where the code is being run.
@@ -1126,10 +1122,10 @@ class GoogleCredentials(OAuth2Credentials):
print(response)
"""
- def __init__(self, access_token, client_id, client_secret, refresh_token,
+ def __init__(self, access_token, client_id, client_secret, refresh_token,
token_expiry, token_uri, user_agent,
revoke_uri=GOOGLE_REVOKE_URI):
- """Create an instance of GoogleCredentials.
+ """Create an instance of GoogleCredentials.
This constructor is not usually called by the user, instead
GoogleCredentials objects are instantiated by
@@ -1147,38 +1143,38 @@ class GoogleCredentials(OAuth2Credentials):
revoke_uri: string, URI for revoke endpoint.
Defaults to GOOGLE_REVOKE_URI; a token can't be revoked if this is None.
"""
- super(GoogleCredentials, self).__init__(
+ super(GoogleCredentials, self).__init__(
access_token, client_id, client_secret, refresh_token, token_expiry,
token_uri, user_agent, revoke_uri=revoke_uri)
- def create_scoped_required(self):
- """Whether this Credentials object is scopeless.
+ def create_scoped_required(self):
+ """Whether this Credentials object is scopeless.
create_scoped(scopes) method needs to be called in order to create
a Credentials object for API calls.
"""
- return False
+ return False
- def create_scoped(self, scopes):
- """Create a Credentials object for the given scopes.
+ def create_scoped(self, scopes):
+ """Create a Credentials object for the given scopes.
The Credentials type is preserved.
"""
- return self
+ return self
- @property
- def serialization_data(self):
- """Get the fields and their values identifying the current credentials."""
- return {
+ @property
+ def serialization_data(self):
+ """Get the fields and their values identifying the current credentials."""
+ return {
'type': 'authorized_user',
'client_id': self.client_id,
'client_secret': self.client_secret,
'refresh_token': self.refresh_token
- }
+ }
- @staticmethod
- def _implicit_credentials_from_gae():
- """Attempts to get implicit credentials in Google App Engine env.
+ @staticmethod
+ def _implicit_credentials_from_gae():
+ """Attempts to get implicit credentials in Google App Engine env.
If the current environment is not detected as App Engine, returns None,
indicating no Google App Engine credentials can be detected from the
@@ -1187,14 +1183,14 @@ class GoogleCredentials(OAuth2Credentials):
Returns:
None, if not in GAE, else an appengine.AppAssertionCredentials object.
"""
- if not _in_gae_environment():
- return None
+ if not _in_gae_environment():
+ return None
- return _get_application_default_credential_GAE()
+ return _get_application_default_credential_GAE()
- @staticmethod
- def _implicit_credentials_from_gce():
- """Attempts to get implicit credentials in Google Compute Engine env.
+ @staticmethod
+ def _implicit_credentials_from_gce():
+ """Attempts to get implicit credentials in Google Compute Engine env.
If the current environment is not detected as Compute Engine, returns None,
indicating no Google Compute Engine credentials can be detected from the
@@ -1203,14 +1199,14 @@ class GoogleCredentials(OAuth2Credentials):
Returns:
None, if not in GCE, else a gce.AppAssertionCredentials object.
"""
- if not _in_gce_environment():
- return None
+ if not _in_gce_environment():
+ return None
- return _get_application_default_credential_GCE()
+ return _get_application_default_credential_GCE()
- @staticmethod
- def _implicit_credentials_from_files():
- """Attempts to get implicit credentials from local credential files.
+ @staticmethod
+ def _implicit_credentials_from_files():
+ """Attempts to get implicit credentials from local credential files.
First checks if the environment variable GOOGLE_APPLICATION_CREDENTIALS
is set with a filename and then falls back to a configuration file (the
@@ -1222,33 +1218,33 @@ class GoogleCredentials(OAuth2Credentials):
define, returns None, indicating no credentials from a file can
detected from the current environment.
"""
- credentials_filename = _get_environment_variable_file()
- if not credentials_filename:
- credentials_filename = _get_well_known_file()
- if os.path.isfile(credentials_filename):
- extra_help = (' (produced automatically when running'
+ credentials_filename = _get_environment_variable_file()
+ if not credentials_filename:
+ credentials_filename = _get_well_known_file()
+ if os.path.isfile(credentials_filename):
+ extra_help = (' (produced automatically when running'
' "gcloud auth login" command)')
- else:
- credentials_filename = None
- else:
- extra_help = (' (pointed to by ' + GOOGLE_APPLICATION_CREDENTIALS +
+ else:
+ credentials_filename = None
+ else:
+ extra_help = (' (pointed to by ' + GOOGLE_APPLICATION_CREDENTIALS +
' environment variable)')
- if not credentials_filename:
- return
+ if not credentials_filename:
+ return
- # If we can read the credentials from a file, we don't need to know what
- # environment we are in.
- SETTINGS.env_name = DEFAULT_ENV_NAME
+ # If we can read the credentials from a file, we don't need to know what
+ # environment we are in.
+ SETTINGS.env_name = DEFAULT_ENV_NAME
- try:
- return _get_application_default_credential_from_file(credentials_filename)
- except (ApplicationDefaultCredentialsError, ValueError) as error:
- _raise_exception_for_reading_json(credentials_filename, extra_help, error)
+ try:
+ return _get_application_default_credential_from_file(credentials_filename)
+ except (ApplicationDefaultCredentialsError, ValueError) as error:
+ _raise_exception_for_reading_json(credentials_filename, extra_help, error)
- @classmethod
+ @classmethod
def _get_implicit_credentials(cls):
- """Gets credentials implicitly from the environment.
+ """Gets credentials implicitly from the environment.
Checks environment in order of precedence:
- Google App Engine (production and testing)
@@ -1262,34 +1258,34 @@ class GoogleCredentials(OAuth2Credentials):
to be retrieved.
"""
- # Environ checks (in order).
- environ_checkers = [
+ # Environ checks (in order).
+ environ_checkers = [
cls._implicit_credentials_from_gae,
cls._implicit_credentials_from_files,
cls._implicit_credentials_from_gce,
- ]
+ ]
- for checker in environ_checkers:
- credentials = checker()
- if credentials is not None:
- return credentials
+ for checker in environ_checkers:
+ credentials = checker()
+ if credentials is not None:
+ return credentials
- # If no credentials, fail.
- raise ApplicationDefaultCredentialsError(ADC_HELP_MSG)
+ # If no credentials, fail.
+ raise ApplicationDefaultCredentialsError(ADC_HELP_MSG)
- @staticmethod
- def get_application_default():
- """Get the Application Default Credentials for the current environment.
+ @staticmethod
+ def get_application_default():
+ """Get the Application Default Credentials for the current environment.
Raises:
ApplicationDefaultCredentialsError: raised when the credentials fail
to be retrieved.
"""
- return GoogleCredentials._get_implicit_credentials()
+ return GoogleCredentials._get_implicit_credentials()
- @staticmethod
- def from_stream(credential_filename):
- """Create a Credentials object by reading the information from a given file.
+ @staticmethod
+ def from_stream(credential_filename):
+ """Create a Credentials object by reading the information from a given file.
It returns an object of type GoogleCredentials.
@@ -1302,38 +1298,38 @@ class GoogleCredentials(OAuth2Credentials):
to be retrieved.
"""
- if credential_filename and os.path.isfile(credential_filename):
- try:
- return _get_application_default_credential_from_file(
+ if credential_filename and os.path.isfile(credential_filename):
+ try:
+ return _get_application_default_credential_from_file(
credential_filename)
- except (ApplicationDefaultCredentialsError, ValueError) as error:
- extra_help = ' (provided as parameter to the from_stream() method)'
- _raise_exception_for_reading_json(credential_filename,
+ except (ApplicationDefaultCredentialsError, ValueError) as error:
+ extra_help = ' (provided as parameter to the from_stream() method)'
+ _raise_exception_for_reading_json(credential_filename,
extra_help,
error)
- else:
- raise ApplicationDefaultCredentialsError(
+ else:
+ raise ApplicationDefaultCredentialsError(
'The parameter passed to the from_stream() '
'method should point to a file.')
def _save_private_file(filename, json_contents):
- """Saves a file with read-write permissions on for the owner.
+ """Saves a file with read-write permissions on for the owner.
Args:
filename: String. Absolute path to file.
json_contents: JSON serializable object to be saved.
"""
- temp_filename = tempfile.mktemp()
- file_desc = os.open(temp_filename, os.O_WRONLY | os.O_CREAT, 0o600)
- with os.fdopen(file_desc, 'w') as file_handle:
- json.dump(json_contents, file_handle, sort_keys=True,
+ temp_filename = tempfile.mktemp()
+ file_desc = os.open(temp_filename, os.O_WRONLY | os.O_CREAT, 0o600)
+ with os.fdopen(file_desc, 'w') as file_handle:
+ json.dump(json_contents, file_handle, sort_keys=True,
indent=2, separators=(',', ': '))
- shutil.move(temp_filename, filename)
+ shutil.move(temp_filename, filename)
def save_to_well_known_file(credentials, well_known_file=None):
- """Save the provided GoogleCredentials to the well known file.
+ """Save the provided GoogleCredentials to the well known file.
Args:
credentials:
@@ -1343,88 +1339,88 @@ def save_to_well_known_file(credentials, well_known_file=None):
the name of the file where the credentials are to be saved;
this parameter is supposed to be used for testing only
"""
- # TODO(orestica): move this method to tools.py
- # once the argparse import gets fixed (it is not present in Python 2.6)
+ # TODO(orestica): move this method to tools.py
+ # once the argparse import gets fixed (it is not present in Python 2.6)
- if well_known_file is None:
- well_known_file = _get_well_known_file()
+ if well_known_file is None:
+ well_known_file = _get_well_known_file()
- config_dir = os.path.dirname(well_known_file)
- if not os.path.isdir(config_dir):
- raise OSError('Config directory does not exist: %s' % config_dir)
+ config_dir = os.path.dirname(well_known_file)
+ if not os.path.isdir(config_dir):
+ raise OSError('Config directory does not exist: %s' % config_dir)
- credentials_data = credentials.serialization_data
- _save_private_file(well_known_file, credentials_data)
+ credentials_data = credentials.serialization_data
+ _save_private_file(well_known_file, credentials_data)
def _get_environment_variable_file():
- application_default_credential_filename = (
+ application_default_credential_filename = (
os.environ.get(GOOGLE_APPLICATION_CREDENTIALS,
None))
- if application_default_credential_filename:
- if os.path.isfile(application_default_credential_filename):
- return application_default_credential_filename
- else:
- raise ApplicationDefaultCredentialsError(
+ if application_default_credential_filename:
+ if os.path.isfile(application_default_credential_filename):
+ return application_default_credential_filename
+ else:
+ raise ApplicationDefaultCredentialsError(
'File ' + application_default_credential_filename + ' (pointed by ' +
GOOGLE_APPLICATION_CREDENTIALS +
' environment variable) does not exist!')
def _get_well_known_file():
- """Get the well known file produced by command 'gcloud auth login'."""
- # TODO(orestica): Revisit this method once gcloud provides a better way
- # of pinpointing the exact location of the file.
+ """Get the well known file produced by command 'gcloud auth login'."""
+ # TODO(orestica): Revisit this method once gcloud provides a better way
+ # of pinpointing the exact location of the file.
- WELL_KNOWN_CREDENTIALS_FILE = 'application_default_credentials.json'
+ WELL_KNOWN_CREDENTIALS_FILE = 'application_default_credentials.json'
- default_config_dir = os.getenv(_CLOUDSDK_CONFIG_ENV_VAR)
- if default_config_dir is None:
- if os.name == 'nt':
- try:
- default_config_dir = os.path.join(os.environ['APPDATA'],
+ default_config_dir = os.getenv(_CLOUDSDK_CONFIG_ENV_VAR)
+ if default_config_dir is None:
+ if os.name == 'nt':
+ try:
+ default_config_dir = os.path.join(os.environ['APPDATA'],
_CLOUDSDK_CONFIG_DIRECTORY)
- except KeyError:
- # This should never happen unless someone is really messing with things.
- drive = os.environ.get('SystemDrive', 'C:')
- default_config_dir = os.path.join(drive, '\\',
+ except KeyError:
+ # This should never happen unless someone is really messing with things.
+ drive = os.environ.get('SystemDrive', 'C:')
+ default_config_dir = os.path.join(drive, '\\',
_CLOUDSDK_CONFIG_DIRECTORY)
- else:
- default_config_dir = os.path.join(os.path.expanduser('~'),
+ else:
+ default_config_dir = os.path.join(os.path.expanduser('~'),
'.config',
_CLOUDSDK_CONFIG_DIRECTORY)
- return os.path.join(default_config_dir, WELL_KNOWN_CREDENTIALS_FILE)
+ return os.path.join(default_config_dir, WELL_KNOWN_CREDENTIALS_FILE)
def _get_application_default_credential_from_file(filename):
- """Build the Application Default Credentials from file."""
+ """Build the Application Default Credentials from file."""
- from oauth2client import service_account
+ from oauth2client import service_account
- # read the credentials from the file
- with open(filename) as file_obj:
- client_credentials = json.load(file_obj)
+ # read the credentials from the file
+ with open(filename) as file_obj:
+ client_credentials = json.load(file_obj)
- credentials_type = client_credentials.get('type')
- if credentials_type == AUTHORIZED_USER:
- required_fields = set(['client_id', 'client_secret', 'refresh_token'])
- elif credentials_type == SERVICE_ACCOUNT:
- required_fields = set(['client_id', 'client_email', 'private_key_id',
+ credentials_type = client_credentials.get('type')
+ if credentials_type == AUTHORIZED_USER:
+ required_fields = set(['client_id', 'client_secret', 'refresh_token'])
+ elif credentials_type == SERVICE_ACCOUNT:
+ required_fields = set(['client_id', 'client_email', 'private_key_id',
'private_key'])
- else:
- raise ApplicationDefaultCredentialsError(
+ else:
+ raise ApplicationDefaultCredentialsError(
"'type' field should be defined (and have one of the '" +
AUTHORIZED_USER + "' or '" + SERVICE_ACCOUNT + "' values)")
- missing_fields = required_fields.difference(client_credentials.keys())
+ missing_fields = required_fields.difference(client_credentials.keys())
- if missing_fields:
- _raise_exception_for_missing_fields(missing_fields)
+ if missing_fields:
+ _raise_exception_for_missing_fields(missing_fields)
- if client_credentials['type'] == AUTHORIZED_USER:
- return GoogleCredentials(
+ if client_credentials['type'] == AUTHORIZED_USER:
+ return GoogleCredentials(
access_token=None,
client_id=client_credentials['client_id'],
client_secret=client_credentials['client_secret'],
@@ -1432,8 +1428,8 @@ def _get_application_default_credential_from_file(filename):
token_expiry=None,
token_uri=GOOGLE_TOKEN_URI,
user_agent='Python client library')
- else: # client_credentials['type'] == SERVICE_ACCOUNT
- return service_account._ServiceAccountCredentials(
+ else: # client_credentials['type'] == SERVICE_ACCOUNT
+ return service_account._ServiceAccountCredentials(
service_account_id=client_credentials['client_id'],
service_account_email=client_credentials['client_email'],
private_key_id=client_credentials['private_key_id'],
@@ -1442,32 +1438,32 @@ def _get_application_default_credential_from_file(filename):
def _raise_exception_for_missing_fields(missing_fields):
- raise ApplicationDefaultCredentialsError(
+ raise ApplicationDefaultCredentialsError(
'The following field(s) must be defined: ' + ', '.join(missing_fields))
def _raise_exception_for_reading_json(credential_file,
extra_help,
error):
- raise ApplicationDefaultCredentialsError(
- 'An error was encountered while reading json file: '+
+ raise ApplicationDefaultCredentialsError(
+ 'An error was encountered while reading json file: ' +
credential_file + extra_help + ': ' + str(error))
def _get_application_default_credential_GAE():
- from oauth2client.appengine import AppAssertionCredentials
+ from oauth2client.appengine import AppAssertionCredentials
- return AppAssertionCredentials([])
+ return AppAssertionCredentials([])
def _get_application_default_credential_GCE():
- from oauth2client.gce import AppAssertionCredentials
+ from oauth2client.gce import AppAssertionCredentials
- return AppAssertionCredentials([])
+ return AppAssertionCredentials([])
class AssertionCredentials(GoogleCredentials):
- """Abstract Credentials object used for OAuth 2.0 assertion grants.
+ """Abstract Credentials object used for OAuth 2.0 assertion grants.
This credential does not require a flow to instantiate because it
represents a two legged flow, and therefore has all of the required
@@ -1477,12 +1473,12 @@ class AssertionCredentials(GoogleCredentials):
AssertionCredentials objects may be safely pickled and unpickled.
"""
- @util.positional(2)
- def __init__(self, assertion_type, user_agent=None,
+ @util.positional(2)
+ def __init__(self, assertion_type, user_agent=None,
token_uri=GOOGLE_TOKEN_URI,
revoke_uri=GOOGLE_REVOKE_URI,
**unused_kwargs):
- """Constructor for AssertionFlowCredentials.
+ """Constructor for AssertionFlowCredentials.
Args:
assertion_type: string, assertion type that will be declared to the auth
@@ -1492,7 +1488,7 @@ class AssertionCredentials(GoogleCredentials):
defaults to Google's endpoints but any OAuth 2.0 provider can be used.
revoke_uri: string, URI for revoke endpoint.
"""
- super(AssertionCredentials, self).__init__(
+ super(AssertionCredentials, self).__init__(
None,
None,
None,
@@ -1501,47 +1497,47 @@ class AssertionCredentials(GoogleCredentials):
token_uri,
user_agent,
revoke_uri=revoke_uri)
- self.assertion_type = assertion_type
+ self.assertion_type = assertion_type
- def _generate_refresh_request_body(self):
- assertion = self._generate_assertion()
+ def _generate_refresh_request_body(self):
+ assertion = self._generate_assertion()
- body = urllib.parse.urlencode({
+ body = urllib.parse.urlencode({
'assertion': assertion,
'grant_type': 'urn:ietf:params:oauth:grant-type:jwt-bearer',
})
- return body
+ return body
- def _generate_assertion(self):
- """Generate the assertion string that will be used in the access token
+ def _generate_assertion(self):
+ """Generate the assertion string that will be used in the access token
request.
"""
- _abstract()
+ _abstract()
- def _revoke(self, http_request):
- """Revokes the access_token and deletes the store if available.
+ def _revoke(self, http_request):
+ """Revokes the access_token and deletes the store if available.
Args:
http_request: callable, a callable that matches the method signature of
httplib2.Http.request, used to make the revoke request.
"""
- self._do_revoke(http_request, self.access_token)
+ self._do_revoke(http_request, self.access_token)
def _RequireCryptoOrDie():
- """Ensure we have a crypto library, or throw CryptoUnavailableError.
+ """Ensure we have a crypto library, or throw CryptoUnavailableError.
The oauth2client.crypt module requires either PyCrypto or PyOpenSSL
to be available in order to function, but these are optional
dependencies.
"""
- if not HAS_CRYPTO:
- raise CryptoUnavailableError('No crypto library available')
+ if not HAS_CRYPTO:
+ raise CryptoUnavailableError('No crypto library available')
class SignedJwtAssertionCredentials(AssertionCredentials):
- """Credentials object used for OAuth 2.0 Signed JWT assertion grants.
+ """Credentials object used for OAuth 2.0 Signed JWT assertion grants.
This credential does not require a flow to instantiate because it
represents a two legged flow, and therefore has all of the required
@@ -1552,10 +1548,10 @@ class SignedJwtAssertionCredentials(AssertionCredentials):
AppAssertionCredentials.
"""
- MAX_TOKEN_LIFETIME_SECS = 3600 # 1 hour in seconds
+ MAX_TOKEN_LIFETIME_SECS = 3600 # 1 hour in seconds
- @util.positional(4)
- def __init__(self,
+ @util.positional(4)
+ def __init__(self,
service_account_name,
private_key,
scope,
@@ -1564,7 +1560,7 @@ class SignedJwtAssertionCredentials(AssertionCredentials):
token_uri=GOOGLE_TOKEN_URI,
revoke_uri=GOOGLE_REVOKE_URI,
**kwargs):
- """Constructor for SignedJwtAssertionCredentials.
+ """Constructor for SignedJwtAssertionCredentials.
Args:
service_account_name: string, id for account, usually an email address.
@@ -1583,27 +1579,27 @@ class SignedJwtAssertionCredentials(AssertionCredentials):
Raises:
CryptoUnavailableError if no crypto library is available.
"""
- _RequireCryptoOrDie()
- super(SignedJwtAssertionCredentials, self).__init__(
+ _RequireCryptoOrDie()
+ super(SignedJwtAssertionCredentials, self).__init__(
None,
user_agent=user_agent,
token_uri=token_uri,
revoke_uri=revoke_uri,
)
- self.scope = util.scopes_to_string(scope)
+ self.scope = util.scopes_to_string(scope)
- # Keep base64 encoded so it can be stored in JSON.
- self.private_key = base64.b64encode(private_key)
- self.private_key = _to_bytes(self.private_key, encoding='utf-8')
- self.private_key_password = private_key_password
- self.service_account_name = service_account_name
- self.kwargs = kwargs
+ # Keep base64 encoded so it can be stored in JSON.
+ self.private_key = base64.b64encode(private_key)
+ self.private_key = _to_bytes(self.private_key, encoding='utf-8')
+ self.private_key_password = private_key_password
+ self.service_account_name = service_account_name
+ self.kwargs = kwargs
- @classmethod
- def from_json(cls, s):
- data = json.loads(_from_bytes(s))
- retval = SignedJwtAssertionCredentials(
+ @classmethod
+ def from_json(cls, s):
+ data = json.loads(_from_bytes(s))
+ retval = SignedJwtAssertionCredentials(
data['service_account_name'],
base64.b64decode(data['private_key']),
data['scope'],
@@ -1612,35 +1608,36 @@ class SignedJwtAssertionCredentials(AssertionCredentials):
token_uri=data['token_uri'],
**data['kwargs']
)
- retval.invalid = data['invalid']
- retval.access_token = data['access_token']
- return retval
+ retval.invalid = data['invalid']
+ retval.access_token = data['access_token']
+ return retval
- def _generate_assertion(self):
- """Generate the assertion that will be used in the request."""
- now = int(time.time())
- payload = {
+ def _generate_assertion(self):
+ """Generate the assertion that will be used in the request."""
+ now = int(time.time())
+ payload = {
'aud': self.token_uri,
'scope': self.scope,
'iat': now,
'exp': now + SignedJwtAssertionCredentials.MAX_TOKEN_LIFETIME_SECS,
'iss': self.service_account_name
- }
- payload.update(self.kwargs)
- logger.debug(str(payload))
+ }
+ payload.update(self.kwargs)
+ logger.debug(str(payload))
- private_key = base64.b64decode(self.private_key)
- return crypt.make_signed_jwt(crypt.Signer.from_string(
+ private_key = base64.b64decode(self.private_key)
+ return crypt.make_signed_jwt(crypt.Signer.from_string(
private_key, self.private_key_password), payload)
# Only used in verify_id_token(), which is always calling to the same URI
# for the certs.
_cached_http = httplib2.Http(MemoryCache())
+
@util.positional(2)
def verify_id_token(id_token, audience, http=None,
cert_uri=ID_TOKEN_VERIFICATION_CERTS):
- """Verifies a signed JWT id_token.
+ """Verifies a signed JWT id_token.
This function requires PyOpenSSL and because of that it does not work on
App Engine.
@@ -1660,20 +1657,20 @@ def verify_id_token(id_token, audience, http=None,
oauth2client.crypt.AppIdentityError: if the JWT fails to verify.
CryptoUnavailableError: if no crypto library is available.
"""
- _RequireCryptoOrDie()
- if http is None:
- http = _cached_http
+ _RequireCryptoOrDie()
+ if http is None:
+ http = _cached_http
- resp, content = http.request(cert_uri)
- if resp.status == 200:
- certs = json.loads(_from_bytes(content))
- return crypt.verify_signed_jwt_with_certs(id_token, certs, audience)
- else:
- raise VerifyJwtTokenError('Status code: %d' % resp.status)
+ resp, content = http.request(cert_uri)
+ if resp.status == 200:
+ certs = json.loads(_from_bytes(content))
+ return crypt.verify_signed_jwt_with_certs(id_token, certs, audience)
+ else:
+ raise VerifyJwtTokenError('Status code: %d' % resp.status)
def _extract_id_token(id_token):
- """Extract the JSON payload from a JWT.
+ """Extract the JSON payload from a JWT.
Does the extraction w/o checking the signature.
@@ -1683,20 +1680,20 @@ def _extract_id_token(id_token):
Returns:
object, The deserialized JSON payload.
"""
- if type(id_token) == bytes:
- segments = id_token.split(b'.')
- else:
- segments = id_token.split(u'.')
+ if type(id_token) == bytes:
+ segments = id_token.split(b'.')
+ else:
+ segments = id_token.split(u'.')
- if len(segments) != 3:
- raise VerifyJwtTokenError(
+ if len(segments) != 3:
+ raise VerifyJwtTokenError(
'Wrong number of segments in token: %s' % id_token)
- return json.loads(_from_bytes(_urlsafe_b64decode(segments[1])))
+ return json.loads(_from_bytes(_urlsafe_b64decode(segments[1])))
def _parse_exchange_token_response(content):
- """Parses response of an exchange token request.
+ """Parses response of an exchange token request.
Most providers return JSON but some (e.g. Facebook) return a
url-encoded string.
@@ -1708,20 +1705,20 @@ def _parse_exchange_token_response(content):
Content as a dictionary object. Note that the dict could be empty,
i.e. {}. That basically indicates a failure.
"""
- resp = {}
- content = _from_bytes(content)
- try:
- resp = json.loads(content)
- except Exception:
- # different JSON libs raise different exceptions,
- # so we just do a catch-all here
- resp = dict(urllib.parse.parse_qsl(content))
+ resp = {}
+ content = _from_bytes(content)
+ try:
+ resp = json.loads(content)
+ except Exception:
+ # different JSON libs raise different exceptions,
+ # so we just do a catch-all here
+ resp = dict(urllib.parse.parse_qsl(content))
- # some providers respond with 'expires', others with 'expires_in'
- if resp and 'expires' in resp:
- resp['expires_in'] = resp.pop('expires')
+ # some providers respond with 'expires', others with 'expires_in'
+ if resp and 'expires' in resp:
+ resp['expires_in'] = resp.pop('expires')
- return resp
+ return resp
@util.positional(4)
@@ -1732,7 +1729,7 @@ def credentials_from_code(client_id, client_secret, scope, code,
revoke_uri=GOOGLE_REVOKE_URI,
device_uri=GOOGLE_DEVICE_URI,
token_info_uri=GOOGLE_TOKEN_INFO_URI):
- """Exchanges an authorization code for an OAuth2Credentials object.
+ """Exchanges an authorization code for an OAuth2Credentials object.
Args:
client_id: string, client identifier.
@@ -1759,24 +1756,24 @@ def credentials_from_code(client_id, client_secret, scope, code,
FlowExchangeError if the authorization code cannot be exchanged for an
access token
"""
- flow = OAuth2WebServerFlow(client_id, client_secret, scope,
+ flow = OAuth2WebServerFlow(client_id, client_secret, scope,
redirect_uri=redirect_uri, user_agent=user_agent,
auth_uri=auth_uri, token_uri=token_uri,
revoke_uri=revoke_uri, device_uri=device_uri,
token_info_uri=token_info_uri)
- credentials = flow.step2_exchange(code, http=http)
- return credentials
+ credentials = flow.step2_exchange(code, http=http)
+ return credentials
@util.positional(3)
def credentials_from_clientsecrets_and_code(filename, scope, code,
- message = None,
+ message=None,
redirect_uri='postmessage',
http=None,
cache=None,
device_uri=None):
- """Returns OAuth2Credentials from a clientsecrets file and an auth code.
+ """Returns OAuth2Credentials from a clientsecrets file and an auth code.
Will create the right kind of Flow based on the contents of the clientsecrets
file or will raise InvalidClientSecretsError for unknown types of Flows.
@@ -1807,58 +1804,59 @@ def credentials_from_clientsecrets_and_code(filename, scope, code,
clientsecrets.InvalidClientSecretsError if the clientsecrets file is
invalid.
"""
- flow = flow_from_clientsecrets(filename, scope, message=message, cache=cache,
+ flow = flow_from_clientsecrets(filename, scope, message=message, cache=cache,
redirect_uri=redirect_uri,
device_uri=device_uri)
- credentials = flow.step2_exchange(code, http=http)
- return credentials
+ credentials = flow.step2_exchange(code, http=http)
+ return credentials
class DeviceFlowInfo(collections.namedtuple('DeviceFlowInfo', (
'device_code', 'user_code', 'interval', 'verification_url',
'user_code_expiry'))):
- """Intermediate information the OAuth2 for devices flow."""
+ """Intermediate information the OAuth2 for devices flow."""
- @classmethod
- def FromResponse(cls, response):
- """Create a DeviceFlowInfo from a server response.
+ @classmethod
+ def FromResponse(cls, response):
+ """Create a DeviceFlowInfo from a server response.
The response should be a dict containing entries as described here:
http://tools.ietf.org/html/draft-ietf-oauth-v2-05#section-3.7.1
"""
- # device_code, user_code, and verification_url are required.
- kwargs = {
+ # device_code, user_code, and verification_url are required.
+ kwargs = {
'device_code': response['device_code'],
'user_code': response['user_code'],
- }
- # The response may list the verification address as either
- # verification_url or verification_uri, so we check for both.
- verification_url = response.get(
+ }
+ # The response may list the verification address as either
+ # verification_url or verification_uri, so we check for both.
+ verification_url = response.get(
'verification_url', response.get('verification_uri'))
- if verification_url is None:
- raise OAuth2DeviceCodeError(
+ if verification_url is None:
+ raise OAuth2DeviceCodeError(
'No verification_url provided in server response')
- kwargs['verification_url'] = verification_url
- # expires_in and interval are optional.
- kwargs.update({
+ kwargs['verification_url'] = verification_url
+ # expires_in and interval are optional.
+ kwargs.update({
'interval': response.get('interval'),
'user_code_expiry': None,
- })
- if 'expires_in' in response:
- kwargs['user_code_expiry'] = datetime.datetime.now() + datetime.timedelta(
+ })
+ if 'expires_in' in response:
+ kwargs['user_code_expiry'] = datetime.datetime.now() + datetime.timedelta(
seconds=int(response['expires_in']))
- return cls(**kwargs)
+ return cls(**kwargs)
+
class OAuth2WebServerFlow(Flow):
- """Does the Web Server Flow for OAuth 2.0.
+ """Does the Web Server Flow for OAuth 2.0.
OAuth2WebServerFlow objects may be safely pickled and unpickled.
"""
- @util.positional(4)
- def __init__(self, client_id,
+ @util.positional(4)
+ def __init__(self, client_id,
client_secret=None,
scope=None,
redirect_uri=None,
@@ -1871,7 +1869,7 @@ class OAuth2WebServerFlow(Flow):
token_info_uri=GOOGLE_TOKEN_INFO_URI,
authorization_header=None,
**kwargs):
- """Constructor for OAuth2WebServerFlow.
+ """Constructor for OAuth2WebServerFlow.
The kwargs argument is used to set extra query parameters on the
auth_uri. For example, the access_type and approval_prompt
@@ -1903,31 +1901,31 @@ class OAuth2WebServerFlow(Flow):
**kwargs: dict, The keyword arguments are all optional and required
parameters for the OAuth calls.
"""
- # scope is a required argument, but to preserve backwards-compatibility
- # we don't want to rearrange the positional arguments
- if scope is None:
- raise TypeError("The value of scope must not be None")
- self.client_id = client_id
- self.client_secret = client_secret
- self.scope = util.scopes_to_string(scope)
- self.redirect_uri = redirect_uri
- self.login_hint = login_hint
- self.user_agent = user_agent
- self.auth_uri = auth_uri
- self.token_uri = token_uri
- self.revoke_uri = revoke_uri
- self.device_uri = device_uri
- self.token_info_uri = token_info_uri
- self.authorization_header = authorization_header
- self.params = {
+ # scope is a required argument, but to preserve backwards-compatibility
+ # we don't want to rearrange the positional arguments
+ if scope is None:
+ raise TypeError("The value of scope must not be None")
+ self.client_id = client_id
+ self.client_secret = client_secret
+ self.scope = util.scopes_to_string(scope)
+ self.redirect_uri = redirect_uri
+ self.login_hint = login_hint
+ self.user_agent = user_agent
+ self.auth_uri = auth_uri
+ self.token_uri = token_uri
+ self.revoke_uri = revoke_uri
+ self.device_uri = device_uri
+ self.token_info_uri = token_info_uri
+ self.authorization_header = authorization_header
+ self.params = {
'access_type': 'offline',
'response_type': 'code',
- }
- self.params.update(kwargs)
+ }
+ self.params.update(kwargs)
- @util.positional(1)
- def step1_get_authorize_url(self, redirect_uri=None, state=None):
- """Returns a URI to redirect to the provider.
+ @util.positional(1)
+ def step1_get_authorize_url(self, redirect_uri=None, state=None):
+ """Returns a URI to redirect to the provider.
Args:
redirect_uri: string, Either the string 'urn:ietf:wg:oauth:2.0:oob' for
@@ -1940,78 +1938,78 @@ class OAuth2WebServerFlow(Flow):
Returns:
A URI as a string to redirect the user to begin the authorization flow.
"""
- if redirect_uri is not None:
- logger.warning((
+ if redirect_uri is not None:
+ logger.warning((
'The redirect_uri parameter for '
'OAuth2WebServerFlow.step1_get_authorize_url is deprecated. Please '
'move to passing the redirect_uri in via the constructor.'))
- self.redirect_uri = redirect_uri
+ self.redirect_uri = redirect_uri
- if self.redirect_uri is None:
- raise ValueError('The value of redirect_uri must not be None.')
+ if self.redirect_uri is None:
+ raise ValueError('The value of redirect_uri must not be None.')
- query_params = {
+ query_params = {
'client_id': self.client_id,
'redirect_uri': self.redirect_uri,
'scope': self.scope,
- }
- if state is not None:
- query_params['state'] = state
- if self.login_hint is not None:
- query_params['login_hint'] = self.login_hint
- query_params.update(self.params)
- return _update_query_params(self.auth_uri, query_params)
+ }
+ if state is not None:
+ query_params['state'] = state
+ if self.login_hint is not None:
+ query_params['login_hint'] = self.login_hint
+ query_params.update(self.params)
+ return _update_query_params(self.auth_uri, query_params)
- @util.positional(1)
- def step1_get_device_and_user_codes(self, http=None):
- """Returns a user code and the verification URL where to enter it
+ @util.positional(1)
+ def step1_get_device_and_user_codes(self, http=None):
+ """Returns a user code and the verification URL where to enter it
Returns:
A user code as a string for the user to authorize the application
An URL as a string where the user has to enter the code
"""
- if self.device_uri is None:
- raise ValueError('The value of device_uri must not be None.')
+ if self.device_uri is None:
+ raise ValueError('The value of device_uri must not be None.')
- body = urllib.parse.urlencode({
+ body = urllib.parse.urlencode({
'client_id': self.client_id,
'scope': self.scope,
- })
- headers = {
+ })
+ headers = {
'content-type': 'application/x-www-form-urlencoded',
- }
+ }
- if self.user_agent is not None:
- headers['user-agent'] = self.user_agent
+ if self.user_agent is not None:
+ headers['user-agent'] = self.user_agent
- if http is None:
- http = httplib2.Http()
+ if http is None:
+ http = httplib2.Http()
- resp, content = http.request(self.device_uri, method='POST', body=body,
+ resp, content = http.request(self.device_uri, method='POST', body=body,
headers=headers)
- content = _from_bytes(content)
- if resp.status == 200:
- try:
- flow_info = json.loads(content)
- except ValueError as e:
- raise OAuth2DeviceCodeError(
+ content = _from_bytes(content)
+ if resp.status == 200:
+ try:
+ flow_info = json.loads(content)
+ except ValueError as e:
+ raise OAuth2DeviceCodeError(
'Could not parse server response as JSON: "%s", error: "%s"' % (
content, e))
- return DeviceFlowInfo.FromResponse(flow_info)
- else:
- error_msg = 'Invalid response %s.' % resp.status
- try:
- d = json.loads(content)
- if 'error' in d:
- error_msg += ' Error: %s' % d['error']
- except ValueError:
- # Couldn't decode a JSON response, stick with the default message.
- pass
- raise OAuth2DeviceCodeError(error_msg)
+ return DeviceFlowInfo.FromResponse(flow_info)
+ else:
+ error_msg = 'Invalid response %s.' % resp.status
+ try:
+ d = json.loads(content)
+ if 'error' in d:
+ error_msg += ' Error: %s' % d['error']
+ except ValueError:
+ # Couldn't decode a JSON response, stick with the default message.
+ pass
+ raise OAuth2DeviceCodeError(error_msg)
- @util.positional(2)
+ @util.positional(2)
def step2_exchange(self, code=None, http=None, device_flow_info=None):
- """Exchanges a code for OAuth2Credentials.
+ """Exchanges a code for OAuth2Credentials.
Args:
@@ -2034,64 +2032,64 @@ class OAuth2WebServerFlow(Flow):
missing.
"""
- if code is None and device_flow_info is None:
- raise ValueError('No code or device_flow_info provided.')
- if code is not None and device_flow_info is not None:
- raise ValueError('Cannot provide both code and device_flow_info.')
+ if code is None and device_flow_info is None:
+ raise ValueError('No code or device_flow_info provided.')
+ if code is not None and device_flow_info is not None:
+ raise ValueError('Cannot provide both code and device_flow_info.')
- if code is None:
- code = device_flow_info.device_code
- elif not isinstance(code, six.string_types):
- if 'code' not in code:
- raise FlowExchangeError(code.get(
+ if code is None:
+ code = device_flow_info.device_code
+ elif not isinstance(code, six.string_types):
+ if 'code' not in code:
+ raise FlowExchangeError(code.get(
'error', 'No code was supplied in the query parameters.'))
- code = code['code']
+ code = code['code']
- post_data = {
+ post_data = {
'client_id': self.client_id,
'code': code,
'scope': self.scope,
- }
- if self.client_secret is not None:
- post_data['client_secret'] = self.client_secret
- if device_flow_info is not None:
- post_data['grant_type'] = 'http://oauth.net/grant_type/device/1.0'
- else:
- post_data['grant_type'] = 'authorization_code'
- post_data['redirect_uri'] = self.redirect_uri
- body = urllib.parse.urlencode(post_data)
- headers = {
+ }
+ if self.client_secret is not None:
+ post_data['client_secret'] = self.client_secret
+ if device_flow_info is not None:
+ post_data['grant_type'] = 'http://oauth.net/grant_type/device/1.0'
+ else:
+ post_data['grant_type'] = 'authorization_code'
+ post_data['redirect_uri'] = self.redirect_uri
+ body = urllib.parse.urlencode(post_data)
+ headers = {
'content-type': 'application/x-www-form-urlencoded',
- }
- if self.authorization_header is not None:
- headers['Authorization'] = self.authorization_header
- if self.user_agent is not None:
- headers['user-agent'] = self.user_agent
+ }
+ if self.authorization_header is not None:
+ headers['Authorization'] = self.authorization_header
+ if self.user_agent is not None:
+ headers['user-agent'] = self.user_agent
- if http is None:
- http = httplib2.Http()
+ if http is None:
+ http = httplib2.Http()
- resp, content = http.request(self.token_uri, method='POST', body=body,
+ resp, content = http.request(self.token_uri, method='POST', body=body,
headers=headers)
- d = _parse_exchange_token_response(content)
- if resp.status == 200 and 'access_token' in d:
- access_token = d['access_token']
- refresh_token = d.get('refresh_token', None)
- if not refresh_token:
- logger.info(
+ d = _parse_exchange_token_response(content)
+ if resp.status == 200 and 'access_token' in d:
+ access_token = d['access_token']
+ refresh_token = d.get('refresh_token', None)
+ if not refresh_token:
+ logger.info(
'Received token response with no refresh_token. Consider '
"reauthenticating with approval_prompt='force'.")
- token_expiry = None
- if 'expires_in' in d:
- token_expiry = datetime.datetime.utcnow() + datetime.timedelta(
+ token_expiry = None
+ if 'expires_in' in d:
+ token_expiry = datetime.datetime.utcnow() + datetime.timedelta(
seconds=int(d['expires_in']))
- extracted_id_token = None
- if 'id_token' in d:
- extracted_id_token = _extract_id_token(d['id_token'])
+ extracted_id_token = None
+ if 'id_token' in d:
+ extracted_id_token = _extract_id_token(d['id_token'])
- logger.info('Successfully retrieved access token')
- return OAuth2Credentials(access_token, self.client_id,
+ logger.info('Successfully retrieved access token')
+ return OAuth2Credentials(access_token, self.client_id,
self.client_secret, refresh_token, token_expiry,
self.token_uri, self.user_agent,
revoke_uri=self.revoke_uri,
@@ -2099,21 +2097,21 @@ class OAuth2WebServerFlow(Flow):
token_response=d,
scopes=self.scope,
token_info_uri=self.token_info_uri)
- else:
- logger.info('Failed to retrieve access token: %s', content)
- if 'error' in d:
- # you never know what those providers got to say
- error_msg = str(d['error']) + str(d.get('error_description', ''))
- else:
- error_msg = 'Invalid response: %s.' % str(resp.status)
- raise FlowExchangeError(error_msg)
+ else:
+ logger.info('Failed to retrieve access token: %s', content)
+ if 'error' in d:
+ # you never know what those providers got to say
+ error_msg = str(d['error']) + str(d.get('error_description', ''))
+ else:
+ error_msg = 'Invalid response: %s.' % str(resp.status)
+ raise FlowExchangeError(error_msg)
@util.positional(2)
def flow_from_clientsecrets(filename, scope, redirect_uri=None,
message=None, cache=None, login_hint=None,
device_uri=None):
- """Create a Flow from a clientsecrets file.
+ """Create a Flow from a clientsecrets file.
Will create the right kind of Flow based on the contents of the clientsecrets
file or will raise InvalidClientSecretsError for unknown types of Flows.
@@ -2144,29 +2142,29 @@ def flow_from_clientsecrets(filename, scope, redirect_uri=None,
clientsecrets.InvalidClientSecretsError if the clientsecrets file is
invalid.
"""
- try:
- client_type, client_info = clientsecrets.loadfile(filename, cache=cache)
- if client_type in (clientsecrets.TYPE_WEB, clientsecrets.TYPE_INSTALLED):
- constructor_kwargs = {
+ try:
+ client_type, client_info = clientsecrets.loadfile(filename, cache=cache)
+ if client_type in (clientsecrets.TYPE_WEB, clientsecrets.TYPE_INSTALLED):
+ constructor_kwargs = {
'redirect_uri': redirect_uri,
'auth_uri': client_info['auth_uri'],
'token_uri': client_info['token_uri'],
'login_hint': login_hint,
- }
- revoke_uri = client_info.get('revoke_uri')
- if revoke_uri is not None:
- constructor_kwargs['revoke_uri'] = revoke_uri
- if device_uri is not None:
- constructor_kwargs['device_uri'] = device_uri
- return OAuth2WebServerFlow(
+ }
+ revoke_uri = client_info.get('revoke_uri')
+ if revoke_uri is not None:
+ constructor_kwargs['revoke_uri'] = revoke_uri
+ if device_uri is not None:
+ constructor_kwargs['device_uri'] = device_uri
+ return OAuth2WebServerFlow(
client_info['client_id'], client_info['client_secret'],
scope, **constructor_kwargs)
- except clientsecrets.InvalidClientSecretsError:
- if message:
- sys.exit(message)
+ except clientsecrets.InvalidClientSecretsError:
+ if message:
+ sys.exit(message)
+ else:
+ raise
else:
- raise
- else:
- raise UnknownClientSecretsFlowError(
+ raise UnknownClientSecretsFlowError(
'This OAuth 2.0 flow is unsupported: %r' % client_type)
diff --git a/oauth2client/clientsecrets.py b/oauth2client/clientsecrets.py
index 08a1702..efa912a 100644
--- a/oauth2client/clientsecrets.py
+++ b/oauth2client/clientsecrets.py
@@ -23,7 +23,6 @@ __author__ = 'jcgregorio@google.com (Joe Gregorio)'
import json
import six
-
# Properties that make a client_secrets.json file valid.
TYPE_WEB = 'web'
TYPE_INSTALLED = 'installed'
@@ -59,65 +58,65 @@ VALID_CLIENT = {
class Error(Exception):
- """Base error for this module."""
- pass
+ """Base error for this module."""
+ pass
class InvalidClientSecretsError(Error):
- """Format of ClientSecrets file is invalid."""
- pass
+ """Format of ClientSecrets file is invalid."""
+ pass
def _validate_clientsecrets(obj):
- _INVALID_FILE_FORMAT_MSG = (
+ _INVALID_FILE_FORMAT_MSG = (
'Invalid file format. See '
'https://developers.google.com/api-client-library/'
'python/guide/aaa_client_secrets')
- if obj is None:
- raise InvalidClientSecretsError(_INVALID_FILE_FORMAT_MSG)
- if len(obj) != 1:
- raise InvalidClientSecretsError(
+ if obj is None:
+ raise InvalidClientSecretsError(_INVALID_FILE_FORMAT_MSG)
+ if len(obj) != 1:
+ raise InvalidClientSecretsError(
_INVALID_FILE_FORMAT_MSG + ' '
'Expected a JSON object with a single property for a "web" or '
'"installed" application')
- client_type = tuple(obj)[0]
- if client_type not in VALID_CLIENT:
- raise InvalidClientSecretsError('Unknown client type: %s.' % (client_type,))
- client_info = obj[client_type]
- for prop_name in VALID_CLIENT[client_type]['required']:
- if prop_name not in client_info:
- raise InvalidClientSecretsError(
+ client_type = tuple(obj)[0]
+ if client_type not in VALID_CLIENT:
+ raise InvalidClientSecretsError('Unknown client type: %s.' % (client_type, ))
+ client_info = obj[client_type]
+ for prop_name in VALID_CLIENT[client_type]['required']:
+ if prop_name not in client_info:
+ raise InvalidClientSecretsError(
'Missing property "%s" in a client type of "%s".' % (prop_name,
client_type))
- for prop_name in VALID_CLIENT[client_type]['string']:
- if client_info[prop_name].startswith('[['):
- raise InvalidClientSecretsError(
+ for prop_name in VALID_CLIENT[client_type]['string']:
+ if client_info[prop_name].startswith('[['):
+ raise InvalidClientSecretsError(
'Property "%s" is not configured.' % prop_name)
- return client_type, client_info
+ return client_type, client_info
def load(fp):
- obj = json.load(fp)
- return _validate_clientsecrets(obj)
+ obj = json.load(fp)
+ return _validate_clientsecrets(obj)
def loads(s):
- obj = json.loads(s)
- return _validate_clientsecrets(obj)
+ obj = json.loads(s)
+ return _validate_clientsecrets(obj)
def _loadfile(filename):
- try:
- with open(filename, 'r') as fp:
- obj = json.load(fp)
- except IOError:
- raise InvalidClientSecretsError('File not found: "%s"' % filename)
- return _validate_clientsecrets(obj)
+ try:
+ with open(filename, 'r') as fp:
+ obj = json.load(fp)
+ except IOError:
+ raise InvalidClientSecretsError('File not found: "%s"' % filename)
+ return _validate_clientsecrets(obj)
def loadfile(filename, cache=None):
- """Loading of client_secrets JSON file, optionally backed by a cache.
+ """Loading of client_secrets JSON file, optionally backed by a cache.
Typical cache storage would be App Engine memcache service,
but you can pass in any other cache client that implements
@@ -149,15 +148,15 @@ def loadfile(filename, cache=None):
JSON contents is validated only during first load. Cache hits are not
validated.
"""
- _SECRET_NAMESPACE = 'oauth2client:secrets#ns'
+ _SECRET_NAMESPACE = 'oauth2client:secrets#ns'
- if not cache:
- return _loadfile(filename)
+ if not cache:
+ return _loadfile(filename)
- obj = cache.get(filename, namespace=_SECRET_NAMESPACE)
- if obj is None:
- client_type, client_info = _loadfile(filename)
- obj = {client_type: client_info}
- cache.set(filename, obj, namespace=_SECRET_NAMESPACE)
+ obj = cache.get(filename, namespace=_SECRET_NAMESPACE)
+ if obj is None:
+ client_type, client_info = _loadfile(filename)
+ obj = {client_type: client_info}
+ cache.set(filename, obj, namespace=_SECRET_NAMESPACE)
- return next(six.iteritems(obj))
+ return next(six.iteritems(obj))
diff --git a/oauth2client/crypt.py b/oauth2client/crypt.py
index 75ecfd1..d5e18f4 100644
--- a/oauth2client/crypt.py
+++ b/oauth2client/crypt.py
@@ -25,51 +25,51 @@ from oauth2client._helpers import _to_bytes
from oauth2client._helpers import _urlsafe_b64decode
from oauth2client._helpers import _urlsafe_b64encode
-
CLOCK_SKEW_SECS = 300 # 5 minutes in seconds
AUTH_TOKEN_LIFETIME_SECS = 300 # 5 minutes in seconds
MAX_TOKEN_LIFETIME_SECS = 86400 # 1 day in seconds
-
logger = logging.getLogger(__name__)
class AppIdentityError(Exception):
- pass
+ pass
try:
- from oauth2client._openssl_crypt import OpenSSLVerifier
- from oauth2client._openssl_crypt import OpenSSLSigner
- from oauth2client._openssl_crypt import pkcs12_key_as_pem
+ from oauth2client._openssl_crypt import OpenSSLVerifier
+ from oauth2client._openssl_crypt import OpenSSLSigner
+ from oauth2client._openssl_crypt import pkcs12_key_as_pem
except ImportError:
- OpenSSLVerifier = None
- OpenSSLSigner = None
- def pkcs12_key_as_pem(*args, **kwargs):
- raise NotImplementedError('pkcs12_key_as_pem requires OpenSSL.')
+ OpenSSLVerifier = None
+ OpenSSLSigner = None
+
+
+ def pkcs12_key_as_pem(*args, **kwargs):
+ raise NotImplementedError('pkcs12_key_as_pem requires OpenSSL.')
try:
- from oauth2client._pycrypto_crypt import PyCryptoVerifier
- from oauth2client._pycrypto_crypt import PyCryptoSigner
+ from oauth2client._pycrypto_crypt import PyCryptoVerifier
+ from oauth2client._pycrypto_crypt import PyCryptoSigner
except ImportError:
- PyCryptoVerifier = None
- PyCryptoSigner = None
+ PyCryptoVerifier = None
+ PyCryptoSigner = None
if OpenSSLSigner:
- Signer = OpenSSLSigner
- Verifier = OpenSSLVerifier
+ Signer = OpenSSLSigner
+ Verifier = OpenSSLVerifier
elif PyCryptoSigner:
- Signer = PyCryptoSigner
- Verifier = PyCryptoVerifier
+ Signer = PyCryptoSigner
+ Verifier = PyCryptoVerifier
else:
- raise ImportError('No encryption library found. Please install either '
+ raise ImportError('No encryption library found. Please install either '
'PyOpenSSL, or PyCrypto 2.6 or later')
def make_signed_jwt(signer, payload):
- """Make a signed JWT.
+ """Make a signed JWT.
See http://self-issued.info/docs/draft-jones-json-web-token.html.
@@ -80,24 +80,24 @@ def make_signed_jwt(signer, payload):
Returns:
string, The JWT for the payload.
"""
- header = {'typ': 'JWT', 'alg': 'RS256'}
+ header = {'typ': 'JWT', 'alg': 'RS256'}
- segments = [
+ segments = [
_urlsafe_b64encode(_json_encode(header)),
_urlsafe_b64encode(_json_encode(payload)),
- ]
- signing_input = b'.'.join(segments)
+ ]
+ signing_input = b'.'.join(segments)
- signature = signer.sign(signing_input)
- segments.append(_urlsafe_b64encode(signature))
+ signature = signer.sign(signing_input)
+ segments.append(_urlsafe_b64encode(signature))
- logger.debug(str(segments))
+ logger.debug(str(segments))
- return b'.'.join(segments)
+ return b'.'.join(segments)
def verify_signed_jwt_with_certs(jwt, certs, audience):
- """Verify a JWT against public certs.
+ """Verify a JWT against public certs.
See http://self-issued.info/docs/draft-jones-json-web-token.html.
@@ -113,61 +113,61 @@ def verify_signed_jwt_with_certs(jwt, certs, audience):
Raises:
AppIdentityError if any checks are failed.
"""
- jwt = _to_bytes(jwt)
- segments = jwt.split(b'.')
+ jwt = _to_bytes(jwt)
+ segments = jwt.split(b'.')
- if len(segments) != 3:
- raise AppIdentityError('Wrong number of segments in token: %s' % jwt)
- signed = segments[0] + b'.' + segments[1]
+ if len(segments) != 3:
+ raise AppIdentityError('Wrong number of segments in token: %s' % jwt)
+ signed = segments[0] + b'.' + segments[1]
- signature = _urlsafe_b64decode(segments[2])
+ signature = _urlsafe_b64decode(segments[2])
- # Parse token.
- json_body = _urlsafe_b64decode(segments[1])
- try:
- parsed = json.loads(_from_bytes(json_body))
- except:
- raise AppIdentityError('Can\'t parse token: %s' % json_body)
+ # Parse token.
+ json_body = _urlsafe_b64decode(segments[1])
+ try:
+ parsed = json.loads(_from_bytes(json_body))
+ except:
+ raise AppIdentityError('Can\'t parse token: %s' % json_body)
- # Check signature.
- verified = False
- for pem in certs.values():
- verifier = Verifier.from_string(pem, True)
- if verifier.verify(signed, signature):
- verified = True
- break
- if not verified:
- raise AppIdentityError('Invalid token signature: %s' % jwt)
+ # Check signature.
+ verified = False
+ for pem in certs.values():
+ verifier = Verifier.from_string(pem, True)
+ if verifier.verify(signed, signature):
+ verified = True
+ break
+ if not verified:
+ raise AppIdentityError('Invalid token signature: %s' % jwt)
- # Check creation timestamp.
- iat = parsed.get('iat')
- if iat is None:
- raise AppIdentityError('No iat field in token: %s' % json_body)
- earliest = iat - CLOCK_SKEW_SECS
+ # Check creation timestamp.
+ iat = parsed.get('iat')
+ if iat is None:
+ raise AppIdentityError('No iat field in token: %s' % json_body)
+ earliest = iat - CLOCK_SKEW_SECS
- # Check expiration timestamp.
- now = int(time.time())
- exp = parsed.get('exp')
- if exp is None:
- raise AppIdentityError('No exp field in token: %s' % json_body)
- if exp >= now + MAX_TOKEN_LIFETIME_SECS:
- raise AppIdentityError('exp field too far in future: %s' % json_body)
- latest = exp + CLOCK_SKEW_SECS
+ # Check expiration timestamp.
+ now = int(time.time())
+ exp = parsed.get('exp')
+ if exp is None:
+ raise AppIdentityError('No exp field in token: %s' % json_body)
+ if exp >= now + MAX_TOKEN_LIFETIME_SECS:
+ raise AppIdentityError('exp field too far in future: %s' % json_body)
+ latest = exp + CLOCK_SKEW_SECS
- if now < earliest:
- raise AppIdentityError('Token used too early, %d < %d: %s' %
+ if now < earliest:
+ raise AppIdentityError('Token used too early, %d < %d: %s' %
(now, earliest, json_body))
- if now > latest:
- raise AppIdentityError('Token used too late, %d > %d: %s' %
+ if now > latest:
+ raise AppIdentityError('Token used too late, %d > %d: %s' %
(now, latest, json_body))
- # Check audience.
- if audience is not None:
- aud = parsed.get('aud')
- if aud is None:
- raise AppIdentityError('No aud field in token: %s' % json_body)
- if aud != audience:
- raise AppIdentityError('Wrong recipient, %s != %s: %s' %
+ # Check audience.
+ if audience is not None:
+ aud = parsed.get('aud')
+ if aud is None:
+ raise AppIdentityError('No aud field in token: %s' % json_body)
+ if aud != audience:
+ raise AppIdentityError('Wrong recipient, %s != %s: %s' %
(aud, audience, json_body))
- return parsed
+ return parsed
diff --git a/oauth2client/devshell.py b/oauth2client/devshell.py
index 52eb260..8613dc3 100644
--- a/oauth2client/devshell.py
+++ b/oauth2client/devshell.py
@@ -20,22 +20,20 @@ import os
from oauth2client._helpers import _to_bytes
from oauth2client import client
-
DEVSHELL_ENV = 'DEVSHELL_CLIENT_PORT'
class Error(Exception):
- """Errors for this module."""
- pass
+ """Errors for this module."""
+ pass
class CommunicationError(Error):
- """Errors for communication with the Developer Shell server."""
+ """Errors for communication with the Developer Shell server."""
class NoDevshellServer(Error):
- """Error when no Developer Shell server can be contacted."""
-
+ """Error when no Developer Shell server can be contacted."""
# The request for credential information to the Developer Shell client socket is
# always an empty PBLite-formatted JSON object, so just define it as a constant.
@@ -43,7 +41,7 @@ CREDENTIAL_INFO_REQUEST_JSON = '[]'
class CredentialInfoResponse(object):
- """Credential information response from Developer Shell server.
+ """Credential information response from Developer Shell server.
The credential information response from Developer Shell socket is a
PBLite-formatted JSON array with fields encoded by their index in the array:
@@ -52,46 +50,46 @@ class CredentialInfoResponse(object):
* Index 2 - OAuth2 access token. None if there is no valid auth context.
"""
- def __init__(self, json_string):
- """Initialize the response data from JSON PBLite array."""
- pbl = json.loads(json_string)
- if not isinstance(pbl, list):
- raise ValueError('Not a list: ' + str(pbl))
- pbl_len = len(pbl)
- self.user_email = pbl[0] if pbl_len > 0 else None
- self.project_id = pbl[1] if pbl_len > 1 else None
- self.access_token = pbl[2] if pbl_len > 2 else None
+ def __init__(self, json_string):
+ """Initialize the response data from JSON PBLite array."""
+ pbl = json.loads(json_string)
+ if not isinstance(pbl, list):
+ raise ValueError('Not a list: ' + str(pbl))
+ pbl_len = len(pbl)
+ self.user_email = pbl[0] if pbl_len > 0 else None
+ self.project_id = pbl[1] if pbl_len > 1 else None
+ self.access_token = pbl[2] if pbl_len > 2 else None
def _SendRecv():
- """Communicate with the Developer Shell server socket."""
+ """Communicate with the Developer Shell server socket."""
- port = int(os.getenv(DEVSHELL_ENV, 0))
- if port == 0:
- raise NoDevshellServer()
+ port = int(os.getenv(DEVSHELL_ENV, 0))
+ if port == 0:
+ raise NoDevshellServer()
- import socket
+ import socket
- sock = socket.socket()
- sock.connect(('localhost', port))
+ sock = socket.socket()
+ sock.connect(('localhost', port))
- data = CREDENTIAL_INFO_REQUEST_JSON
- msg = '%s\n%s' % (len(data), data)
- sock.sendall(_to_bytes(msg, encoding='utf-8'))
+ data = CREDENTIAL_INFO_REQUEST_JSON
+ msg = '%s\n%s' % (len(data), data)
+ sock.sendall(_to_bytes(msg, encoding='utf-8'))
- header = sock.recv(6).decode()
- if '\n' not in header:
- raise CommunicationError('saw no newline in the first 6 bytes')
- len_str, json_str = header.split('\n', 1)
- to_read = int(len_str) - len(json_str)
- if to_read > 0:
- json_str += sock.recv(to_read, socket.MSG_WAITALL).decode()
+ header = sock.recv(6).decode()
+ if '\n' not in header:
+ raise CommunicationError('saw no newline in the first 6 bytes')
+ len_str, json_str = header.split('\n', 1)
+ to_read = int(len_str) - len(json_str)
+ if to_read > 0:
+ json_str += sock.recv(to_read, socket.MSG_WAITALL).decode()
- return CredentialInfoResponse(json_str)
+ return CredentialInfoResponse(json_str)
class DevshellCredentials(client.GoogleCredentials):
- """Credentials object for Google Developer Shell environment.
+ """Credentials object for Google Developer Shell environment.
This object will allow a Google Developer Shell session to identify its user
to Google and other OAuth 2.0 servers that can verify assertions. It can be
@@ -102,8 +100,8 @@ class DevshellCredentials(client.GoogleCredentials):
generate and refresh its own access tokens.
"""
- def __init__(self, user_agent=None):
- super(DevshellCredentials, self).__init__(
+ def __init__(self, user_agent=None):
+ super(DevshellCredentials, self).__init__(
None, # access_token, initialized below
None, # client_id
None, # client_secret
@@ -111,26 +109,26 @@ class DevshellCredentials(client.GoogleCredentials):
None, # token_expiry
None, # token_uri
user_agent)
- self._refresh(None)
+ self._refresh(None)
- def _refresh(self, http_request):
- self.devshell_response = _SendRecv()
- self.access_token = self.devshell_response.access_token
+ def _refresh(self, http_request):
+ self.devshell_response = _SendRecv()
+ self.access_token = self.devshell_response.access_token
- @property
- def user_email(self):
- return self.devshell_response.user_email
+ @property
+ def user_email(self):
+ return self.devshell_response.user_email
- @property
- def project_id(self):
- return self.devshell_response.project_id
+ @property
+ def project_id(self):
+ return self.devshell_response.project_id
- @classmethod
- def from_json(cls, json_data):
- raise NotImplementedError(
+ @classmethod
+ def from_json(cls, json_data):
+ raise NotImplementedError(
'Cannot load Developer Shell credentials from JSON.')
- @property
- def serialization_data(self):
- raise NotImplementedError(
+ @property
+ def serialization_data(self):
+ raise NotImplementedError(
'Cannot serialize Developer Shell credentials.')
diff --git a/oauth2client/django_orm.py b/oauth2client/django_orm.py
index 65c5d20..afc1e0a 100644
--- a/oauth2client/django_orm.py
+++ b/oauth2client/django_orm.py
@@ -27,58 +27,59 @@ import pickle
from django.db import models
from oauth2client.client import Storage as BaseStorage
+
class CredentialsField(models.Field):
- __metaclass__ = models.SubfieldBase
+ __metaclass__ = models.SubfieldBase
- def __init__(self, *args, **kwargs):
- if 'null' not in kwargs:
- kwargs['null'] = True
- super(CredentialsField, self).__init__(*args, **kwargs)
+ def __init__(self, *args, **kwargs):
+ if 'null' not in kwargs:
+ kwargs['null'] = True
+ super(CredentialsField, self).__init__(*args, **kwargs)
- def get_internal_type(self):
- return "TextField"
+ def get_internal_type(self):
+ return "TextField"
- def to_python(self, value):
- if value is None:
- return None
- if isinstance(value, oauth2client.client.Credentials):
- return value
- return pickle.loads(base64.b64decode(value))
+ def to_python(self, value):
+ if value is None:
+ return None
+ if isinstance(value, oauth2client.client.Credentials):
+ return value
+ return pickle.loads(base64.b64decode(value))
- def get_db_prep_value(self, value, connection, prepared=False):
- if value is None:
- return None
- return base64.b64encode(pickle.dumps(value))
+ def get_db_prep_value(self, value, connection, prepared=False):
+ if value is None:
+ return None
+ return base64.b64encode(pickle.dumps(value))
class FlowField(models.Field):
- __metaclass__ = models.SubfieldBase
+ __metaclass__ = models.SubfieldBase
- def __init__(self, *args, **kwargs):
- if 'null' not in kwargs:
- kwargs['null'] = True
- super(FlowField, self).__init__(*args, **kwargs)
+ def __init__(self, *args, **kwargs):
+ if 'null' not in kwargs:
+ kwargs['null'] = True
+ super(FlowField, self).__init__(*args, **kwargs)
- def get_internal_type(self):
- return "TextField"
+ def get_internal_type(self):
+ return "TextField"
- def to_python(self, value):
- if value is None:
- return None
- if isinstance(value, oauth2client.client.Flow):
- return value
- return pickle.loads(base64.b64decode(value))
+ def to_python(self, value):
+ if value is None:
+ return None
+ if isinstance(value, oauth2client.client.Flow):
+ return value
+ return pickle.loads(base64.b64decode(value))
- def get_db_prep_value(self, value, connection, prepared=False):
- if value is None:
- return None
- return base64.b64encode(pickle.dumps(value))
+ def get_db_prep_value(self, value, connection, prepared=False):
+ if value is None:
+ return None
+ return base64.b64encode(pickle.dumps(value))
class Storage(BaseStorage):
- """Store and retrieve a single credential to and from
+ """Store and retrieve a single credential to and from
the datastore.
This Storage helper presumes the Credentials
@@ -86,8 +87,8 @@ class Storage(BaseStorage):
on a db model class.
"""
- def __init__(self, model_class, key_name, key_value, property_name):
- """Constructor for Storage.
+ def __init__(self, model_class, key_name, key_value, property_name):
+ """Constructor for Storage.
Args:
model: db.Model, model class
@@ -95,47 +96,47 @@ class Storage(BaseStorage):
key_value: string, key value for the entity that has the credentials
property_name: string, name of the property that is an CredentialsProperty
"""
- self.model_class = model_class
- self.key_name = key_name
- self.key_value = key_value
- self.property_name = property_name
+ self.model_class = model_class
+ self.key_name = key_name
+ self.key_value = key_value
+ self.property_name = property_name
- def locked_get(self):
- """Retrieve Credential from datastore.
+ def locked_get(self):
+ """Retrieve Credential from datastore.
Returns:
oauth2client.Credentials
"""
- credential = None
+ credential = None
- query = {self.key_name: self.key_value}
- entities = self.model_class.objects.filter(**query)
- if len(entities) > 0:
- credential = getattr(entities[0], self.property_name)
- if credential and hasattr(credential, 'set_store'):
- credential.set_store(self)
- return credential
+ query = {self.key_name: self.key_value}
+ entities = self.model_class.objects.filter(**query)
+ if len(entities) > 0:
+ credential = getattr(entities[0], self.property_name)
+ if credential and hasattr(credential, 'set_store'):
+ credential.set_store(self)
+ return credential
- def locked_put(self, credentials, overwrite=False):
- """Write a Credentials to the datastore.
+ def locked_put(self, credentials, overwrite=False):
+ """Write a Credentials to the datastore.
Args:
credentials: Credentials, the credentials to store.
overwrite: Boolean, indicates whether you would like these credentials to
overwrite any existing stored credentials.
"""
- args = {self.key_name: self.key_value}
+ args = {self.key_name: self.key_value}
- if overwrite:
- entity, unused_is_new = self.model_class.objects.get_or_create(**args)
- else:
- entity = self.model_class(**args)
+ if overwrite:
+ entity, unused_is_new = self.model_class.objects.get_or_create(**args)
+ else:
+ entity = self.model_class(**args)
- setattr(entity, self.property_name, credentials)
- entity.save()
+ setattr(entity, self.property_name, credentials)
+ entity.save()
- def locked_delete(self):
- """Delete Credentials from the datastore."""
+ def locked_delete(self):
+ """Delete Credentials from the datastore."""
- query = {self.key_name: self.key_value}
- entities = self.model_class.objects.filter(**query).delete()
+ query = {self.key_name: self.key_value}
+ entities = self.model_class.objects.filter(**query).delete()
diff --git a/oauth2client/file.py b/oauth2client/file.py
index 9d0ae7f..1ed6cff 100644
--- a/oauth2client/file.py
+++ b/oauth2client/file.py
@@ -28,37 +28,37 @@ from oauth2client.client import Storage as BaseStorage
class CredentialsFileSymbolicLinkError(Exception):
- """Credentials files must not be symbolic links."""
+ """Credentials files must not be symbolic links."""
class Storage(BaseStorage):
- """Store and retrieve a single credential to and from a file."""
+ """Store and retrieve a single credential to and from a file."""
- def __init__(self, filename):
- self._filename = filename
- self._lock = threading.Lock()
+ def __init__(self, filename):
+ self._filename = filename
+ self._lock = threading.Lock()
- def _validate_file(self):
- if os.path.islink(self._filename):
- raise CredentialsFileSymbolicLinkError(
+ def _validate_file(self):
+ if os.path.islink(self._filename):
+ raise CredentialsFileSymbolicLinkError(
'File: %s is a symbolic link.' % self._filename)
- def acquire_lock(self):
- """Acquires any lock necessary to access this Storage.
+ def acquire_lock(self):
+ """Acquires any lock necessary to access this Storage.
This lock is not reentrant."""
- self._lock.acquire()
+ self._lock.acquire()
- def release_lock(self):
- """Release the Storage lock.
+ def release_lock(self):
+ """Release the Storage lock.
Trying to release a lock that isn't held will result in a
RuntimeError.
"""
- self._lock.release()
+ self._lock.release()
- def locked_get(self):
- """Retrieve Credential from file.
+ def locked_get(self):
+ """Retrieve Credential from file.
Returns:
oauth2client.client.Credentials
@@ -66,38 +66,38 @@ class Storage(BaseStorage):
Raises:
CredentialsFileSymbolicLinkError if the file is a symbolic link.
"""
- credentials = None
- self._validate_file()
- try:
- f = open(self._filename, 'rb')
- content = f.read()
- f.close()
- except IOError:
- return credentials
+ credentials = None
+ self._validate_file()
+ try:
+ f = open(self._filename, 'rb')
+ content = f.read()
+ f.close()
+ except IOError:
+ return credentials
- try:
- credentials = Credentials.new_from_json(content)
- credentials.set_store(self)
- except ValueError:
- pass
+ try:
+ credentials = Credentials.new_from_json(content)
+ credentials.set_store(self)
+ except ValueError:
+ pass
- return credentials
+ return credentials
- def _create_file_if_needed(self):
- """Create an empty file if necessary.
+ def _create_file_if_needed(self):
+ """Create an empty file if necessary.
This method will not initialize the file. Instead it implements a
simple version of "touch" to ensure the file has been created.
"""
- if not os.path.exists(self._filename):
- old_umask = os.umask(0o177)
- try:
- open(self._filename, 'a+b').close()
- finally:
- os.umask(old_umask)
+ if not os.path.exists(self._filename):
+ old_umask = os.umask(0o177)
+ try:
+ open(self._filename, 'a+b').close()
+ finally:
+ os.umask(old_umask)
- def locked_put(self, credentials):
- """Write Credentials to file.
+ def locked_put(self, credentials):
+ """Write Credentials to file.
Args:
credentials: Credentials, the credentials to store.
@@ -106,17 +106,17 @@ class Storage(BaseStorage):
CredentialsFileSymbolicLinkError if the file is a symbolic link.
"""
- self._create_file_if_needed()
- self._validate_file()
- f = open(self._filename, 'w')
- f.write(credentials.to_json())
- f.close()
+ self._create_file_if_needed()
+ self._validate_file()
+ f = open(self._filename, 'w')
+ f.write(credentials.to_json())
+ f.close()
- def locked_delete(self):
- """Delete Credentials file.
+ def locked_delete(self):
+ """Delete Credentials file.
Args:
credentials: Credentials, the credentials to store.
"""
- os.unlink(self._filename)
+ os.unlink(self._filename)
diff --git a/oauth2client/flask_util.py b/oauth2client/flask_util.py
index b3e7002..2771f47 100644
--- a/oauth2client/flask_util.py
+++ b/oauth2client/flask_util.py
@@ -189,8 +189,7 @@ from oauth2client.client import Storage
from oauth2client import clientsecrets
from oauth2client import util
-
-DEFAULT_SCOPES = ('email',)
+DEFAULT_SCOPES = ('email', )
class UserOAuth2(object):
@@ -461,6 +460,7 @@ class UserOAuth2(object):
be redirected to the authorization flow. Once complete, the user will
be redirected back to the original page.
"""
+
def curry_wrapper(wrapped_function):
@wraps(wrapped_function)
def required_wrapper(*args, **kwargs):
@@ -519,6 +519,7 @@ class FlaskSessionStorage(Storage):
credentials. We strongly recommend using a server-side session
implementation.
"""
+
def locked_get(self):
serialized = session.get('google_oauth2_credentials')
diff --git a/oauth2client/gce.py b/oauth2client/gce.py
index e4729d1..08bdecb 100644
--- a/oauth2client/gce.py
+++ b/oauth2client/gce.py
@@ -36,7 +36,7 @@ META = ('http://metadata.google.internal/0.1/meta-data/service-accounts/'
class AppAssertionCredentials(AssertionCredentials):
- """Credentials object for Compute Engine Assertion Grants
+ """Credentials object for Compute Engine Assertion Grants
This object will allow a Compute Engine instance to identify itself to
Google and other OAuth 2.0 servers that can verify assertions. It can be used
@@ -48,27 +48,27 @@ class AppAssertionCredentials(AssertionCredentials):
generate and refresh its own access tokens.
"""
- @util.positional(2)
- def __init__(self, scope, **kwargs):
- """Constructor for AppAssertionCredentials
+ @util.positional(2)
+ def __init__(self, scope, **kwargs):
+ """Constructor for AppAssertionCredentials
Args:
scope: string or iterable of strings, scope(s) of the credentials being
requested.
"""
- self.scope = util.scopes_to_string(scope)
- self.kwargs = kwargs
+ self.scope = util.scopes_to_string(scope)
+ self.kwargs = kwargs
- # Assertion type is no longer used, but still in the parent class signature.
- super(AppAssertionCredentials, self).__init__(None)
+ # Assertion type is no longer used, but still in the parent class signature.
+ super(AppAssertionCredentials, self).__init__(None)
- @classmethod
- def from_json(cls, json_data):
- data = json.loads(_from_bytes(json_data))
- return AppAssertionCredentials(data['scope'])
+ @classmethod
+ def from_json(cls, json_data):
+ data = json.loads(_from_bytes(json_data))
+ return AppAssertionCredentials(data['scope'])
- def _refresh(self, http_request):
- """Refreshes the access_token.
+ def _refresh(self, http_request):
+ """Refreshes the access_token.
Skip all the storage hoops and just refresh using the API.
@@ -79,29 +79,29 @@ class AppAssertionCredentials(AssertionCredentials):
Raises:
AccessTokenRefreshError: When the refresh fails.
"""
- query = '?scope=%s' % urllib.parse.quote(self.scope, '')
- uri = META.replace('{?scope}', query)
- response, content = http_request(uri)
- content = _from_bytes(content)
- if response.status == 200:
- try:
- d = json.loads(content)
- except Exception as e:
- raise AccessTokenRefreshError(str(e))
- self.access_token = d['accessToken']
- else:
- if response.status == 404:
- content += (' This can occur if a VM was created'
+ query = '?scope=%s' % urllib.parse.quote(self.scope, '')
+ uri = META.replace('{?scope}', query)
+ response, content = http_request(uri)
+ content = _from_bytes(content)
+ if response.status == 200:
+ try:
+ d = json.loads(content)
+ except Exception as e:
+ raise AccessTokenRefreshError(str(e))
+ self.access_token = d['accessToken']
+ else:
+ if response.status == 404:
+ content += (' This can occur if a VM was created'
' with no service account or scopes.')
- raise AccessTokenRefreshError(content)
+ raise AccessTokenRefreshError(content)
- @property
+ @property
def serialization_data(self):
- raise NotImplementedError(
+ raise NotImplementedError(
'Cannot serialize credentials for GCE service accounts.')
- def create_scoped_required(self):
- return not self.scope
+ def create_scoped_required(self):
+ return not self.scope
- def create_scoped(self, scopes):
- return AppAssertionCredentials(scopes, **self.kwargs)
+ def create_scoped(self, scopes):
+ return AppAssertionCredentials(scopes, **self.kwargs)
diff --git a/oauth2client/keyring_storage.py b/oauth2client/keyring_storage.py
index cda1d9a..d2b9a9a 100644
--- a/oauth2client/keyring_storage.py
+++ b/oauth2client/keyring_storage.py
@@ -28,7 +28,7 @@ from oauth2client.client import Storage as BaseStorage
class Storage(BaseStorage):
- """Store and retrieve a single credential to and from the keyring.
+ """Store and retrieve a single credential to and from the keyring.
To use this module you must have the keyring module installed. See
. This is an optional module and is not
@@ -48,63 +48,63 @@ class Storage(BaseStorage):
"""
- def __init__(self, service_name, user_name):
- """Constructor.
+ def __init__(self, service_name, user_name):
+ """Constructor.
Args:
service_name: string, The name of the service under which the credentials
are stored.
user_name: string, The name of the user to store credentials for.
"""
- self._service_name = service_name
- self._user_name = user_name
- self._lock = threading.Lock()
+ self._service_name = service_name
+ self._user_name = user_name
+ self._lock = threading.Lock()
- def acquire_lock(self):
- """Acquires any lock necessary to access this Storage.
+ def acquire_lock(self):
+ """Acquires any lock necessary to access this Storage.
This lock is not reentrant."""
- self._lock.acquire()
+ self._lock.acquire()
- def release_lock(self):
- """Release the Storage lock.
+ def release_lock(self):
+ """Release the Storage lock.
Trying to release a lock that isn't held will result in a
RuntimeError.
"""
- self._lock.release()
+ self._lock.release()
- def locked_get(self):
- """Retrieve Credential from file.
+ def locked_get(self):
+ """Retrieve Credential from file.
Returns:
oauth2client.client.Credentials
"""
- credentials = None
- content = keyring.get_password(self._service_name, self._user_name)
+ credentials = None
+ content = keyring.get_password(self._service_name, self._user_name)
- if content is not None:
- try:
- credentials = Credentials.new_from_json(content)
- credentials.set_store(self)
- except ValueError:
- pass
+ if content is not None:
+ try:
+ credentials = Credentials.new_from_json(content)
+ credentials.set_store(self)
+ except ValueError:
+ pass
- return credentials
+ return credentials
- def locked_put(self, credentials):
- """Write Credentials to file.
+ def locked_put(self, credentials):
+ """Write Credentials to file.
Args:
credentials: Credentials, the credentials to store.
"""
- keyring.set_password(self._service_name, self._user_name,
+ keyring.set_password(self._service_name, self._user_name,
credentials.to_json())
- def locked_delete(self):
- """Delete Credentials file.
+ def locked_delete(self):
+ """Delete Credentials file.
Args:
credentials: Credentials, the credentials to store.
"""
- keyring.set_password(self._service_name, self._user_name, '')
+ keyring.set_password(self._service_name, self._user_name, '')
diff --git a/oauth2client/locked_file.py b/oauth2client/locked_file.py
index af92398..cce3359 100644
--- a/oauth2client/locked_file.py
+++ b/oauth2client/locked_file.py
@@ -45,68 +45,69 @@ logger = logging.getLogger(__name__)
class CredentialsFileSymbolicLinkError(Exception):
- """Credentials files must not be symbolic links."""
+ """Credentials files must not be symbolic links."""
class AlreadyLockedException(Exception):
- """Trying to lock a file that has already been locked by the LockedFile."""
- pass
+ """Trying to lock a file that has already been locked by the LockedFile."""
+ pass
def validate_file(filename):
- if os.path.islink(filename):
- raise CredentialsFileSymbolicLinkError(
+ if os.path.islink(filename):
+ raise CredentialsFileSymbolicLinkError(
'File: %s is a symbolic link.' % filename)
-class _Opener(object):
- """Base class for different locking primitives."""
- def __init__(self, filename, mode, fallback_mode):
- """Create an Opener.
+class _Opener(object):
+ """Base class for different locking primitives."""
+
+ def __init__(self, filename, mode, fallback_mode):
+ """Create an Opener.
Args:
filename: string, The pathname of the file.
mode: string, The preferred mode to access the file with.
fallback_mode: string, The mode to use if locking fails.
"""
- self._locked = False
- self._filename = filename
- self._mode = mode
- self._fallback_mode = fallback_mode
- self._fh = None
- self._lock_fd = None
+ self._locked = False
+ self._filename = filename
+ self._mode = mode
+ self._fallback_mode = fallback_mode
+ self._fh = None
+ self._lock_fd = None
- def is_locked(self):
- """Was the file locked."""
- return self._locked
+ def is_locked(self):
+ """Was the file locked."""
+ return self._locked
- def file_handle(self):
- """The file handle to the file. Valid only after opened."""
- return self._fh
+ def file_handle(self):
+ """The file handle to the file. Valid only after opened."""
+ return self._fh
- def filename(self):
- """The filename that is being locked."""
- return self._filename
+ def filename(self):
+ """The filename that is being locked."""
+ return self._filename
- def open_and_lock(self, timeout, delay):
- """Open the file and lock it.
+ def open_and_lock(self, timeout, delay):
+ """Open the file and lock it.
Args:
timeout: float, How long to try to lock for.
delay: float, How long to wait between retries.
"""
- pass
+ pass
- def unlock_and_close(self):
- """Unlock and close the file."""
- pass
+ def unlock_and_close(self):
+ """Unlock and close the file."""
+ pass
class _PosixOpener(_Opener):
- """Lock files using Posix advisory lock files."""
+ """Lock files using Posix advisory lock files."""
- def open_and_lock(self, timeout, delay):
- """Open the file and lock it.
+ def open_and_lock(self, timeout, delay):
+ """Open the file and lock it.
Tries to create a .lock file next to the file we're trying to open.
@@ -119,141 +120,67 @@ class _PosixOpener(_Opener):
IOError: if the open fails.
CredentialsFileSymbolicLinkError if the file is a symbolic link.
"""
- if self._locked:
- raise AlreadyLockedException('File %s is already locked' %
+ if self._locked:
+ raise AlreadyLockedException('File %s is already locked' %
self._filename)
- self._locked = False
+ self._locked = False
- validate_file(self._filename)
- try:
- self._fh = open(self._filename, self._mode)
- except IOError as e:
- # If we can't access with _mode, try _fallback_mode and don't lock.
- if e.errno == errno.EACCES:
- self._fh = open(self._filename, self._fallback_mode)
- return
+ validate_file(self._filename)
+ try:
+ self._fh = open(self._filename, self._mode)
+ except IOError as e:
+ # If we can't access with _mode, try _fallback_mode and don't lock.
+ if e.errno == errno.EACCES:
+ self._fh = open(self._filename, self._fallback_mode)
+ return
- lock_filename = self._posix_lockfile(self._filename)
+ lock_filename = self._posix_lockfile(self._filename)
start_time = time.time()
while True:
- try:
- self._lock_fd = os.open(lock_filename,
- os.O_CREAT|os.O_EXCL|os.O_RDWR)
- self._locked = True
- break
+ try:
+ self._lock_fd = os.open(lock_filename,
+ os.O_CREAT | os.O_EXCL | os.O_RDWR)
+ self._locked = True
+ break
- except OSError as e:
- if e.errno != errno.EEXIST:
- raise
- if (time.time() - start_time) >= timeout:
- logger.warn('Could not acquire lock %s in %s seconds',
+ except OSError as e:
+ if e.errno != errno.EEXIST:
+ raise
+ if (time.time() - start_time) >= timeout:
+ logger.warn('Could not acquire lock %s in %s seconds',
lock_filename, timeout)
- # Close the file and open in fallback_mode.
- if self._fh:
- self._fh.close()
- self._fh = open(self._filename, self._fallback_mode)
- return
- time.sleep(delay)
-
- def unlock_and_close(self):
- """Unlock a file by removing the .lock file, and close the handle."""
- if self._locked:
- lock_filename = self._posix_lockfile(self._filename)
- os.close(self._lock_fd)
- os.unlink(lock_filename)
- self._locked = False
- self._lock_fd = None
- if self._fh:
- self._fh.close()
-
- def _posix_lockfile(self, filename):
- """The name of the lock file to use for posix locking."""
- return '%s.lock' % filename
-
-
-try:
- import fcntl
-
- class _FcntlOpener(_Opener):
- """Open, lock, and unlock a file using fcntl.lockf."""
-
- def open_and_lock(self, timeout, delay):
- """Open the file and lock it.
-
- Args:
- timeout: float, How long to try to lock for.
- delay: float, How long to wait between retries
-
- Raises:
- AlreadyLockedException: if the lock is already acquired.
- IOError: if the open fails.
- CredentialsFileSymbolicLinkError if the file is a symbolic link.
- """
- if self._locked:
- raise AlreadyLockedException('File %s is already locked' %
- self._filename)
- start_time = time.time()
-
- validate_file(self._filename)
- try:
- self._fh = open(self._filename, self._mode)
- except IOError as e:
- # If we can't access with _mode, try _fallback_mode and don't lock.
- if e.errno in (errno.EPERM, errno.EACCES):
- self._fh = open(self._filename, self._fallback_mode)
- return
-
- # We opened in _mode, try to lock the file.
- while True:
- try:
- fcntl.lockf(self._fh.fileno(), fcntl.LOCK_EX)
- self._locked = True
- return
- except IOError as e:
- # If not retrying, then just pass on the error.
- if timeout == 0:
- raise
- if e.errno != errno.EACCES:
- raise
- # We could not acquire the lock. Try again.
- if (time.time() - start_time) >= timeout:
- logger.warn('Could not lock %s in %s seconds',
- self._filename, timeout)
- if self._fh:
- self._fh.close()
- self._fh = open(self._filename, self._fallback_mode)
- return
- time.sleep(delay)
+ # Close the file and open in fallback_mode.
+ if self._fh:
+ self._fh.close()
+ self._fh = open(self._filename, self._fallback_mode)
+ return
+ time.sleep(delay)
def unlock_and_close(self):
- """Close and unlock the file using the fcntl.lockf primitive."""
- if self._locked:
- fcntl.lockf(self._fh.fileno(), fcntl.LOCK_UN)
- self._locked = False
- if self._fh:
- self._fh.close()
-except ImportError:
- _FcntlOpener = None
+ """Unlock a file by removing the .lock file, and close the handle."""
+ if self._locked:
+ lock_filename = self._posix_lockfile(self._filename)
+ os.close(self._lock_fd)
+ os.unlink(lock_filename)
+ self._locked = False
+ self._lock_fd = None
+ if self._fh:
+ self._fh.close()
+
+ def _posix_lockfile(self, filename):
+ """The name of the lock file to use for posix locking."""
+ return '%s.lock' % filename
try:
- import pywintypes
- import win32con
- import win32file
+ import fcntl
- class _Win32Opener(_Opener):
- """Open, lock, and unlock a file using windows primitives."""
- # Error #33:
- # 'The process cannot access the file because another process'
- FILE_IN_USE_ERROR = 33
+ class _FcntlOpener(_Opener):
+ """Open, lock, and unlock a file using fcntl.lockf."""
- # Error #158:
- # 'The segment is already unlocked.'
- FILE_ALREADY_UNLOCKED_ERROR = 158
-
- def open_and_lock(self, timeout, delay):
- """Open the file and lock it.
+ def open_and_lock(self, timeout, delay):
+ """Open the file and lock it.
Args:
timeout: float, How long to try to lock for.
@@ -264,71 +191,147 @@ try:
IOError: if the open fails.
CredentialsFileSymbolicLinkError if the file is a symbolic link.
"""
- if self._locked:
- raise AlreadyLockedException('File %s is already locked' %
+ if self._locked:
+ raise AlreadyLockedException('File %s is already locked' %
self._filename)
- start_time = time.time()
+ start_time = time.time()
- validate_file(self._filename)
- try:
- self._fh = open(self._filename, self._mode)
- except IOError as e:
- # If we can't access with _mode, try _fallback_mode and don't lock.
- if e.errno == errno.EACCES:
- self._fh = open(self._filename, self._fallback_mode)
- return
+ validate_file(self._filename)
+ try:
+ self._fh = open(self._filename, self._mode)
+ except IOError as e:
+ # If we can't access with _mode, try _fallback_mode and don't lock.
+ if e.errno in (errno.EPERM, errno.EACCES):
+ self._fh = open(self._filename, self._fallback_mode)
+ return
- # We opened in _mode, try to lock the file.
- while True:
- try:
- hfile = win32file._get_osfhandle(self._fh.fileno())
- win32file.LockFileEx(
+ # We opened in _mode, try to lock the file.
+ while True:
+ try:
+ fcntl.lockf(self._fh.fileno(), fcntl.LOCK_EX)
+ self._locked = True
+ return
+ except IOError as e:
+ # If not retrying, then just pass on the error.
+ if timeout == 0:
+ raise
+ if e.errno != errno.EACCES:
+ raise
+ # We could not acquire the lock. Try again.
+ if (time.time() - start_time) >= timeout:
+ logger.warn('Could not lock %s in %s seconds',
+ self._filename, timeout)
+ if self._fh:
+ self._fh.close()
+ self._fh = open(self._filename, self._fallback_mode)
+ return
+ time.sleep(delay)
+
+ def unlock_and_close(self):
+ """Close and unlock the file using the fcntl.lockf primitive."""
+ if self._locked:
+ fcntl.lockf(self._fh.fileno(), fcntl.LOCK_UN)
+ self._locked = False
+ if self._fh:
+ self._fh.close()
+except ImportError:
+ _FcntlOpener = None
+
+
+try:
+ import pywintypes
+ import win32con
+ import win32file
+
+
+ class _Win32Opener(_Opener):
+ """Open, lock, and unlock a file using windows primitives."""
+
+ # Error #33:
+ # 'The process cannot access the file because another process'
+ FILE_IN_USE_ERROR = 33
+
+ # Error #158:
+ # 'The segment is already unlocked.'
+ FILE_ALREADY_UNLOCKED_ERROR = 158
+
+ def open_and_lock(self, timeout, delay):
+ """Open the file and lock it.
+
+ Args:
+ timeout: float, How long to try to lock for.
+ delay: float, How long to wait between retries
+
+ Raises:
+ AlreadyLockedException: if the lock is already acquired.
+ IOError: if the open fails.
+ CredentialsFileSymbolicLinkError if the file is a symbolic link.
+ """
+ if self._locked:
+ raise AlreadyLockedException('File %s is already locked' %
+ self._filename)
+ start_time = time.time()
+
+ validate_file(self._filename)
+ try:
+ self._fh = open(self._filename, self._mode)
+ except IOError as e:
+ # If we can't access with _mode, try _fallback_mode and don't lock.
+ if e.errno == errno.EACCES:
+ self._fh = open(self._filename, self._fallback_mode)
+ return
+
+ # We opened in _mode, try to lock the file.
+ while True:
+ try:
+ hfile = win32file._get_osfhandle(self._fh.fileno())
+ win32file.LockFileEx(
hfile,
- (win32con.LOCKFILE_FAIL_IMMEDIATELY|
+ (win32con.LOCKFILE_FAIL_IMMEDIATELY |
win32con.LOCKFILE_EXCLUSIVE_LOCK), 0, -0x10000,
pywintypes.OVERLAPPED())
- self._locked = True
- return
- except pywintypes.error as e:
- if timeout == 0:
- raise
+ self._locked = True
+ return
+ except pywintypes.error as e:
+ if timeout == 0:
+ raise
- # If the error is not that the file is already in use, raise.
- if e[0] != _Win32Opener.FILE_IN_USE_ERROR:
- raise
+ # If the error is not that the file is already in use, raise.
+ if e[0] != _Win32Opener.FILE_IN_USE_ERROR:
+ raise
- # We could not acquire the lock. Try again.
- if (time.time() - start_time) >= timeout:
- logger.warn('Could not lock %s in %s seconds' % (
+ # We could not acquire the lock. Try again.
+ if (time.time() - start_time) >= timeout:
+ logger.warn('Could not lock %s in %s seconds' % (
self._filename, timeout))
- if self._fh:
- self._fh.close()
- self._fh = open(self._filename, self._fallback_mode)
- return
- time.sleep(delay)
+ if self._fh:
+ self._fh.close()
+ self._fh = open(self._filename, self._fallback_mode)
+ return
+ time.sleep(delay)
- def unlock_and_close(self):
- """Close and unlock the file using the win32 primitive."""
- if self._locked:
- try:
- hfile = win32file._get_osfhandle(self._fh.fileno())
- win32file.UnlockFileEx(hfile, 0, -0x10000, pywintypes.OVERLAPPED())
- except pywintypes.error as e:
- if e[0] != _Win32Opener.FILE_ALREADY_UNLOCKED_ERROR:
- raise
- self._locked = False
+ def unlock_and_close(self):
+ """Close and unlock the file using the win32 primitive."""
+ if self._locked:
+ try:
+ hfile = win32file._get_osfhandle(self._fh.fileno())
+ win32file.UnlockFileEx(hfile, 0, -0x10000, pywintypes.OVERLAPPED())
+ except pywintypes.error as e:
+ if e[0] != _Win32Opener.FILE_ALREADY_UNLOCKED_ERROR:
+ raise
+ self._locked = False
if self._fh:
- self._fh.close()
+ self._fh.close()
except ImportError:
- _Win32Opener = None
+ _Win32Opener = None
class LockedFile(object):
- """Represent a file that has exclusive access."""
+ """Represent a file that has exclusive access."""
- @util.positional(4)
- def __init__(self, filename, mode, fallback_mode, use_native_locking=True):
- """Construct a LockedFile.
+ @util.positional(4)
+ def __init__(self, filename, mode, fallback_mode, use_native_locking=True):
+ """Construct a LockedFile.
Args:
filename: string, The path of the file to open.
@@ -336,32 +339,32 @@ class LockedFile(object):
fallback_mode: string, The mode to use if locking fails.
use_native_locking: bool, Whether or not fcntl/win32 locking is used.
"""
- opener = None
- if not opener and use_native_locking:
- if _Win32Opener:
- opener = _Win32Opener(filename, mode, fallback_mode)
- if _FcntlOpener:
- opener = _FcntlOpener(filename, mode, fallback_mode)
+ opener = None
+ if not opener and use_native_locking:
+ if _Win32Opener:
+ opener = _Win32Opener(filename, mode, fallback_mode)
+ if _FcntlOpener:
+ opener = _FcntlOpener(filename, mode, fallback_mode)
- if not opener:
- opener = _PosixOpener(filename, mode, fallback_mode)
+ if not opener:
+ opener = _PosixOpener(filename, mode, fallback_mode)
- self._opener = opener
+ self._opener = opener
- def filename(self):
- """Return the filename we were constructed with."""
- return self._opener._filename
+ def filename(self):
+ """Return the filename we were constructed with."""
+ return self._opener._filename
- def file_handle(self):
- """Return the file_handle to the opened file."""
- return self._opener.file_handle()
+ def file_handle(self):
+ """Return the file_handle to the opened file."""
+ return self._opener.file_handle()
- def is_locked(self):
- """Return whether we successfully locked the file."""
- return self._opener.is_locked()
+ def is_locked(self):
+ """Return whether we successfully locked the file."""
+ return self._opener.is_locked()
- def open_and_lock(self, timeout=0, delay=0.05):
- """Open the file, trying to lock it.
+ def open_and_lock(self, timeout=0, delay=0.05):
+ """Open the file, trying to lock it.
Args:
timeout: float, The number of seconds to try to acquire the lock.
@@ -371,8 +374,8 @@ class LockedFile(object):
AlreadyLockedException: if the lock is already acquired.
IOError: if the open fails.
"""
- self._opener.open_and_lock(timeout, delay)
+ self._opener.open_and_lock(timeout, delay)
- def unlock_and_close(self):
- """Unlock and close a file."""
- self._opener.unlock_and_close()
+ def unlock_and_close(self):
+ """Unlock and close a file."""
+ self._opener.unlock_and_close()
diff --git a/oauth2client/multistore_file.py b/oauth2client/multistore_file.py
index f4ba4a7..ffb9115 100644
--- a/oauth2client/multistore_file.py
+++ b/oauth2client/multistore_file.py
@@ -65,17 +65,17 @@ _multistores_lock = threading.Lock()
class Error(Exception):
- """Base error for this module."""
+ """Base error for this module."""
class NewerCredentialStoreError(Error):
- """The credential store is a newer version than supported."""
+ """The credential store is a newer version than supported."""
@util.positional(4)
def get_credential_storage(filename, client_id, user_agent, scope,
warn_on_readonly=True):
- """Get a Storage instance for a credential.
+ """Get a Storage instance for a credential.
Args:
filename: The JSON file storing a set of credentials
@@ -88,17 +88,17 @@ def get_credential_storage(filename, client_id, user_agent, scope,
An object derived from client.Storage for getting/setting the
credential.
"""
- # Recreate the legacy key with these specific parameters
- key = {'clientId': client_id, 'userAgent': user_agent,
+ # Recreate the legacy key with these specific parameters
+ key = {'clientId': client_id, 'userAgent': user_agent,
'scope': util.scopes_to_string(scope)}
- return get_credential_storage_custom_key(
+ return get_credential_storage_custom_key(
filename, key, warn_on_readonly=warn_on_readonly)
@util.positional(2)
def get_credential_storage_custom_string_key(
filename, key_string, warn_on_readonly=True):
- """Get a Storage instance for a credential using a single string as a key.
+ """Get a Storage instance for a credential using a single string as a key.
Allows you to provide a string as a custom key that will be used for
credential storage and retrieval.
@@ -112,16 +112,16 @@ def get_credential_storage_custom_string_key(
An object derived from client.Storage for getting/setting the
credential.
"""
- # Create a key dictionary that can be used
- key_dict = {'key': key_string}
- return get_credential_storage_custom_key(
+ # Create a key dictionary that can be used
+ key_dict = {'key': key_string}
+ return get_credential_storage_custom_key(
filename, key_dict, warn_on_readonly=warn_on_readonly)
@util.positional(2)
def get_credential_storage_custom_key(
filename, key_dict, warn_on_readonly=True):
- """Get a Storage instance for a credential using a dictionary as a key.
+ """Get a Storage instance for a credential using a dictionary as a key.
Allows you to provide a dictionary as a custom key that will be used for
credential storage and retrieval.
@@ -137,14 +137,14 @@ def get_credential_storage_custom_key(
An object derived from client.Storage for getting/setting the
credential.
"""
- multistore = _get_multistore(filename, warn_on_readonly=warn_on_readonly)
- key = util.dict_to_tuple_key(key_dict)
- return multistore._get_storage(key)
+ multistore = _get_multistore(filename, warn_on_readonly=warn_on_readonly)
+ key = util.dict_to_tuple_key(key_dict)
+ return multistore._get_storage(key)
@util.positional(1)
def get_all_credential_keys(filename, warn_on_readonly=True):
- """Gets all the registered credential keys in the given Multistore.
+ """Gets all the registered credential keys in the given Multistore.
Args:
filename: The JSON file storing a set of credentials
@@ -155,17 +155,17 @@ def get_all_credential_keys(filename, warn_on_readonly=True):
dictionaries that can be passed into get_credential_storage_custom_key to
get the actual credentials.
"""
- multistore = _get_multistore(filename, warn_on_readonly=warn_on_readonly)
- multistore._lock()
- try:
- return multistore._get_all_credential_keys()
- finally:
- multistore._unlock()
+ multistore = _get_multistore(filename, warn_on_readonly=warn_on_readonly)
+ multistore._lock()
+ try:
+ return multistore._get_all_credential_keys()
+ finally:
+ multistore._unlock()
@util.positional(1)
def _get_multistore(filename, warn_on_readonly=True):
- """A helper method to initialize the multistore with proper locking.
+ """A helper method to initialize the multistore with proper locking.
Args:
filename: The JSON file storing a set of credentials
@@ -174,176 +174,176 @@ def _get_multistore(filename, warn_on_readonly=True):
Returns:
A multistore object
"""
- filename = os.path.expanduser(filename)
- _multistores_lock.acquire()
- try:
- multistore = _multistores.setdefault(
+ filename = os.path.expanduser(filename)
+ _multistores_lock.acquire()
+ try:
+ multistore = _multistores.setdefault(
filename, _MultiStore(filename, warn_on_readonly=warn_on_readonly))
- finally:
- _multistores_lock.release()
- return multistore
+ finally:
+ _multistores_lock.release()
+ return multistore
class _MultiStore(object):
- """A file backed store for multiple credentials."""
+ """A file backed store for multiple credentials."""
- @util.positional(2)
- def __init__(self, filename, warn_on_readonly=True):
- """Initialize the class.
+ @util.positional(2)
+ def __init__(self, filename, warn_on_readonly=True):
+ """Initialize the class.
This will create the file if necessary.
"""
- self._file = LockedFile(filename, 'r+', 'r')
- self._thread_lock = threading.Lock()
- self._read_only = False
- self._warn_on_readonly = warn_on_readonly
+ self._file = LockedFile(filename, 'r+', 'r')
+ self._thread_lock = threading.Lock()
+ self._read_only = False
+ self._warn_on_readonly = warn_on_readonly
- self._create_file_if_needed()
+ self._create_file_if_needed()
- # Cache of deserialized store. This is only valid after the
- # _MultiStore is locked or _refresh_data_cache is called. This is
- # of the form of:
- #
- # ((key, value), (key, value)...) -> OAuth2Credential
- #
- # If this is None, then the store hasn't been read yet.
- self._data = None
+ # Cache of deserialized store. This is only valid after the
+ # _MultiStore is locked or _refresh_data_cache is called. This is
+ # of the form of:
+ #
+ # ((key, value), (key, value)...) -> OAuth2Credential
+ #
+ # If this is None, then the store hasn't been read yet.
+ self._data = None
- class _Storage(BaseStorage):
- """A Storage object that knows how to read/write a single credential."""
+ class _Storage(BaseStorage):
+ """A Storage object that knows how to read/write a single credential."""
- def __init__(self, multistore, key):
- self._multistore = multistore
- self._key = key
+ def __init__(self, multistore, key):
+ self._multistore = multistore
+ self._key = key
- def acquire_lock(self):
- """Acquires any lock necessary to access this Storage.
+ def acquire_lock(self):
+ """Acquires any lock necessary to access this Storage.
This lock is not reentrant.
"""
- self._multistore._lock()
+ self._multistore._lock()
- def release_lock(self):
- """Release the Storage lock.
+ def release_lock(self):
+ """Release the Storage lock.
Trying to release a lock that isn't held will result in a
RuntimeError.
"""
- self._multistore._unlock()
+ self._multistore._unlock()
- def locked_get(self):
- """Retrieve credential.
+ def locked_get(self):
+ """Retrieve credential.
The Storage lock must be held when this is called.
Returns:
oauth2client.client.Credentials
"""
- credential = self._multistore._get_credential(self._key)
- if credential:
- credential.set_store(self)
- return credential
+ credential = self._multistore._get_credential(self._key)
+ if credential:
+ credential.set_store(self)
+ return credential
- def locked_put(self, credentials):
- """Write a credential.
+ def locked_put(self, credentials):
+ """Write a credential.
The Storage lock must be held when this is called.
Args:
credentials: Credentials, the credentials to store.
"""
- self._multistore._update_credential(self._key, credentials)
+ self._multistore._update_credential(self._key, credentials)
- def locked_delete(self):
- """Delete a credential.
+ def locked_delete(self):
+ """Delete a credential.
The Storage lock must be held when this is called.
Args:
credentials: Credentials, the credentials to store.
"""
- self._multistore._delete_credential(self._key)
+ self._multistore._delete_credential(self._key)
- def _create_file_if_needed(self):
- """Create an empty file if necessary.
+ def _create_file_if_needed(self):
+ """Create an empty file if necessary.
This method will not initialize the file. Instead it implements a
simple version of "touch" to ensure the file has been created.
"""
- if not os.path.exists(self._file.filename()):
- old_umask = os.umask(0o177)
- try:
- open(self._file.filename(), 'a+b').close()
- finally:
- os.umask(old_umask)
+ if not os.path.exists(self._file.filename()):
+ old_umask = os.umask(0o177)
+ try:
+ open(self._file.filename(), 'a+b').close()
+ finally:
+ os.umask(old_umask)
- def _lock(self):
- """Lock the entire multistore."""
- self._thread_lock.acquire()
- try:
- self._file.open_and_lock()
- except IOError as e:
- if e.errno == errno.ENOSYS:
- logger.warn('File system does not support locking the credentials '
+ def _lock(self):
+ """Lock the entire multistore."""
+ self._thread_lock.acquire()
+ try:
+ self._file.open_and_lock()
+ except IOError as e:
+ if e.errno == errno.ENOSYS:
+ logger.warn('File system does not support locking the credentials '
'file.')
- elif e.errno == errno.ENOLCK:
- logger.warn('File system is out of resources for writing the '
+ elif e.errno == errno.ENOLCK:
+ logger.warn('File system is out of resources for writing the '
'credentials file (is your disk full?).')
- else:
- raise
- if not self._file.is_locked():
- self._read_only = True
- if self._warn_on_readonly:
- logger.warn('The credentials file (%s) is not writable. Opening in '
+ else:
+ raise
+ if not self._file.is_locked():
+ self._read_only = True
+ if self._warn_on_readonly:
+ logger.warn('The credentials file (%s) is not writable. Opening in '
'read-only mode. Any refreshed credentials will only be '
'valid for this run.', self._file.filename())
- if os.path.getsize(self._file.filename()) == 0:
- logger.debug('Initializing empty multistore file')
- # The multistore is empty so write out an empty file.
- self._data = {}
- self._write()
- elif not self._read_only or self._data is None:
- # Only refresh the data if we are read/write or we haven't
- # cached the data yet. If we are readonly, we assume is isn't
- # changing out from under us and that we only have to read it
- # once. This prevents us from whacking any new access keys that
- # we have cached in memory but were unable to write out.
- self._refresh_data_cache()
+ if os.path.getsize(self._file.filename()) == 0:
+ logger.debug('Initializing empty multistore file')
+ # The multistore is empty so write out an empty file.
+ self._data = {}
+ self._write()
+ elif not self._read_only or self._data is None:
+ # Only refresh the data if we are read/write or we haven't
+ # cached the data yet. If we are readonly, we assume is isn't
+ # changing out from under us and that we only have to read it
+ # once. This prevents us from whacking any new access keys that
+ # we have cached in memory but were unable to write out.
+ self._refresh_data_cache()
- def _unlock(self):
- """Release the lock on the multistore."""
- self._file.unlock_and_close()
- self._thread_lock.release()
+ def _unlock(self):
+ """Release the lock on the multistore."""
+ self._file.unlock_and_close()
+ self._thread_lock.release()
- def _locked_json_read(self):
- """Get the raw content of the multistore file.
+ def _locked_json_read(self):
+ """Get the raw content of the multistore file.
The multistore must be locked when this is called.
Returns:
The contents of the multistore decoded as JSON.
"""
- assert self._thread_lock.locked()
- self._file.file_handle().seek(0)
- return json.load(self._file.file_handle())
+ assert self._thread_lock.locked()
+ self._file.file_handle().seek(0)
+ return json.load(self._file.file_handle())
- def _locked_json_write(self, data):
- """Write a JSON serializable data structure to the multistore.
+ def _locked_json_write(self, data):
+ """Write a JSON serializable data structure to the multistore.
The multistore must be locked when this is called.
Args:
data: The data to be serialized and written.
"""
- assert self._thread_lock.locked()
- if self._read_only:
- return
- self._file.file_handle().seek(0)
- json.dump(data, self._file.file_handle(), sort_keys=True, indent=2, separators=(',', ': '))
- self._file.file_handle().truncate()
+ assert self._thread_lock.locked()
+ if self._read_only:
+ return
+ self._file.file_handle().seek(0)
+ json.dump(data, self._file.file_handle(), sort_keys=True, indent=2, separators=(',', ': '))
+ self._file.file_handle().truncate()
- def _refresh_data_cache(self):
- """Refresh the contents of the multistore.
+ def _refresh_data_cache(self):
+ """Refresh the contents of the multistore.
The multistore must be locked when this is called.
@@ -351,41 +351,41 @@ class _MultiStore(object):
NewerCredentialStoreError: Raised when a newer client has written the
store.
"""
- self._data = {}
- try:
- raw_data = self._locked_json_read()
- except Exception:
- logger.warn('Credential data store could not be loaded. '
+ self._data = {}
+ try:
+ raw_data = self._locked_json_read()
+ except Exception:
+ logger.warn('Credential data store could not be loaded. '
'Will ignore and overwrite.')
- return
+ return
- version = 0
- try:
- version = raw_data['file_version']
- except Exception:
- logger.warn('Missing version for credential data store. It may be '
+ version = 0
+ try:
+ version = raw_data['file_version']
+ except Exception:
+ logger.warn('Missing version for credential data store. It may be '
'corrupt or an old version. Overwriting.')
- if version > 1:
- raise NewerCredentialStoreError(
+ if version > 1:
+ raise NewerCredentialStoreError(
'Credential file has file_version of %d. '
'Only file_version of 1 is supported.' % version)
- credentials = []
- try:
- credentials = raw_data['data']
- except (TypeError, KeyError):
- pass
+ credentials = []
+ try:
+ credentials = raw_data['data']
+ except (TypeError, KeyError):
+ pass
- for cred_entry in credentials:
- try:
- (key, credential) = self._decode_credential_from_json(cred_entry)
- self._data[key] = credential
- except:
- # If something goes wrong loading a credential, just ignore it
- logger.info('Error decoding credential, skipping', exc_info=True)
+ for cred_entry in credentials:
+ try:
+ (key, credential) = self._decode_credential_from_json(cred_entry)
+ self._data[key] = credential
+ except:
+ # If something goes wrong loading a credential, just ignore it
+ logger.info('Error decoding credential, skipping', exc_info=True)
- def _decode_credential_from_json(self, cred_entry):
- """Load a credential from our JSON serialization.
+ def _decode_credential_from_json(self, cred_entry):
+ """Load a credential from our JSON serialization.
Args:
cred_entry: A dict entry from the data member of our format
@@ -394,36 +394,36 @@ class _MultiStore(object):
(key, cred) where the key is the key tuple and the cred is the
OAuth2Credential object.
"""
- raw_key = cred_entry['key']
- key = util.dict_to_tuple_key(raw_key)
- credential = None
- credential = Credentials.new_from_json(json.dumps(cred_entry['credential']))
- return (key, credential)
+ raw_key = cred_entry['key']
+ key = util.dict_to_tuple_key(raw_key)
+ credential = None
+ credential = Credentials.new_from_json(json.dumps(cred_entry['credential']))
+ return (key, credential)
- def _write(self):
- """Write the cached data back out.
+ def _write(self):
+ """Write the cached data back out.
The multistore must be locked.
"""
- raw_data = {'file_version': 1}
- raw_creds = []
- raw_data['data'] = raw_creds
- for (cred_key, cred) in self._data.items():
- raw_key = dict(cred_key)
- raw_cred = json.loads(cred.to_json())
- raw_creds.append({'key': raw_key, 'credential': raw_cred})
- self._locked_json_write(raw_data)
+ raw_data = {'file_version': 1}
+ raw_creds = []
+ raw_data['data'] = raw_creds
+ for (cred_key, cred) in self._data.items():
+ raw_key = dict(cred_key)
+ raw_cred = json.loads(cred.to_json())
+ raw_creds.append({'key': raw_key, 'credential': raw_cred})
+ self._locked_json_write(raw_data)
- def _get_all_credential_keys(self):
- """Gets all the registered credential keys in the multistore.
+ def _get_all_credential_keys(self):
+ """Gets all the registered credential keys in the multistore.
Returns:
A list of dictionaries corresponding to all the keys currently registered
"""
- return [dict(key) for key in self._data.keys()]
+ return [dict(key) for key in self._data.keys()]
- def _get_credential(self, key):
- """Get a credential from the multistore.
+ def _get_credential(self, key):
+ """Get a credential from the multistore.
The multistore must be locked.
@@ -433,10 +433,10 @@ class _MultiStore(object):
Returns:
The credential specified or None if not present
"""
- return self._data.get(key, None)
+ return self._data.get(key, None)
- def _update_credential(self, key, cred):
- """Update a credential and write the multistore.
+ def _update_credential(self, key, cred):
+ """Update a credential and write the multistore.
This must be called when the multistore is locked.
@@ -444,25 +444,25 @@ class _MultiStore(object):
key: The key used to retrieve the credential
cred: The OAuth2Credential to update/set
"""
- self._data[key] = cred
- self._write()
+ self._data[key] = cred
+ self._write()
- def _delete_credential(self, key):
- """Delete a credential and write the multistore.
+ def _delete_credential(self, key):
+ """Delete a credential and write the multistore.
This must be called when the multistore is locked.
Args:
key: The key used to retrieve the credential
"""
- try:
- del self._data[key]
- except KeyError:
- pass
- self._write()
+ try:
+ del self._data[key]
+ except KeyError:
+ pass
+ self._write()
- def _get_storage(self, key):
- """Get a Storage object to get/set a credential.
+ def _get_storage(self, key):
+ """Get a Storage object to get/set a credential.
This Storage is a 'view' into the multistore.
@@ -472,4 +472,4 @@ class _MultiStore(object):
Returns:
A Storage object that can be used to get/set this cred
"""
- return self._Storage(self, key)
+ return self._Storage(self, key)
diff --git a/oauth2client/old_run.py b/oauth2client/old_run.py
index 51db69b..2faf068 100644
--- a/oauth2client/old_run.py
+++ b/oauth2client/old_run.py
@@ -30,7 +30,6 @@ from oauth2client import util
from oauth2client.tools import ClientRedirectHandler
from oauth2client.tools import ClientRedirectServer
-
FLAGS = gflags.FLAGS
gflags.DEFINE_boolean('auth_local_webserver', True,
@@ -48,7 +47,7 @@ gflags.DEFINE_multi_int('auth_host_port', [8080, 8090],
@util.positional(2)
def run(flow, storage, http=None):
- """Core code for a command-line application.
+ """Core code for a command-line application.
The ``run()`` function is called from your application and runs
through all the steps to obtain credentials. It takes a ``Flow``
@@ -86,76 +85,76 @@ def run(flow, storage, http=None):
Returns:
Credentials, the obtained credential.
"""
- logging.warning('This function, oauth2client.tools.run(), and the use of '
+ logging.warning('This function, oauth2client.tools.run(), and the use of '
'the gflags library are deprecated and will be removed in a future '
'version of the library.')
- if FLAGS.auth_local_webserver:
- success = False
- port_number = 0
- for port in FLAGS.auth_host_port:
- port_number = port
- try:
- httpd = ClientRedirectServer((FLAGS.auth_host_name, port),
+ if FLAGS.auth_local_webserver:
+ success = False
+ port_number = 0
+ for port in FLAGS.auth_host_port:
+ port_number = port
+ try:
+ httpd = ClientRedirectServer((FLAGS.auth_host_name, port),
ClientRedirectHandler)
- except socket.error as e:
- pass
- else:
- success = True
- break
- FLAGS.auth_local_webserver = success
+ except socket.error as e:
+ pass
+ else:
+ success = True
+ break
+ FLAGS.auth_local_webserver = success
if not success:
- print('Failed to start a local webserver listening on either port 8080')
- print('or port 9090. Please check your firewall settings and locally')
- print('running programs that may be blocking or using those ports.')
- print()
- print('Falling back to --noauth_local_webserver and continuing with')
- print('authorization.')
- print()
+ print('Failed to start a local webserver listening on either port 8080')
+ print('or port 9090. Please check your firewall settings and locally')
+ print('running programs that may be blocking or using those ports.')
+ print()
+ print('Falling back to --noauth_local_webserver and continuing with')
+ print('authorization.')
+ print()
- if FLAGS.auth_local_webserver:
- oauth_callback = 'http://%s:%s/' % (FLAGS.auth_host_name, port_number)
- else:
- oauth_callback = client.OOB_CALLBACK_URN
- flow.redirect_uri = oauth_callback
- authorize_url = flow.step1_get_authorize_url()
-
- if FLAGS.auth_local_webserver:
- webbrowser.open(authorize_url, new=1, autoraise=True)
- print('Your browser has been opened to visit:')
- print()
- print(' ' + authorize_url)
- print()
- print('If your browser is on a different machine then exit and re-run')
- print('this application with the command-line parameter ')
- print()
- print(' --noauth_local_webserver')
- print()
- else:
- print('Go to the following link in your browser:')
- print()
- print(' ' + authorize_url)
- print()
-
- code = None
- if FLAGS.auth_local_webserver:
- httpd.handle_request()
- if 'error' in httpd.query_params:
- sys.exit('Authentication request was rejected.')
- if 'code' in httpd.query_params:
- code = httpd.query_params['code']
+ if FLAGS.auth_local_webserver:
+ oauth_callback = 'http://%s:%s/' % (FLAGS.auth_host_name, port_number)
else:
- print('Failed to find "code" in the query parameters of the redirect.')
- sys.exit('Try running with --noauth_local_webserver.')
- else:
- code = input('Enter verification code: ').strip()
+ oauth_callback = client.OOB_CALLBACK_URN
+ flow.redirect_uri = oauth_callback
+ authorize_url = flow.step1_get_authorize_url()
- try:
- credential = flow.step2_exchange(code, http=http)
- except client.FlowExchangeError as e:
- sys.exit('Authentication has failed: %s' % e)
+ if FLAGS.auth_local_webserver:
+ webbrowser.open(authorize_url, new=1, autoraise=True)
+ print('Your browser has been opened to visit:')
+ print()
+ print(' ' + authorize_url)
+ print()
+ print('If your browser is on a different machine then exit and re-run')
+ print('this application with the command-line parameter ')
+ print()
+ print(' --noauth_local_webserver')
+ print()
+ else:
+ print('Go to the following link in your browser:')
+ print()
+ print(' ' + authorize_url)
+ print()
- storage.put(credential)
- credential.set_store(storage)
- print('Authentication successful.')
+ code = None
+ if FLAGS.auth_local_webserver:
+ httpd.handle_request()
+ if 'error' in httpd.query_params:
+ sys.exit('Authentication request was rejected.')
+ if 'code' in httpd.query_params:
+ code = httpd.query_params['code']
+ else:
+ print('Failed to find "code" in the query parameters of the redirect.')
+ sys.exit('Try running with --noauth_local_webserver.')
+ else:
+ code = input('Enter verification code: ').strip()
- return credential
+ try:
+ credential = flow.step2_exchange(code, http=http)
+ except client.FlowExchangeError as e:
+ sys.exit('Authentication has failed: %s' % e)
+
+ storage.put(credential)
+ credential.set_store(storage)
+ print('Authentication successful.')
+
+ return credential
diff --git a/oauth2client/service_account.py b/oauth2client/service_account.py
index 0119be0..321ebf0 100644
--- a/oauth2client/service_account.py
+++ b/oauth2client/service_account.py
@@ -34,83 +34,83 @@ from oauth2client.client import AssertionCredentials
class _ServiceAccountCredentials(AssertionCredentials):
- """Class representing a service account (signed JWT) credential."""
+ """Class representing a service account (signed JWT) credential."""
- MAX_TOKEN_LIFETIME_SECS = 3600 # 1 hour in seconds
+ MAX_TOKEN_LIFETIME_SECS = 3600 # 1 hour in seconds
- def __init__(self, service_account_id, service_account_email, private_key_id,
+ def __init__(self, service_account_id, service_account_email, private_key_id,
private_key_pkcs8_text, scopes, user_agent=None,
token_uri=GOOGLE_TOKEN_URI, revoke_uri=GOOGLE_REVOKE_URI,
**kwargs):
- super(_ServiceAccountCredentials, self).__init__(
+ super(_ServiceAccountCredentials, self).__init__(
None, user_agent=user_agent, token_uri=token_uri, revoke_uri=revoke_uri)
- self._service_account_id = service_account_id
- self._service_account_email = service_account_email
- self._private_key_id = private_key_id
- self._private_key = _get_private_key(private_key_pkcs8_text)
- self._private_key_pkcs8_text = private_key_pkcs8_text
- self._scopes = util.scopes_to_string(scopes)
- self._user_agent = user_agent
- self._token_uri = token_uri
- self._revoke_uri = revoke_uri
- self._kwargs = kwargs
+ self._service_account_id = service_account_id
+ self._service_account_email = service_account_email
+ self._private_key_id = private_key_id
+ self._private_key = _get_private_key(private_key_pkcs8_text)
+ self._private_key_pkcs8_text = private_key_pkcs8_text
+ self._scopes = util.scopes_to_string(scopes)
+ self._user_agent = user_agent
+ self._token_uri = token_uri
+ self._revoke_uri = revoke_uri
+ self._kwargs = kwargs
- def _generate_assertion(self):
- """Generate the assertion that will be used in the request."""
+ def _generate_assertion(self):
+ """Generate the assertion that will be used in the request."""
- header = {
+ header = {
'alg': 'RS256',
'typ': 'JWT',
'kid': self._private_key_id
- }
+ }
- now = int(time.time())
- payload = {
+ now = int(time.time())
+ payload = {
'aud': self._token_uri,
'scope': self._scopes,
'iat': now,
'exp': now + _ServiceAccountCredentials.MAX_TOKEN_LIFETIME_SECS,
'iss': self._service_account_email
- }
- payload.update(self._kwargs)
+ }
+ payload.update(self._kwargs)
- first_segment = _urlsafe_b64encode(_json_encode(header))
- second_segment = _urlsafe_b64encode(_json_encode(payload))
- assertion_input = first_segment + b'.' + second_segment
+ first_segment = _urlsafe_b64encode(_json_encode(header))
+ second_segment = _urlsafe_b64encode(_json_encode(payload))
+ assertion_input = first_segment + b'.' + second_segment
- # Sign the assertion.
- rsa_bytes = rsa.pkcs1.sign(assertion_input, self._private_key, 'SHA-256')
- signature = base64.urlsafe_b64encode(rsa_bytes).rstrip(b'=')
+ # Sign the assertion.
+ rsa_bytes = rsa.pkcs1.sign(assertion_input, self._private_key, 'SHA-256')
+ signature = base64.urlsafe_b64encode(rsa_bytes).rstrip(b'=')
- return assertion_input + b'.' + signature
+ return assertion_input + b'.' + signature
- def sign_blob(self, blob):
- # Ensure that it is bytes
- blob = _to_bytes(blob, encoding='utf-8')
- return (self._private_key_id,
+ def sign_blob(self, blob):
+ # Ensure that it is bytes
+ blob = _to_bytes(blob, encoding='utf-8')
+ return (self._private_key_id,
rsa.pkcs1.sign(blob, self._private_key, 'SHA-256'))
- @property
- def service_account_email(self):
- return self._service_account_email
+ @property
+ def service_account_email(self):
+ return self._service_account_email
- @property
- def serialization_data(self):
- return {
+ @property
+ def serialization_data(self):
+ return {
'type': 'service_account',
'client_id': self._service_account_id,
'client_email': self._service_account_email,
'private_key_id': self._private_key_id,
'private_key': self._private_key_pkcs8_text
- }
+ }
- def create_scoped_required(self):
- return not self._scopes
+ def create_scoped_required(self):
+ return not self._scopes
- def create_scoped(self, scopes):
- return _ServiceAccountCredentials(self._service_account_id,
+ def create_scoped(self, scopes):
+ return _ServiceAccountCredentials(self._service_account_id,
self._service_account_email,
self._private_key_id,
self._private_key_pkcs8_text,
@@ -122,10 +122,10 @@ class _ServiceAccountCredentials(AssertionCredentials):
def _get_private_key(private_key_pkcs8_text):
- """Get an RSA private key object from a pkcs8 representation."""
- private_key_pkcs8_text = _to_bytes(private_key_pkcs8_text)
- der = rsa.pem.load_pem(private_key_pkcs8_text, 'PRIVATE KEY')
- asn1_private_key, _ = decoder.decode(der, asn1Spec=PrivateKeyInfo())
- return rsa.PrivateKey.load_pkcs1(
+ """Get an RSA private key object from a pkcs8 representation."""
+ private_key_pkcs8_text = _to_bytes(private_key_pkcs8_text)
+ der = rsa.pem.load_pem(private_key_pkcs8_text, 'PRIVATE KEY')
+ asn1_private_key, _ = decoder.decode(der, asn1Spec=PrivateKeyInfo())
+ return rsa.PrivateKey.load_pkcs1(
asn1_private_key.getComponentByName('privateKey').asOctets(),
format='DER')
diff --git a/oauth2client/tools.py b/oauth2client/tools.py
index 2caa134..bd77020 100644
--- a/oauth2client/tools.py
+++ b/oauth2client/tools.py
@@ -35,7 +35,6 @@ from six.moves import input
from oauth2client import client
from oauth2client import util
-
_CLIENT_SECRETS_MESSAGE = """WARNING: Please configure OAuth 2.0
To make this sample run you will need to populate the client_secrets.json file
@@ -47,22 +46,23 @@ with information from the APIs Console .
"""
+
def _CreateArgumentParser():
- try:
- import argparse
- except ImportError:
- return None
- parser = argparse.ArgumentParser(add_help=False)
- parser.add_argument('--auth_host_name', default='localhost',
+ try:
+ import argparse
+ except ImportError:
+ return None
+ parser = argparse.ArgumentParser(add_help=False)
+ parser.add_argument('--auth_host_name', default='localhost',
help='Hostname when running a local web server.')
- parser.add_argument('--noauth_local_webserver', action='store_true',
+ parser.add_argument('--noauth_local_webserver', action='store_true',
default=False, help='Do not run a local web server.')
- parser.add_argument('--auth_host_port', default=[8080, 8090], type=int,
+ parser.add_argument('--auth_host_port', default=[8080, 8090], type=int,
nargs='*', help='Port web server should listen on.')
- parser.add_argument('--logging_level', default='ERROR',
+ parser.add_argument('--logging_level', default='ERROR',
choices=['DEBUG', 'INFO', 'WARNING', 'ERROR', 'CRITICAL'],
help='Set the logging level of detail.')
- return parser
+ return parser
# argparser is an ArgumentParser that contains command-line options expected
# by tools.run(). Pass it in as part of the 'parents' argument to your own
@@ -71,45 +71,45 @@ argparser = _CreateArgumentParser()
class ClientRedirectServer(BaseHTTPServer.HTTPServer):
- """A server to handle OAuth 2.0 redirects back to localhost.
+ """A server to handle OAuth 2.0 redirects back to localhost.
Waits for a single request and parses the query parameters
into query_params and then stops serving.
"""
- query_params = {}
+ query_params = {}
class ClientRedirectHandler(BaseHTTPServer.BaseHTTPRequestHandler):
- """A handler for OAuth 2.0 redirects back to localhost.
+ """A handler for OAuth 2.0 redirects back to localhost.
Waits for a single request and parses the query parameters
into the servers query_params and then stops serving.
"""
- def do_GET(self):
- """Handle a GET request.
+ def do_GET(self):
+ """Handle a GET request.
Parses the query parameters and prints a message
if the flow has completed. Note that we can't detect
if an error occurred.
"""
- self.send_response(200)
- self.send_header("Content-type", "text/html")
- self.end_headers()
- query = self.path.split('?', 1)[-1]
- query = dict(urllib.parse.parse_qsl(query))
- self.server.query_params = query
- self.wfile.write(b"Authentication Status")
- self.wfile.write(b"The authentication flow has completed.
")
- self.wfile.write(b"")
+ self.send_response(200)
+ self.send_header("Content-type", "text/html")
+ self.end_headers()
+ query = self.path.split('?', 1)[-1]
+ query = dict(urllib.parse.parse_qsl(query))
+ self.server.query_params = query
+ self.wfile.write(b"Authentication Status")
+ self.wfile.write(b"The authentication flow has completed.
")
+ self.wfile.write(b"")
- def log_message(self, format, *args):
- """Do not log messages to stdout while running as command line program."""
+ def log_message(self, format, *args):
+ """Do not log messages to stdout while running as command line program."""
@util.positional(3)
def run_flow(flow, storage, flags, http=None):
- """Core code for a command-line application.
+ """Core code for a command-line application.
The ``run()`` function is called from your application and runs
through all the steps to obtain credentials. It takes a ``Flow``
@@ -159,91 +159,91 @@ def run_flow(flow, storage, flags, http=None):
Returns:
Credentials, the obtained credential.
"""
- logging.getLogger().setLevel(getattr(logging, flags.logging_level))
- if not flags.noauth_local_webserver:
- success = False
- port_number = 0
- for port in flags.auth_host_port:
- port_number = port
- try:
- httpd = ClientRedirectServer((flags.auth_host_name, port),
+ logging.getLogger().setLevel(getattr(logging, flags.logging_level))
+ if not flags.noauth_local_webserver:
+ success = False
+ port_number = 0
+ for port in flags.auth_host_port:
+ port_number = port
+ try:
+ httpd = ClientRedirectServer((flags.auth_host_name, port),
ClientRedirectHandler)
- except socket.error:
- pass
- else:
- success = True
- break
- flags.noauth_local_webserver = not success
+ except socket.error:
+ pass
+ else:
+ success = True
+ break
+ flags.noauth_local_webserver = not success
if not success:
- print('Failed to start a local webserver listening on either port 8080')
- print('or port 9090. Please check your firewall settings and locally')
- print('running programs that may be blocking or using those ports.')
- print()
- print('Falling back to --noauth_local_webserver and continuing with')
- print('authorization.')
- print()
+ print('Failed to start a local webserver listening on either port 8080')
+ print('or port 9090. Please check your firewall settings and locally')
+ print('running programs that may be blocking or using those ports.')
+ print()
+ print('Falling back to --noauth_local_webserver and continuing with')
+ print('authorization.')
+ print()
- if not flags.noauth_local_webserver:
- oauth_callback = 'http://%s:%s/' % (flags.auth_host_name, port_number)
- else:
- oauth_callback = client.OOB_CALLBACK_URN
- flow.redirect_uri = oauth_callback
- authorize_url = flow.step1_get_authorize_url()
-
- if not flags.noauth_local_webserver:
- import webbrowser
- webbrowser.open(authorize_url, new=1, autoraise=True)
- print('Your browser has been opened to visit:')
- print()
- print(' ' + authorize_url)
- print()
- print('If your browser is on a different machine then exit and re-run this')
- print('application with the command-line parameter ')
- print()
- print(' --noauth_local_webserver')
- print()
- else:
- print('Go to the following link in your browser:')
- print()
- print(' ' + authorize_url)
- print()
-
- code = None
- if not flags.noauth_local_webserver:
- httpd.handle_request()
- if 'error' in httpd.query_params:
- sys.exit('Authentication request was rejected.')
- if 'code' in httpd.query_params:
- code = httpd.query_params['code']
+ if not flags.noauth_local_webserver:
+ oauth_callback = 'http://%s:%s/' % (flags.auth_host_name, port_number)
else:
- print('Failed to find "code" in the query parameters of the redirect.')
- sys.exit('Try running with --noauth_local_webserver.')
- else:
- code = input('Enter verification code: ').strip()
+ oauth_callback = client.OOB_CALLBACK_URN
+ flow.redirect_uri = oauth_callback
+ authorize_url = flow.step1_get_authorize_url()
- try:
- credential = flow.step2_exchange(code, http=http)
- except client.FlowExchangeError as e:
- sys.exit('Authentication has failed: %s' % e)
+ if not flags.noauth_local_webserver:
+ import webbrowser
+ webbrowser.open(authorize_url, new=1, autoraise=True)
+ print('Your browser has been opened to visit:')
+ print()
+ print(' ' + authorize_url)
+ print()
+ print('If your browser is on a different machine then exit and re-run this')
+ print('application with the command-line parameter ')
+ print()
+ print(' --noauth_local_webserver')
+ print()
+ else:
+ print('Go to the following link in your browser:')
+ print()
+ print(' ' + authorize_url)
+ print()
- storage.put(credential)
- credential.set_store(storage)
- print('Authentication successful.')
+ code = None
+ if not flags.noauth_local_webserver:
+ httpd.handle_request()
+ if 'error' in httpd.query_params:
+ sys.exit('Authentication request was rejected.')
+ if 'code' in httpd.query_params:
+ code = httpd.query_params['code']
+ else:
+ print('Failed to find "code" in the query parameters of the redirect.')
+ sys.exit('Try running with --noauth_local_webserver.')
+ else:
+ code = input('Enter verification code: ').strip()
- return credential
+ try:
+ credential = flow.step2_exchange(code, http=http)
+ except client.FlowExchangeError as e:
+ sys.exit('Authentication has failed: %s' % e)
+
+ storage.put(credential)
+ credential.set_store(storage)
+ print('Authentication successful.')
+
+ return credential
def message_if_missing(filename):
- """Helpful message to display if the CLIENT_SECRETS file is missing."""
+ """Helpful message to display if the CLIENT_SECRETS file is missing."""
- return _CLIENT_SECRETS_MESSAGE % filename
+ return _CLIENT_SECRETS_MESSAGE % filename
try:
- from oauth2client.old_run import run
- from oauth2client.old_run import FLAGS
+ from oauth2client.old_run import run
+ from oauth2client.old_run import FLAGS
except ImportError:
- def run(*args, **kwargs):
- raise NotImplementedError(
+ def run(*args, **kwargs):
+ raise NotImplementedError(
'The gflags library must be installed to use tools.run(). '
'Please install gflags or preferrably switch to using '
'tools.run_flow().')
diff --git a/oauth2client/util.py b/oauth2client/util.py
index 94c2523..75abc03 100644
--- a/oauth2client/util.py
+++ b/oauth2client/util.py
@@ -38,7 +38,6 @@ import types
import six
from six.moves import urllib
-
logger = logging.getLogger(__name__)
POSITIONAL_WARNING = 'WARNING'
@@ -49,8 +48,9 @@ POSITIONAL_SET = frozenset([POSITIONAL_WARNING, POSITIONAL_EXCEPTION,
positional_parameters_enforcement = POSITIONAL_WARNING
+
def positional(max_positional_args):
- """A decorator to declare that only the first N arguments my be positional.
+ """A decorator to declare that only the first N arguments my be positional.
This decorator makes it easy to support Python 3 style keyword-only
parameters. For example, in Python 3 it is possible to write::
@@ -119,33 +119,34 @@ def positional(max_positional_args):
POSITIONAL_EXCEPTION.
"""
- def positional_decorator(wrapped):
- @functools.wraps(wrapped)
- def positional_wrapper(*args, **kwargs):
- if len(args) > max_positional_args:
- plural_s = ''
- if max_positional_args != 1:
- plural_s = 's'
- message = '%s() takes at most %d positional argument%s (%d given)' % (
- wrapped.__name__, max_positional_args, plural_s, len(args))
- if positional_parameters_enforcement == POSITIONAL_EXCEPTION:
- raise TypeError(message)
- elif positional_parameters_enforcement == POSITIONAL_WARNING:
- logger.warning(message)
- else: # IGNORE
- pass
- return wrapped(*args, **kwargs)
- return positional_wrapper
- if isinstance(max_positional_args, six.integer_types):
- return positional_decorator
- else:
- args, _, _, defaults = inspect.getargspec(max_positional_args)
- return positional(len(args) - len(defaults))(max_positional_args)
+ def positional_decorator(wrapped):
+ @functools.wraps(wrapped)
+ def positional_wrapper(*args, **kwargs):
+ if len(args) > max_positional_args:
+ plural_s = ''
+ if max_positional_args != 1:
+ plural_s = 's'
+ message = '%s() takes at most %d positional argument%s (%d given)' % (
+ wrapped.__name__, max_positional_args, plural_s, len(args))
+ if positional_parameters_enforcement == POSITIONAL_EXCEPTION:
+ raise TypeError(message)
+ elif positional_parameters_enforcement == POSITIONAL_WARNING:
+ logger.warning(message)
+ else: # IGNORE
+ pass
+ return wrapped(*args, **kwargs)
+ return positional_wrapper
+
+ if isinstance(max_positional_args, six.integer_types):
+ return positional_decorator
+ else:
+ args, _, _, defaults = inspect.getargspec(max_positional_args)
+ return positional(len(args) - len(defaults))(max_positional_args)
def scopes_to_string(scopes):
- """Converts scope value to a string.
+ """Converts scope value to a string.
If scopes is a string then it is simply passed through. If scopes is an
iterable then a string is returned that is all the individual scopes
@@ -157,14 +158,14 @@ def scopes_to_string(scopes):
Returns:
The scopes formatted as a single string.
"""
- if isinstance(scopes, six.string_types):
- return scopes
- else:
- return ' '.join(scopes)
+ if isinstance(scopes, six.string_types):
+ return scopes
+ else:
+ return ' '.join(scopes)
def string_to_scopes(scopes):
- """Converts stringifed scope value to a list.
+ """Converts stringifed scope value to a list.
If scopes is a list then it is simply passed through. If scopes is an
string then a list of each individual scope is returned.
@@ -175,16 +176,16 @@ def string_to_scopes(scopes):
Returns:
The scopes in a list.
"""
- if not scopes:
- return []
- if isinstance(scopes, six.string_types):
- return scopes.split(' ')
- else:
- return scopes
+ if not scopes:
+ return []
+ if isinstance(scopes, six.string_types):
+ return scopes.split(' ')
+ else:
+ return scopes
def dict_to_tuple_key(dictionary):
- """Converts a dictionary to a tuple that can be used as an immutable key.
+ """Converts a dictionary to a tuple that can be used as an immutable key.
The resulting key is always sorted so that logically equivalent dictionaries
always produce an identical tuple for a key.
@@ -195,11 +196,11 @@ def dict_to_tuple_key(dictionary):
Returns:
A tuple representing the dictionary in it's naturally sorted ordering.
"""
- return tuple(sorted(dictionary.items()))
+ return tuple(sorted(dictionary.items()))
def _add_query_parameter(url, name, value):
- """Adds a query parameter to a url.
+ """Adds a query parameter to a url.
Replaces the current value if it already exists in the URL.
@@ -211,11 +212,11 @@ def _add_query_parameter(url, name, value):
Returns:
Updated query parameter. Does not update the url if value is None.
"""
- if value is None:
- return url
- else:
- parsed = list(urllib.parse.urlparse(url))
- q = dict(urllib.parse.parse_qsl(parsed[4]))
- q[name] = value
- parsed[4] = urllib.parse.urlencode(q)
- return urllib.parse.urlunparse(parsed)
+ if value is None:
+ return url
+ else:
+ parsed = list(urllib.parse.urlparse(url))
+ q = dict(urllib.parse.parse_qsl(parsed[4]))
+ q[name] = value
+ parsed[4] = urllib.parse.urlencode(q)
+ return urllib.parse.urlunparse(parsed)
diff --git a/oauth2client/xsrfutil.py b/oauth2client/xsrfutil.py
index 5739dcf..9cd59d6 100644
--- a/oauth2client/xsrfutil.py
+++ b/oauth2client/xsrfutil.py
@@ -20,7 +20,6 @@ __authors__ = [
'"Joe Gregorio" ',
]
-
import base64
import hmac
import time
@@ -28,13 +27,11 @@ import time
import six
from oauth2client import util
-
# Delimiter character
DELIMITER = b':'
-
# 1 hour in seconds
-DEFAULT_TIMEOUT_SECS = 1*60*60
+DEFAULT_TIMEOUT_SECS = 1 * 60 * 60
def _force_bytes(s):
@@ -48,7 +45,7 @@ def _force_bytes(s):
@util.positional(2)
def generate_token(key, user_id, action_id="", when=None):
- """Generates a URL-safe token for the given user, action, time tuple.
+ """Generates a URL-safe token for the given user, action, time tuple.
Args:
key: secret key to use.
@@ -61,22 +58,22 @@ def generate_token(key, user_id, action_id="", when=None):
Returns:
A string XSRF protection token.
"""
- when = _force_bytes(when or int(time.time()))
- digester = hmac.new(_force_bytes(key))
- digester.update(_force_bytes(user_id))
- digester.update(DELIMITER)
- digester.update(_force_bytes(action_id))
- digester.update(DELIMITER)
- digester.update(when)
- digest = digester.digest()
+ when = _force_bytes(when or int(time.time()))
+ digester = hmac.new(_force_bytes(key))
+ digester.update(_force_bytes(user_id))
+ digester.update(DELIMITER)
+ digester.update(_force_bytes(action_id))
+ digester.update(DELIMITER)
+ digester.update(when)
+ digest = digester.digest()
- token = base64.urlsafe_b64encode(digest + DELIMITER + when)
- return token
+ token = base64.urlsafe_b64encode(digest + DELIMITER + when)
+ return token
@util.positional(3)
def validate_token(key, token, user_id, action_id="", current_time=None):
- """Validates that the given token authorizes the user for the action.
+ """Validates that the given token authorizes the user for the action.
Tokens are invalid if the time of issue is too old or if the token
does not match what generateToken outputs (i.e. the token was forged).
@@ -92,27 +89,27 @@ def validate_token(key, token, user_id, action_id="", current_time=None):
A boolean - True if the user is authorized for the action, False
otherwise.
"""
- if not token:
- return False
- try:
- decoded = base64.urlsafe_b64decode(token)
- token_time = int(decoded.split(DELIMITER)[-1])
- except (TypeError, ValueError):
- return False
- if current_time is None:
- current_time = time.time()
- # If the token is too old it's not valid.
- if current_time - token_time > DEFAULT_TIMEOUT_SECS:
- return False
+ if not token:
+ return False
+ try:
+ decoded = base64.urlsafe_b64decode(token)
+ token_time = int(decoded.split(DELIMITER)[-1])
+ except (TypeError, ValueError):
+ return False
+ if current_time is None:
+ current_time = time.time()
+ # If the token is too old it's not valid.
+ if current_time - token_time > DEFAULT_TIMEOUT_SECS:
+ return False
- # The given token should match the generated one with the same time.
- expected_token = generate_token(key, user_id, action_id=action_id,
+ # The given token should match the generated one with the same time.
+ expected_token = generate_token(key, user_id, action_id=action_id,
when=token_time)
- if len(token) != len(expected_token):
- return False
+ if len(token) != len(expected_token):
+ return False
- # Perform constant time comparison to avoid timing attacks
- different = 0
- for x, y in zip(bytearray(token), bytearray(expected_token)):
- different |= x ^ y
- return not different
+ # Perform constant time comparison to avoid timing attacks
+ different = 0
+ for x, y in zip(bytearray(token), bytearray(expected_token)):
+ different |= x ^ y
+ return not different
diff --git a/tests/__init__.py b/tests/__init__.py
index 7913e6f..9a4adcb 100644
--- a/tests/__init__.py
+++ b/tests/__init__.py
@@ -16,6 +16,7 @@ __author__ = 'afshar@google.com (Ali Afshar)'
import oauth2client.util
+
def setup_package():
- """Run on testing package."""
- oauth2client.util.positional_parameters_enforcement = 'EXCEPTION'
+ """Run on testing package."""
+ oauth2client.util.positional_parameters_enforcement = 'EXCEPTION'
diff --git a/tests/http_mock.py b/tests/http_mock.py
index 4040d48..e8677fc 100644
--- a/tests/http_mock.py
+++ b/tests/http_mock.py
@@ -20,47 +20,45 @@ import httplib2
# TODO(craigcitro): Find a cleaner way to share this code with googleapiclient.
-
class HttpMock(object):
- """Mock of httplib2.Http"""
+ """Mock of httplib2.Http"""
- def __init__(self, filename=None, headers=None):
- """
+ def __init__(self, filename=None, headers=None):
+ """
Args:
filename: string, absolute filename to read response from
headers: dict, header to return with response
"""
- if headers is None:
- headers = {'status': '200 OK'}
- if filename:
- f = file(filename, 'r')
- self.data = f.read()
- f.close()
- else:
- self.data = None
- self.response_headers = headers
- self.headers = None
- self.uri = None
- self.method = None
- self.body = None
- self.headers = None
+ if headers is None:
+ headers = {'status': '200 OK'}
+ if filename:
+ f = file(filename, 'r')
+ self.data = f.read()
+ f.close()
+ else:
+ self.data = None
+ self.response_headers = headers
+ self.headers = None
+ self.uri = None
+ self.method = None
+ self.body = None
+ self.headers = None
-
- def request(self, uri,
+ def request(self, uri,
method='GET',
body=None,
headers=None,
redirections=1,
connection_type=None):
- self.uri = uri
- self.method = method
- self.body = body
- self.headers = headers
- return httplib2.Response(self.response_headers), self.data
+ self.uri = uri
+ self.method = method
+ self.body = body
+ self.headers = headers
+ return httplib2.Response(self.response_headers), self.data
class HttpMockSequence(object):
- """Mock of httplib2.Http
+ """Mock of httplib2.Http
Mocks a sequence of calls to request returning different responses for each
call. Create an instance initialized with the desired response headers
@@ -83,33 +81,33 @@ class HttpMockSequence(object):
'echo_request_uri' means return the request uri in the response body
"""
- def __init__(self, iterable):
- """
+ def __init__(self, iterable):
+ """
Args:
iterable: iterable, a sequence of pairs of (headers, body)
"""
- self._iterable = iterable
- self.follow_redirects = True
- self.requests = []
+ self._iterable = iterable
+ self.follow_redirects = True
+ self.requests = []
- def request(self, uri,
+ def request(self, uri,
method='GET',
body=None,
headers=None,
redirections=1,
connection_type=None):
- resp, content = self._iterable.pop(0)
- self.requests.append({'uri': uri, 'body': body, 'headers': headers})
- # Read any underlying stream before sending the request.
- body_stream_content = body.read() if getattr(body, 'read', None) else None
- if content == 'echo_request_headers':
- content = headers
- elif content == 'echo_request_headers_as_json':
- content = json.dumps(headers)
- elif content == 'echo_request_body':
- content = body if body_stream_content is None else body_stream_content
- elif content == 'echo_request_uri':
- content = uri
- elif not isinstance(content, bytes):
- raise TypeError('http content should be bytes: %r' % (content,))
- return httplib2.Response(resp), content
+ resp, content = self._iterable.pop(0)
+ self.requests.append({'uri': uri, 'body': body, 'headers': headers})
+ # Read any underlying stream before sending the request.
+ body_stream_content = body.read() if getattr(body, 'read', None) else None
+ if content == 'echo_request_headers':
+ content = headers
+ elif content == 'echo_request_headers_as_json':
+ content = json.dumps(headers)
+ elif content == 'echo_request_body':
+ content = body if body_stream_content is None else body_stream_content
+ elif content == 'echo_request_uri':
+ content = uri
+ elif not isinstance(content, bytes):
+ raise TypeError('http content should be bytes: %r' % (content, ))
+ return httplib2.Response(resp), content
diff --git a/tests/test__helpers.py b/tests/test__helpers.py
index f47c8b7..eb75d35 100644
--- a/tests/test__helpers.py
+++ b/tests/test__helpers.py
@@ -25,93 +25,93 @@ from oauth2client._helpers import _urlsafe_b64encode
class Test__parse_pem_key(unittest.TestCase):
- def test_valid_input(self):
- test_string = b'1234-----BEGIN FOO BAR BAZ'
- result = _parse_pem_key(test_string)
- self.assertEqual(result, test_string[4:])
+ def test_valid_input(self):
+ test_string = b'1234-----BEGIN FOO BAR BAZ'
+ result = _parse_pem_key(test_string)
+ self.assertEqual(result, test_string[4:])
- def test_bad_input(self):
- test_string = b'DOES NOT HAVE DASHES'
- result = _parse_pem_key(test_string)
- self.assertEqual(result, None)
+ def test_bad_input(self):
+ test_string = b'DOES NOT HAVE DASHES'
+ result = _parse_pem_key(test_string)
+ self.assertEqual(result, None)
class Test__json_encode(unittest.TestCase):
- def test_dictionary_input(self):
- # Use only a single key since dictionary hash order
- # is non-deterministic.
- data = {u'foo': 10}
- result = _json_encode(data)
- self.assertEqual(result, """{"foo":10}""")
+ def test_dictionary_input(self):
+ # Use only a single key since dictionary hash order
+ # is non-deterministic.
+ data = {u'foo': 10}
+ result = _json_encode(data)
+ self.assertEqual(result, """{"foo":10}""")
- def test_list_input(self):
- data = [42, 1337]
- result = _json_encode(data)
- self.assertEqual(result, """[42,1337]""")
+ def test_list_input(self):
+ data = [42, 1337]
+ result = _json_encode(data)
+ self.assertEqual(result, """[42,1337]""")
class Test__to_bytes(unittest.TestCase):
- def test_with_bytes(self):
- value = b'bytes-val'
- self.assertEqual(_to_bytes(value), value)
+ def test_with_bytes(self):
+ value = b'bytes-val'
+ self.assertEqual(_to_bytes(value), value)
- def test_with_unicode(self):
- value = u'string-val'
- encoded_value = b'string-val'
- self.assertEqual(_to_bytes(value), encoded_value)
+ def test_with_unicode(self):
+ value = u'string-val'
+ encoded_value = b'string-val'
+ self.assertEqual(_to_bytes(value), encoded_value)
- def test_with_nonstring_type(self):
- value = object()
- self.assertRaises(ValueError, _to_bytes, value)
+ def test_with_nonstring_type(self):
+ value = object()
+ self.assertRaises(ValueError, _to_bytes, value)
class Test__from_bytes(unittest.TestCase):
- def test_with_unicode(self):
- value = u'bytes-val'
- self.assertEqual(_from_bytes(value), value)
+ def test_with_unicode(self):
+ value = u'bytes-val'
+ self.assertEqual(_from_bytes(value), value)
- def test_with_bytes(self):
- value = b'string-val'
- decoded_value = u'string-val'
- self.assertEqual(_from_bytes(value), decoded_value)
+ def test_with_bytes(self):
+ value = b'string-val'
+ decoded_value = u'string-val'
+ self.assertEqual(_from_bytes(value), decoded_value)
- def test_with_nonstring_type(self):
- value = object()
- self.assertRaises(ValueError, _from_bytes, value)
+ def test_with_nonstring_type(self):
+ value = object()
+ self.assertRaises(ValueError, _from_bytes, value)
class Test__urlsafe_b64encode(unittest.TestCase):
- DEADBEEF_ENCODED = b'ZGVhZGJlZWY'
+ DEADBEEF_ENCODED = b'ZGVhZGJlZWY'
- def test_valid_input_bytes(self):
- test_string = b'deadbeef'
- result = _urlsafe_b64encode(test_string)
- self.assertEqual(result, self.DEADBEEF_ENCODED)
+ def test_valid_input_bytes(self):
+ test_string = b'deadbeef'
+ result = _urlsafe_b64encode(test_string)
+ self.assertEqual(result, self.DEADBEEF_ENCODED)
- def test_valid_input_unicode(self):
- test_string = u'deadbeef'
- result = _urlsafe_b64encode(test_string)
- self.assertEqual(result, self.DEADBEEF_ENCODED)
+ def test_valid_input_unicode(self):
+ test_string = u'deadbeef'
+ result = _urlsafe_b64encode(test_string)
+ self.assertEqual(result, self.DEADBEEF_ENCODED)
class Test__urlsafe_b64decode(unittest.TestCase):
- def test_valid_input_bytes(self):
- test_string = b'ZGVhZGJlZWY'
- result = _urlsafe_b64decode(test_string)
- self.assertEqual(result, b'deadbeef')
+ def test_valid_input_bytes(self):
+ test_string = b'ZGVhZGJlZWY'
+ result = _urlsafe_b64decode(test_string)
+ self.assertEqual(result, b'deadbeef')
- def test_valid_input_unicode(self):
- test_string = b'ZGVhZGJlZWY'
- result = _urlsafe_b64decode(test_string)
- self.assertEqual(result, b'deadbeef')
+ def test_valid_input_unicode(self):
+ test_string = b'ZGVhZGJlZWY'
+ result = _urlsafe_b64decode(test_string)
+ self.assertEqual(result, b'deadbeef')
- def test_bad_input(self):
- import binascii
- bad_string = b'+'
- self.assertRaises((TypeError, binascii.Error),
+ def test_bad_input(self):
+ import binascii
+ bad_string = b'+'
+ self.assertRaises((TypeError, binascii.Error),
_urlsafe_b64decode, bad_string)
diff --git a/tests/test__pycrypto_crypt.py b/tests/test__pycrypto_crypt.py
index 3097235..9895a62 100644
--- a/tests/test__pycrypto_crypt.py
+++ b/tests/test__pycrypto_crypt.py
@@ -22,42 +22,42 @@ from oauth2client.crypt import PyCryptoVerifier
class TestPyCryptoVerifier(unittest.TestCase):
- PUBLIC_KEY_FILENAME = os.path.join(os.path.dirname(__file__),
+ PUBLIC_KEY_FILENAME = os.path.join(os.path.dirname(__file__),
'data', 'publickey.pem')
- PRIVATE_KEY_FILENAME = os.path.join(os.path.dirname(__file__),
+ PRIVATE_KEY_FILENAME = os.path.join(os.path.dirname(__file__),
'data', 'privatekey.pem')
- def _load_public_key_bytes(self):
- with open(self.PUBLIC_KEY_FILENAME, 'rb') as fh:
- return fh.read()
+ def _load_public_key_bytes(self):
+ with open(self.PUBLIC_KEY_FILENAME, 'rb') as fh:
+ return fh.read()
- def _load_private_key_bytes(self):
- with open(self.PRIVATE_KEY_FILENAME, 'rb') as fh:
- return fh.read()
+ def _load_private_key_bytes(self):
+ with open(self.PRIVATE_KEY_FILENAME, 'rb') as fh:
+ return fh.read()
- def test_verify_success(self):
- to_sign = b'foo'
- signer = PyCryptoSigner.from_string(self._load_private_key_bytes())
- actual_signature = signer.sign(to_sign)
+ def test_verify_success(self):
+ to_sign = b'foo'
+ signer = PyCryptoSigner.from_string(self._load_private_key_bytes())
+ actual_signature = signer.sign(to_sign)
- verifier = PyCryptoVerifier.from_string(self._load_public_key_bytes(),
+ verifier = PyCryptoVerifier.from_string(self._load_public_key_bytes(),
is_x509_cert=True)
- self.assertTrue(verifier.verify(to_sign, actual_signature))
+ self.assertTrue(verifier.verify(to_sign, actual_signature))
- def test_verify_failure(self):
- verifier = PyCryptoVerifier.from_string(self._load_public_key_bytes(),
+ def test_verify_failure(self):
+ verifier = PyCryptoVerifier.from_string(self._load_public_key_bytes(),
is_x509_cert=True)
- bad_signature = b''
- self.assertFalse(verifier.verify(b'foo', bad_signature))
+ bad_signature = b''
+ self.assertFalse(verifier.verify(b'foo', bad_signature))
- def test_verify_bad_key(self):
- verifier = PyCryptoVerifier.from_string(self._load_public_key_bytes(),
+ def test_verify_bad_key(self):
+ verifier = PyCryptoVerifier.from_string(self._load_public_key_bytes(),
is_x509_cert=True)
- bad_signature = b''
- self.assertFalse(verifier.verify(b'foo', bad_signature))
+ bad_signature = b''
+ self.assertFalse(verifier.verify(b'foo', bad_signature))
- def test_from_string_unicode_key(self):
- public_key = self._load_public_key_bytes()
- public_key = public_key.decode('utf-8')
- verifier = PyCryptoVerifier.from_string(public_key, is_x509_cert=True)
- self.assertTrue(isinstance(verifier, PyCryptoVerifier))
+ def test_from_string_unicode_key(self):
+ public_key = self._load_public_key_bytes()
+ public_key = public_key.decode('utf-8')
+ verifier = PyCryptoVerifier.from_string(public_key, is_x509_cert=True)
+ self.assertTrue(isinstance(verifier, PyCryptoVerifier))
diff --git a/tests/test_appengine.py b/tests/test_appengine.py
index 83199df..dfa6831 100644
--- a/tests/test_appengine.py
+++ b/tests/test_appengine.py
@@ -14,7 +14,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-
"""Discovery document tests
Unit tests for objects created from discovery documents.
@@ -68,467 +67,463 @@ from oauth2client.client import flow_from_clientsecrets
from oauth2client.client import save_to_well_known_file
from webtest import TestApp
-
DATA_DIR = os.path.join(os.path.dirname(__file__), 'data')
def datafile(filename):
- return os.path.join(DATA_DIR, filename)
+ return os.path.join(DATA_DIR, filename)
def load_and_cache(existing_file, fakename, cache_mock):
- client_type, client_info = _loadfile(datafile(existing_file))
- cache_mock.cache[fakename] = {client_type: client_info}
+ client_type, client_info = _loadfile(datafile(existing_file))
+ cache_mock.cache[fakename] = {client_type: client_info}
class CacheMock(object):
- def __init__(self):
- self.cache = {}
+ def __init__(self):
+ self.cache = {}
- def get(self, key, namespace=''):
- # ignoring namespace for easier testing
- return self.cache.get(key, None)
+ def get(self, key, namespace=''):
+ # ignoring namespace for easier testing
+ return self.cache.get(key, None)
- def set(self, key, value, namespace=''):
- # ignoring namespace for easier testing
- self.cache[key] = value
+ def set(self, key, value, namespace=''):
+ # ignoring namespace for easier testing
+ self.cache[key] = value
class UserMock(object):
- """Mock the app engine user service"""
+ """Mock the app engine user service"""
- def __call__(self):
- return self
+ def __call__(self):
+ return self
- def user_id(self):
- return 'foo_user'
+ def user_id(self):
+ return 'foo_user'
class UserNotLoggedInMock(object):
- """Mock the app engine user service"""
+ """Mock the app engine user service"""
- def __call__(self):
- return None
+ def __call__(self):
+ return None
class Http2Mock(object):
- """Mock httplib2.Http"""
- status = 200
- content = {
+ """Mock httplib2.Http"""
+ status = 200
+ content = {
'access_token': 'foo_access_token',
'refresh_token': 'foo_refresh_token',
'expires_in': 3600,
'extra': 'value',
}
- def request(self, token_uri, method, body, headers, *args, **kwargs):
- self.body = body
- self.headers = headers
- return (self, json.dumps(self.content))
+ def request(self, token_uri, method, body, headers, *args, **kwargs):
+ self.body = body
+ self.headers = headers
+ return (self, json.dumps(self.content))
class TestAppAssertionCredentials(unittest.TestCase):
- account_name = "service_account_name@appspot.com"
- signature = "signature"
+ account_name = "service_account_name@appspot.com"
+ signature = "signature"
+ class AppIdentityStubImpl(apiproxy_stub.APIProxyStub):
- class AppIdentityStubImpl(apiproxy_stub.APIProxyStub):
-
- def __init__(self):
- super(TestAppAssertionCredentials.AppIdentityStubImpl, self).__init__(
+ def __init__(self):
+ super(TestAppAssertionCredentials.AppIdentityStubImpl, self).__init__(
'app_identity_service')
- def _Dynamic_GetAccessToken(self, request, response):
- response.set_access_token('a_token_123')
- response.set_expiration_time(time.time() + 1800)
+ def _Dynamic_GetAccessToken(self, request, response):
+ response.set_access_token('a_token_123')
+ response.set_expiration_time(time.time() + 1800)
+ class ErroringAppIdentityStubImpl(apiproxy_stub.APIProxyStub):
- class ErroringAppIdentityStubImpl(apiproxy_stub.APIProxyStub):
-
- def __init__(self):
- super(TestAppAssertionCredentials.ErroringAppIdentityStubImpl, self).__init__(
+ def __init__(self):
+ super(TestAppAssertionCredentials.ErroringAppIdentityStubImpl, self).__init__(
'app_identity_service')
- def _Dynamic_GetAccessToken(self, request, response):
- raise app_identity.BackendDeadlineExceeded()
+ def _Dynamic_GetAccessToken(self, request, response):
+ raise app_identity.BackendDeadlineExceeded()
- def test_raise_correct_type_of_exception(self):
- app_identity_stub = self.ErroringAppIdentityStubImpl()
- apiproxy_stub_map.apiproxy = apiproxy_stub_map.APIProxyStubMap()
- apiproxy_stub_map.apiproxy.RegisterStub('app_identity_service',
+ def test_raise_correct_type_of_exception(self):
+ app_identity_stub = self.ErroringAppIdentityStubImpl()
+ apiproxy_stub_map.apiproxy = apiproxy_stub_map.APIProxyStubMap()
+ apiproxy_stub_map.apiproxy.RegisterStub('app_identity_service',
app_identity_stub)
- apiproxy_stub_map.apiproxy.RegisterStub(
+ apiproxy_stub_map.apiproxy.RegisterStub(
'memcache', memcache_stub.MemcacheServiceStub())
- scope = 'http://www.googleapis.com/scope'
- credentials = AppAssertionCredentials(scope)
- http = httplib2.Http()
- self.assertRaises(AccessTokenRefreshError, credentials.refresh, http)
+ scope = 'http://www.googleapis.com/scope'
+ credentials = AppAssertionCredentials(scope)
+ http = httplib2.Http()
+ self.assertRaises(AccessTokenRefreshError, credentials.refresh, http)
- def test_get_access_token_on_refresh(self):
- app_identity_stub = self.AppIdentityStubImpl()
- apiproxy_stub_map.apiproxy = apiproxy_stub_map.APIProxyStubMap()
- apiproxy_stub_map.apiproxy.RegisterStub("app_identity_service",
+ def test_get_access_token_on_refresh(self):
+ app_identity_stub = self.AppIdentityStubImpl()
+ apiproxy_stub_map.apiproxy = apiproxy_stub_map.APIProxyStubMap()
+ apiproxy_stub_map.apiproxy.RegisterStub("app_identity_service",
app_identity_stub)
- apiproxy_stub_map.apiproxy.RegisterStub(
+ apiproxy_stub_map.apiproxy.RegisterStub(
'memcache', memcache_stub.MemcacheServiceStub())
- scope = [
+ scope = [
"http://www.googleapis.com/scope",
"http://www.googleapis.com/scope2"]
- credentials = AppAssertionCredentials(scope)
- http = httplib2.Http()
- credentials.refresh(http)
- self.assertEqual('a_token_123', credentials.access_token)
+ credentials = AppAssertionCredentials(scope)
+ http = httplib2.Http()
+ credentials.refresh(http)
+ self.assertEqual('a_token_123', credentials.access_token)
- json = credentials.to_json()
- credentials = Credentials.new_from_json(json)
- self.assertEqual(
+ json = credentials.to_json()
+ credentials = Credentials.new_from_json(json)
+ self.assertEqual(
'http://www.googleapis.com/scope http://www.googleapis.com/scope2',
credentials.scope)
- scope = "http://www.googleapis.com/scope http://www.googleapis.com/scope2"
- credentials = AppAssertionCredentials(scope)
- http = httplib2.Http()
- credentials.refresh(http)
- self.assertEqual('a_token_123', credentials.access_token)
- self.assertEqual(
+ scope = "http://www.googleapis.com/scope http://www.googleapis.com/scope2"
+ credentials = AppAssertionCredentials(scope)
+ http = httplib2.Http()
+ credentials.refresh(http)
+ self.assertEqual('a_token_123', credentials.access_token)
+ self.assertEqual(
'http://www.googleapis.com/scope http://www.googleapis.com/scope2',
credentials.scope)
- def test_custom_service_account(self):
- scope = "http://www.googleapis.com/scope"
- account_id = "service_account_name_2@appspot.com"
+ def test_custom_service_account(self):
+ scope = "http://www.googleapis.com/scope"
+ account_id = "service_account_name_2@appspot.com"
- with mock.patch.object(app_identity, 'get_access_token',
+ with mock.patch.object(app_identity, 'get_access_token',
return_value=('a_token_456', None),
autospec=True) as get_access_token:
- credentials = AppAssertionCredentials(
+ credentials = AppAssertionCredentials(
scope, service_account_id=account_id)
- http = httplib2.Http()
- credentials.refresh(http)
+ http = httplib2.Http()
+ credentials.refresh(http)
- self.assertEqual('a_token_456', credentials.access_token)
- self.assertEqual(scope, credentials.scope)
- get_access_token.assert_called_once_with(
+ self.assertEqual('a_token_456', credentials.access_token)
+ self.assertEqual(scope, credentials.scope)
+ get_access_token.assert_called_once_with(
[scope], service_account_id=account_id)
- def test_create_scoped_required_without_scopes(self):
- credentials = AppAssertionCredentials([])
- self.assertTrue(credentials.create_scoped_required())
+ def test_create_scoped_required_without_scopes(self):
+ credentials = AppAssertionCredentials([])
+ self.assertTrue(credentials.create_scoped_required())
- def test_create_scoped_required_with_scopes(self):
- credentials = AppAssertionCredentials(['dummy_scope'])
- self.assertFalse(credentials.create_scoped_required())
+ def test_create_scoped_required_with_scopes(self):
+ credentials = AppAssertionCredentials(['dummy_scope'])
+ self.assertFalse(credentials.create_scoped_required())
- def test_create_scoped(self):
- credentials = AppAssertionCredentials([])
- new_credentials = credentials.create_scoped(['dummy_scope'])
- self.assertNotEqual(credentials, new_credentials)
- self.assertTrue(isinstance(new_credentials, AppAssertionCredentials))
- self.assertEqual('dummy_scope', new_credentials.scope)
+ def test_create_scoped(self):
+ credentials = AppAssertionCredentials([])
+ new_credentials = credentials.create_scoped(['dummy_scope'])
+ self.assertNotEqual(credentials, new_credentials)
+ self.assertTrue(isinstance(new_credentials, AppAssertionCredentials))
+ self.assertEqual('dummy_scope', new_credentials.scope)
- def test_get_access_token(self):
- app_identity_stub = self.AppIdentityStubImpl()
- apiproxy_stub_map.apiproxy = apiproxy_stub_map.APIProxyStubMap()
- apiproxy_stub_map.apiproxy.RegisterStub("app_identity_service",
+ def test_get_access_token(self):
+ app_identity_stub = self.AppIdentityStubImpl()
+ apiproxy_stub_map.apiproxy = apiproxy_stub_map.APIProxyStubMap()
+ apiproxy_stub_map.apiproxy.RegisterStub("app_identity_service",
app_identity_stub)
- apiproxy_stub_map.apiproxy.RegisterStub(
+ apiproxy_stub_map.apiproxy.RegisterStub(
'memcache', memcache_stub.MemcacheServiceStub())
- credentials = AppAssertionCredentials(['dummy_scope'])
- token = credentials.get_access_token()
- self.assertEqual('a_token_123', token.access_token)
- self.assertEqual(None, token.expires_in)
+ credentials = AppAssertionCredentials(['dummy_scope'])
+ token = credentials.get_access_token()
+ self.assertEqual('a_token_123', token.access_token)
+ self.assertEqual(None, token.expires_in)
- def test_save_to_well_known_file(self):
- os.environ[_CLOUDSDK_CONFIG_ENV_VAR] = tempfile.mkdtemp()
- credentials = AppAssertionCredentials([])
- self.assertRaises(NotImplementedError, save_to_well_known_file, credentials)
- del os.environ[_CLOUDSDK_CONFIG_ENV_VAR]
+ def test_save_to_well_known_file(self):
+ os.environ[_CLOUDSDK_CONFIG_ENV_VAR] = tempfile.mkdtemp()
+ credentials = AppAssertionCredentials([])
+ self.assertRaises(NotImplementedError, save_to_well_known_file, credentials)
+ del os.environ[_CLOUDSDK_CONFIG_ENV_VAR]
class TestFlowModel(db.Model):
- flow = FlowProperty()
+ flow = FlowProperty()
class FlowPropertyTest(unittest.TestCase):
- def setUp(self):
- self.testbed = testbed.Testbed()
- self.testbed.activate()
- self.testbed.init_datastore_v3_stub()
+ def setUp(self):
+ self.testbed = testbed.Testbed()
+ self.testbed.activate()
+ self.testbed.init_datastore_v3_stub()
- def tearDown(self):
- self.testbed.deactivate()
+ def tearDown(self):
+ self.testbed.deactivate()
- def test_flow_get_put(self):
- instance = TestFlowModel(
+ def test_flow_get_put(self):
+ instance = TestFlowModel(
flow=flow_from_clientsecrets(datafile('client_secrets.json'), 'foo',
redirect_uri='oob'),
key_name='foo'
)
- instance.put()
- retrieved = TestFlowModel.get_by_key_name('foo')
+ instance.put()
+ retrieved = TestFlowModel.get_by_key_name('foo')
- self.assertEqual('foo_client_id', retrieved.flow.client_id)
+ self.assertEqual('foo_client_id', retrieved.flow.client_id)
class TestFlowNDBModel(ndb.Model):
- flow = FlowNDBProperty()
+ flow = FlowNDBProperty()
class FlowNDBPropertyTest(unittest.TestCase):
- def setUp(self):
- self.testbed = testbed.Testbed()
- self.testbed.activate()
- self.testbed.init_datastore_v3_stub()
- self.testbed.init_memcache_stub()
+ def setUp(self):
+ self.testbed = testbed.Testbed()
+ self.testbed.activate()
+ self.testbed.init_datastore_v3_stub()
+ self.testbed.init_memcache_stub()
- def tearDown(self):
- self.testbed.deactivate()
+ def tearDown(self):
+ self.testbed.deactivate()
- def test_flow_get_put(self):
- instance = TestFlowNDBModel(
+ def test_flow_get_put(self):
+ instance = TestFlowNDBModel(
flow=flow_from_clientsecrets(datafile('client_secrets.json'), 'foo',
redirect_uri='oob'),
id='foo'
)
- instance.put()
- retrieved = TestFlowNDBModel.get_by_id('foo')
+ instance.put()
+ retrieved = TestFlowNDBModel.get_by_id('foo')
- self.assertEqual('foo_client_id', retrieved.flow.client_id)
+ self.assertEqual('foo_client_id', retrieved.flow.client_id)
def _http_request(*args, **kwargs):
- resp = httplib2.Response({'status': '200'})
- content = json.dumps({'access_token': 'bar'})
+ resp = httplib2.Response({'status': '200'})
+ content = json.dumps({'access_token': 'bar'})
- return resp, content
+ return resp, content
class StorageByKeyNameTest(unittest.TestCase):
- def setUp(self):
- self.testbed = testbed.Testbed()
- self.testbed.activate()
- self.testbed.init_datastore_v3_stub()
- self.testbed.init_memcache_stub()
- self.testbed.init_user_stub()
+ def setUp(self):
+ self.testbed = testbed.Testbed()
+ self.testbed.activate()
+ self.testbed.init_datastore_v3_stub()
+ self.testbed.init_memcache_stub()
+ self.testbed.init_user_stub()
- access_token = 'foo'
- client_id = 'some_client_id'
- client_secret = 'cOuDdkfjxxnv+'
- refresh_token = '1/0/a.df219fjls0'
- token_expiry = datetime.datetime.utcnow()
- user_agent = 'refresh_checker/1.0'
- self.credentials = OAuth2Credentials(
+ access_token = 'foo'
+ client_id = 'some_client_id'
+ client_secret = 'cOuDdkfjxxnv+'
+ refresh_token = '1/0/a.df219fjls0'
+ token_expiry = datetime.datetime.utcnow()
+ user_agent = 'refresh_checker/1.0'
+ self.credentials = OAuth2Credentials(
access_token, client_id, client_secret,
refresh_token, token_expiry, GOOGLE_TOKEN_URI,
user_agent)
- def tearDown(self):
- self.testbed.deactivate()
+ def tearDown(self):
+ self.testbed.deactivate()
- def test_get_and_put_simple(self):
- storage = StorageByKeyName(
+ def test_get_and_put_simple(self):
+ storage = StorageByKeyName(
CredentialsModel, 'foo', 'credentials')
- self.assertEqual(None, storage.get())
- self.credentials.set_store(storage)
+ self.assertEqual(None, storage.get())
+ self.credentials.set_store(storage)
- self.credentials._refresh(_http_request)
- credmodel = CredentialsModel.get_by_key_name('foo')
- self.assertEqual('bar', credmodel.credentials.access_token)
+ self.credentials._refresh(_http_request)
+ credmodel = CredentialsModel.get_by_key_name('foo')
+ self.assertEqual('bar', credmodel.credentials.access_token)
- def test_get_and_put_cached(self):
- storage = StorageByKeyName(
+ def test_get_and_put_cached(self):
+ storage = StorageByKeyName(
CredentialsModel, 'foo', 'credentials', cache=memcache)
- self.assertEqual(None, storage.get())
- self.credentials.set_store(storage)
+ self.assertEqual(None, storage.get())
+ self.credentials.set_store(storage)
- self.credentials._refresh(_http_request)
- credmodel = CredentialsModel.get_by_key_name('foo')
- self.assertEqual('bar', credmodel.credentials.access_token)
+ self.credentials._refresh(_http_request)
+ credmodel = CredentialsModel.get_by_key_name('foo')
+ self.assertEqual('bar', credmodel.credentials.access_token)
- # Now remove the item from the cache.
- memcache.delete('foo')
+ # Now remove the item from the cache.
+ memcache.delete('foo')
- # Check that getting refreshes the cache.
- credentials = storage.get()
- self.assertEqual('bar', credentials.access_token)
- self.assertNotEqual(None, memcache.get('foo'))
+ # Check that getting refreshes the cache.
+ credentials = storage.get()
+ self.assertEqual('bar', credentials.access_token)
+ self.assertNotEqual(None, memcache.get('foo'))
- # Deleting should clear the cache.
- storage.delete()
- credentials = storage.get()
- self.assertEqual(None, credentials)
- self.assertEqual(None, memcache.get('foo'))
+ # Deleting should clear the cache.
+ storage.delete()
+ credentials = storage.get()
+ self.assertEqual(None, credentials)
+ self.assertEqual(None, memcache.get('foo'))
- def test_get_and_put_set_store_on_cache_retrieval(self):
- storage = StorageByKeyName(
+ def test_get_and_put_set_store_on_cache_retrieval(self):
+ storage = StorageByKeyName(
CredentialsModel, 'foo', 'credentials', cache=memcache)
- self.assertEqual(None, storage.get())
- self.credentials.set_store(storage)
- storage.put(self.credentials)
- # Pre-bug 292 old_creds wouldn't have storage, and the _refresh wouldn't
- # be able to store the updated cred back into the storage.
- old_creds = storage.get()
- self.assertEqual(old_creds.access_token, 'foo')
- old_creds.invalid = True
- old_creds._refresh(_http_request)
- new_creds = storage.get()
- self.assertEqual(new_creds.access_token, 'bar')
+ self.assertEqual(None, storage.get())
+ self.credentials.set_store(storage)
+ storage.put(self.credentials)
+ # Pre-bug 292 old_creds wouldn't have storage, and the _refresh wouldn't
+ # be able to store the updated cred back into the storage.
+ old_creds = storage.get()
+ self.assertEqual(old_creds.access_token, 'foo')
+ old_creds.invalid = True
+ old_creds._refresh(_http_request)
+ new_creds = storage.get()
+ self.assertEqual(new_creds.access_token, 'bar')
- def test_get_and_put_ndb(self):
- # Start empty
- storage = StorageByKeyName(
+ def test_get_and_put_ndb(self):
+ # Start empty
+ storage = StorageByKeyName(
CredentialsNDBModel, 'foo', 'credentials')
- self.assertEqual(None, storage.get())
+ self.assertEqual(None, storage.get())
- # Refresh storage and retrieve without using storage
- self.credentials.set_store(storage)
- self.credentials._refresh(_http_request)
- credmodel = CredentialsNDBModel.get_by_id('foo')
- self.assertEqual('bar', credmodel.credentials.access_token)
- self.assertEqual(credmodel.credentials.to_json(),
+ # Refresh storage and retrieve without using storage
+ self.credentials.set_store(storage)
+ self.credentials._refresh(_http_request)
+ credmodel = CredentialsNDBModel.get_by_id('foo')
+ self.assertEqual('bar', credmodel.credentials.access_token)
+ self.assertEqual(credmodel.credentials.to_json(),
self.credentials.to_json())
- def test_delete_ndb(self):
- # Start empty
- storage = StorageByKeyName(
+ def test_delete_ndb(self):
+ # Start empty
+ storage = StorageByKeyName(
CredentialsNDBModel, 'foo', 'credentials')
- self.assertEqual(None, storage.get())
+ self.assertEqual(None, storage.get())
- # Add credentials to model with storage, and check equivalent w/o storage
- storage.put(self.credentials)
- credmodel = CredentialsNDBModel.get_by_id('foo')
- self.assertEqual(credmodel.credentials.to_json(),
+ # Add credentials to model with storage, and check equivalent w/o storage
+ storage.put(self.credentials)
+ credmodel = CredentialsNDBModel.get_by_id('foo')
+ self.assertEqual(credmodel.credentials.to_json(),
self.credentials.to_json())
- # Delete and make sure empty
- storage.delete()
- self.assertEqual(None, storage.get())
+ # Delete and make sure empty
+ storage.delete()
+ self.assertEqual(None, storage.get())
- def test_get_and_put_mixed_ndb_storage_db_get(self):
- # Start empty
- storage = StorageByKeyName(
+ def test_get_and_put_mixed_ndb_storage_db_get(self):
+ # Start empty
+ storage = StorageByKeyName(
CredentialsNDBModel, 'foo', 'credentials')
- self.assertEqual(None, storage.get())
+ self.assertEqual(None, storage.get())
- # Set NDB store and refresh to add to storage
- self.credentials.set_store(storage)
- self.credentials._refresh(_http_request)
+ # Set NDB store and refresh to add to storage
+ self.credentials.set_store(storage)
+ self.credentials._refresh(_http_request)
- # Retrieve same key from DB model to confirm mixing works
- credmodel = CredentialsModel.get_by_key_name('foo')
- self.assertEqual('bar', credmodel.credentials.access_token)
- self.assertEqual(self.credentials.to_json(),
+ # Retrieve same key from DB model to confirm mixing works
+ credmodel = CredentialsModel.get_by_key_name('foo')
+ self.assertEqual('bar', credmodel.credentials.access_token)
+ self.assertEqual(self.credentials.to_json(),
credmodel.credentials.to_json())
- def test_get_and_put_mixed_db_storage_ndb_get(self):
- # Start empty
- storage = StorageByKeyName(
+ def test_get_and_put_mixed_db_storage_ndb_get(self):
+ # Start empty
+ storage = StorageByKeyName(
CredentialsModel, 'foo', 'credentials')
- self.assertEqual(None, storage.get())
+ self.assertEqual(None, storage.get())
- # Set DB store and refresh to add to storage
- self.credentials.set_store(storage)
- self.credentials._refresh(_http_request)
+ # Set DB store and refresh to add to storage
+ self.credentials.set_store(storage)
+ self.credentials._refresh(_http_request)
- # Retrieve same key from NDB model to confirm mixing works
- credmodel = CredentialsNDBModel.get_by_id('foo')
- self.assertEqual('bar', credmodel.credentials.access_token)
- self.assertEqual(self.credentials.to_json(),
+ # Retrieve same key from NDB model to confirm mixing works
+ credmodel = CredentialsNDBModel.get_by_id('foo')
+ self.assertEqual('bar', credmodel.credentials.access_token)
+ self.assertEqual(self.credentials.to_json(),
credmodel.credentials.to_json())
- def test_delete_db_ndb_mixed(self):
- # Start empty
- storage_ndb = StorageByKeyName(
+ def test_delete_db_ndb_mixed(self):
+ # Start empty
+ storage_ndb = StorageByKeyName(
CredentialsNDBModel, 'foo', 'credentials')
- storage = StorageByKeyName(
+ storage = StorageByKeyName(
CredentialsModel, 'foo', 'credentials')
- # First DB, then NDB
- self.assertEqual(None, storage.get())
- storage.put(self.credentials)
- self.assertNotEqual(None, storage.get())
+ # First DB, then NDB
+ self.assertEqual(None, storage.get())
+ storage.put(self.credentials)
+ self.assertNotEqual(None, storage.get())
- storage_ndb.delete()
- self.assertEqual(None, storage.get())
+ storage_ndb.delete()
+ self.assertEqual(None, storage.get())
- # First NDB, then DB
- self.assertEqual(None, storage_ndb.get())
- storage_ndb.put(self.credentials)
+ # First NDB, then DB
+ self.assertEqual(None, storage_ndb.get())
+ storage_ndb.put(self.credentials)
- storage.delete()
- self.assertNotEqual(None, storage_ndb.get())
- # NDB uses memcache and an instance cache (Context)
- ndb.get_context().clear_cache()
- memcache.flush_all()
- self.assertEqual(None, storage_ndb.get())
+ storage.delete()
+ self.assertNotEqual(None, storage_ndb.get())
+ # NDB uses memcache and an instance cache (Context)
+ ndb.get_context().clear_cache()
+ memcache.flush_all()
+ self.assertEqual(None, storage_ndb.get())
class MockRequest(object):
- url = 'https://example.org'
+ url = 'https://example.org'
- def relative_url(self, rel):
- return self.url + rel
+ def relative_url(self, rel):
+ return self.url + rel
class MockRequestHandler(object):
- request = MockRequest()
+ request = MockRequest()
class DecoratorTests(unittest.TestCase):
- def setUp(self):
- self.testbed = testbed.Testbed()
- self.testbed.activate()
- self.testbed.init_datastore_v3_stub()
- self.testbed.init_memcache_stub()
- self.testbed.init_user_stub()
+ def setUp(self):
+ self.testbed = testbed.Testbed()
+ self.testbed.activate()
+ self.testbed.init_datastore_v3_stub()
+ self.testbed.init_memcache_stub()
+ self.testbed.init_user_stub()
- decorator = OAuth2Decorator(client_id='foo_client_id',
+ decorator = OAuth2Decorator(client_id='foo_client_id',
client_secret='foo_client_secret',
scope=['foo_scope', 'bar_scope'],
user_agent='foo')
- self._finish_setup(decorator, user_mock=UserMock)
+ self._finish_setup(decorator, user_mock=UserMock)
- def _finish_setup(self, decorator, user_mock):
- self.decorator = decorator
- self.had_credentials = False
- self.found_credentials = None
- self.should_raise = False
- parent = self
+ def _finish_setup(self, decorator, user_mock):
+ self.decorator = decorator
+ self.had_credentials = False
+ self.found_credentials = None
+ self.should_raise = False
+ parent = self
- class TestRequiredHandler(webapp2.RequestHandler):
- @decorator.oauth_required
- def get(self):
- if decorator.has_credentials():
- parent.had_credentials = True
- parent.found_credentials = decorator.credentials
- if parent.should_raise:
- raise Exception('')
+ class TestRequiredHandler(webapp2.RequestHandler):
+ @decorator.oauth_required
+ def get(self):
+ if decorator.has_credentials():
+ parent.had_credentials = True
+ parent.found_credentials = decorator.credentials
+ if parent.should_raise:
+ raise Exception('')
- class TestAwareHandler(webapp2.RequestHandler):
- @decorator.oauth_aware
- def get(self, *args, **kwargs):
- self.response.out.write('Hello World!')
- assert(kwargs['year'] == '2012')
- assert(kwargs['month'] == '01')
- if decorator.has_credentials():
- parent.had_credentials = True
- parent.found_credentials = decorator.credentials
- if parent.should_raise:
- raise Exception('')
+ class TestAwareHandler(webapp2.RequestHandler):
+ @decorator.oauth_aware
+ def get(self, *args, **kwargs):
+ self.response.out.write('Hello World!')
+ assert(kwargs['year'] == '2012')
+ assert(kwargs['month'] == '01')
+ if decorator.has_credentials():
+ parent.had_credentials = True
+ parent.found_credentials = decorator.credentials
+ if parent.should_raise:
+ raise Exception('')
-
- application = webapp2.WSGIApplication([
+ application = webapp2.WSGIApplication([
('/oauth2callback', self.decorator.callback_handler()),
('/foo_path', TestRequiredHandler),
webapp2.Route(r'/bar_path//',
@@ -543,344 +538,342 @@ class DecoratorTests(unittest.TestCase):
self.httplib2_orig = httplib2.Http
httplib2.Http = Http2Mock
- def tearDown(self):
- self.testbed.deactivate()
- httplib2.Http = self.httplib2_orig
+ def tearDown(self):
+ self.testbed.deactivate()
+ httplib2.Http = self.httplib2_orig
- def test_required(self):
- # An initial request to an oauth_required decorated path should be a
- # redirect to start the OAuth dance.
- self.assertEqual(self.decorator.flow, None)
- self.assertEqual(self.decorator.credentials, None)
- response = self.app.get('http://localhost/foo_path')
- self.assertTrue(response.status.startswith('302'))
- q = urllib.parse.parse_qs(response.headers['Location'].split('?', 1)[1])
- self.assertEqual('http://localhost/oauth2callback', q['redirect_uri'][0])
- self.assertEqual('foo_client_id', q['client_id'][0])
- self.assertEqual('foo_scope bar_scope', q['scope'][0])
- self.assertEqual('http://localhost/foo_path',
+ def test_required(self):
+ # An initial request to an oauth_required decorated path should be a
+ # redirect to start the OAuth dance.
+ self.assertEqual(self.decorator.flow, None)
+ self.assertEqual(self.decorator.credentials, None)
+ response = self.app.get('http://localhost/foo_path')
+ self.assertTrue(response.status.startswith('302'))
+ q = urllib.parse.parse_qs(response.headers['Location'].split('?', 1)[1])
+ self.assertEqual('http://localhost/oauth2callback', q['redirect_uri'][0])
+ self.assertEqual('foo_client_id', q['client_id'][0])
+ self.assertEqual('foo_scope bar_scope', q['scope'][0])
+ self.assertEqual('http://localhost/foo_path',
q['state'][0].rsplit(':', 1)[0])
- self.assertEqual('code', q['response_type'][0])
- self.assertEqual(False, self.decorator.has_credentials())
+ self.assertEqual('code', q['response_type'][0])
+ self.assertEqual(False, self.decorator.has_credentials())
- with mock.patch.object(appengine, '_parse_state_value',
+ with mock.patch.object(appengine, '_parse_state_value',
return_value='foo_path',
autospec=True) as parse_state_value:
- # Now simulate the callback to /oauth2callback.
- response = self.app.get('/oauth2callback', {
+ # Now simulate the callback to /oauth2callback.
+ response = self.app.get('/oauth2callback', {
'code': 'foo_access_code',
'state': 'foo_path:xsrfkey123',
})
- parts = response.headers['Location'].split('?', 1)
- self.assertEqual('http://localhost/foo_path', parts[0])
- self.assertEqual(None, self.decorator.credentials)
- if self.decorator._token_response_param:
- response_query = urllib.parse.parse_qs(parts[1])
- response = response_query[self.decorator._token_response_param][0]
- self.assertEqual(Http2Mock.content,
+ parts = response.headers['Location'].split('?', 1)
+ self.assertEqual('http://localhost/foo_path', parts[0])
+ self.assertEqual(None, self.decorator.credentials)
+ if self.decorator._token_response_param:
+ response_query = urllib.parse.parse_qs(parts[1])
+ response = response_query[self.decorator._token_response_param][0]
+ self.assertEqual(Http2Mock.content,
json.loads(urllib.parse.unquote(response)))
- self.assertEqual(self.decorator.flow, self.decorator._tls.flow)
- self.assertEqual(self.decorator.credentials,
+ self.assertEqual(self.decorator.flow, self.decorator._tls.flow)
+ self.assertEqual(self.decorator.credentials,
self.decorator._tls.credentials)
- parse_state_value.assert_called_once_with(
+ parse_state_value.assert_called_once_with(
'foo_path:xsrfkey123', self.current_user)
- # Now requesting the decorated path should work.
- response = self.app.get('/foo_path')
- self.assertEqual('200 OK', response.status)
- self.assertEqual(True, self.had_credentials)
- self.assertEqual('foo_refresh_token',
+ # Now requesting the decorated path should work.
+ response = self.app.get('/foo_path')
+ self.assertEqual('200 OK', response.status)
+ self.assertEqual(True, self.had_credentials)
+ self.assertEqual('foo_refresh_token',
self.found_credentials.refresh_token)
- self.assertEqual('foo_access_token',
+ self.assertEqual('foo_access_token',
self.found_credentials.access_token)
- self.assertEqual(None, self.decorator.credentials)
+ self.assertEqual(None, self.decorator.credentials)
- # Raising an exception still clears the Credentials.
- self.should_raise = True
- self.assertRaises(Exception, self.app.get, '/foo_path')
- self.should_raise = False
- self.assertEqual(None, self.decorator.credentials)
+ # Raising an exception still clears the Credentials.
+ self.should_raise = True
+ self.assertRaises(Exception, self.app.get, '/foo_path')
+ self.should_raise = False
+ self.assertEqual(None, self.decorator.credentials)
- # Invalidate the stored Credentials.
- self.found_credentials.invalid = True
- self.found_credentials.store.put(self.found_credentials)
+ # Invalidate the stored Credentials.
+ self.found_credentials.invalid = True
+ self.found_credentials.store.put(self.found_credentials)
- # Invalid Credentials should start the OAuth dance again.
- response = self.app.get('/foo_path')
- self.assertTrue(response.status.startswith('302'))
- q = urllib.parse.parse_qs(response.headers['Location'].split('?', 1)[1])
- self.assertEqual('http://localhost/oauth2callback', q['redirect_uri'][0])
+ # Invalid Credentials should start the OAuth dance again.
+ response = self.app.get('/foo_path')
+ self.assertTrue(response.status.startswith('302'))
+ q = urllib.parse.parse_qs(response.headers['Location'].split('?', 1)[1])
+ self.assertEqual('http://localhost/oauth2callback', q['redirect_uri'][0])
- def test_storage_delete(self):
- # An initial request to an oauth_required decorated path should be a
- # redirect to start the OAuth dance.
- response = self.app.get('/foo_path')
- self.assertTrue(response.status.startswith('302'))
+ def test_storage_delete(self):
+ # An initial request to an oauth_required decorated path should be a
+ # redirect to start the OAuth dance.
+ response = self.app.get('/foo_path')
+ self.assertTrue(response.status.startswith('302'))
- with mock.patch.object(appengine, '_parse_state_value',
+ with mock.patch.object(appengine, '_parse_state_value',
return_value='foo_path',
autospec=True) as parse_state_value:
- # Now simulate the callback to /oauth2callback.
- response = self.app.get('/oauth2callback', {
+ # Now simulate the callback to /oauth2callback.
+ response = self.app.get('/oauth2callback', {
'code': 'foo_access_code',
'state': 'foo_path:xsrfkey123',
- })
- self.assertEqual('http://localhost/foo_path', response.headers['Location'])
- self.assertEqual(None, self.decorator.credentials)
+ })
+ self.assertEqual('http://localhost/foo_path', response.headers['Location'])
+ self.assertEqual(None, self.decorator.credentials)
- # Now requesting the decorated path should work.
- response = self.app.get('/foo_path')
+ # Now requesting the decorated path should work.
+ response = self.app.get('/foo_path')
- self.assertTrue(self.had_credentials)
+ self.assertTrue(self.had_credentials)
- # Credentials should be cleared after each call.
- self.assertEqual(None, self.decorator.credentials)
+ # Credentials should be cleared after each call.
+ self.assertEqual(None, self.decorator.credentials)
- # Invalidate the stored Credentials.
- self.found_credentials.store.delete()
+ # Invalidate the stored Credentials.
+ self.found_credentials.store.delete()
- # Invalid Credentials should start the OAuth dance again.
- response = self.app.get('/foo_path')
- self.assertTrue(response.status.startswith('302'))
+ # Invalid Credentials should start the OAuth dance again.
+ response = self.app.get('/foo_path')
+ self.assertTrue(response.status.startswith('302'))
- parse_state_value.assert_called_once_with(
+ parse_state_value.assert_called_once_with(
'foo_path:xsrfkey123', self.current_user)
- def test_aware(self):
- # An initial request to an oauth_aware decorated path should not redirect.
- response = self.app.get('http://localhost/bar_path/2012/01')
- self.assertEqual('Hello World!', response.body)
- self.assertEqual('200 OK', response.status)
- self.assertEqual(False, self.decorator.has_credentials())
- url = self.decorator.authorize_url()
- q = urllib.parse.parse_qs(url.split('?', 1)[1])
- self.assertEqual('http://localhost/oauth2callback', q['redirect_uri'][0])
- self.assertEqual('foo_client_id', q['client_id'][0])
- self.assertEqual('foo_scope bar_scope', q['scope'][0])
- self.assertEqual('http://localhost/bar_path/2012/01',
+ def test_aware(self):
+ # An initial request to an oauth_aware decorated path should not redirect.
+ response = self.app.get('http://localhost/bar_path/2012/01')
+ self.assertEqual('Hello World!', response.body)
+ self.assertEqual('200 OK', response.status)
+ self.assertEqual(False, self.decorator.has_credentials())
+ url = self.decorator.authorize_url()
+ q = urllib.parse.parse_qs(url.split('?', 1)[1])
+ self.assertEqual('http://localhost/oauth2callback', q['redirect_uri'][0])
+ self.assertEqual('foo_client_id', q['client_id'][0])
+ self.assertEqual('foo_scope bar_scope', q['scope'][0])
+ self.assertEqual('http://localhost/bar_path/2012/01',
q['state'][0].rsplit(':', 1)[0])
- self.assertEqual('code', q['response_type'][0])
+ self.assertEqual('code', q['response_type'][0])
- with mock.patch.object(appengine, '_parse_state_value',
+ with mock.patch.object(appengine, '_parse_state_value',
return_value='bar_path',
autospec=True) as parse_state_value:
- # Now simulate the callback to /oauth2callback.
- url = self.decorator.authorize_url()
- response = self.app.get('/oauth2callback', {
+ # Now simulate the callback to /oauth2callback.
+ url = self.decorator.authorize_url()
+ response = self.app.get('/oauth2callback', {
'code': 'foo_access_code',
'state': 'bar_path:xsrfkey456',
})
- self.assertEqual('http://localhost/bar_path', response.headers['Location'])
- self.assertEqual(False, self.decorator.has_credentials())
- parse_state_value.assert_called_once_with(
+ self.assertEqual('http://localhost/bar_path', response.headers['Location'])
+ self.assertEqual(False, self.decorator.has_credentials())
+ parse_state_value.assert_called_once_with(
'bar_path:xsrfkey456', self.current_user)
- # Now requesting the decorated path will have credentials.
- response = self.app.get('/bar_path/2012/01')
- self.assertEqual('200 OK', response.status)
- self.assertEqual('Hello World!', response.body)
- self.assertEqual(True, self.had_credentials)
- self.assertEqual('foo_refresh_token',
+ # Now requesting the decorated path will have credentials.
+ response = self.app.get('/bar_path/2012/01')
+ self.assertEqual('200 OK', response.status)
+ self.assertEqual('Hello World!', response.body)
+ self.assertEqual(True, self.had_credentials)
+ self.assertEqual('foo_refresh_token',
self.found_credentials.refresh_token)
- self.assertEqual('foo_access_token',
+ self.assertEqual('foo_access_token',
self.found_credentials.access_token)
- # Credentials should be cleared after each call.
- self.assertEqual(None, self.decorator.credentials)
+ # Credentials should be cleared after each call.
+ self.assertEqual(None, self.decorator.credentials)
- # Raising an exception still clears the Credentials.
- self.should_raise = True
- self.assertRaises(Exception, self.app.get, '/bar_path/2012/01')
- self.should_raise = False
- self.assertEqual(None, self.decorator.credentials)
+ # Raising an exception still clears the Credentials.
+ self.should_raise = True
+ self.assertRaises(Exception, self.app.get, '/bar_path/2012/01')
+ self.should_raise = False
+ self.assertEqual(None, self.decorator.credentials)
-
- def test_error_in_step2(self):
- # An initial request to an oauth_aware decorated path should not redirect.
- response = self.app.get('/bar_path/2012/01')
- url = self.decorator.authorize_url()
- response = self.app.get('/oauth2callback', {
+ def test_error_in_step2(self):
+ # An initial request to an oauth_aware decorated path should not redirect.
+ response = self.app.get('/bar_path/2012/01')
+ url = self.decorator.authorize_url()
+ response = self.app.get('/oauth2callback', {
'error': 'BadHappened\''
})
- self.assertEqual('200 OK', response.status)
- self.assertTrue('Bad<Stuff>Happened'' in response.body)
+ self.assertEqual('200 OK', response.status)
+ self.assertTrue('Bad<Stuff>Happened'' in response.body)
- def test_kwargs_are_passed_to_underlying_flow(self):
- decorator = OAuth2Decorator(client_id='foo_client_id',
+ def test_kwargs_are_passed_to_underlying_flow(self):
+ decorator = OAuth2Decorator(client_id='foo_client_id',
client_secret='foo_client_secret',
user_agent='foo_user_agent',
scope=['foo_scope', 'bar_scope'],
access_type='offline',
approval_prompt='force',
revoke_uri='dummy_revoke_uri')
- request_handler = MockRequestHandler()
- decorator._create_flow(request_handler)
+ request_handler = MockRequestHandler()
+ decorator._create_flow(request_handler)
- self.assertEqual('https://example.org/oauth2callback',
+ self.assertEqual('https://example.org/oauth2callback',
decorator.flow.redirect_uri)
- self.assertEqual('offline', decorator.flow.params['access_type'])
- self.assertEqual('force', decorator.flow.params['approval_prompt'])
- self.assertEqual('foo_user_agent', decorator.flow.user_agent)
- self.assertEqual('dummy_revoke_uri', decorator.flow.revoke_uri)
- self.assertEqual(None, decorator.flow.params.get('user_agent', None))
- self.assertEqual(decorator.flow, decorator._tls.flow)
+ self.assertEqual('offline', decorator.flow.params['access_type'])
+ self.assertEqual('force', decorator.flow.params['approval_prompt'])
+ self.assertEqual('foo_user_agent', decorator.flow.user_agent)
+ self.assertEqual('dummy_revoke_uri', decorator.flow.revoke_uri)
+ self.assertEqual(None, decorator.flow.params.get('user_agent', None))
+ self.assertEqual(decorator.flow, decorator._tls.flow)
- def test_token_response_param(self):
- self.decorator._token_response_param = 'foobar'
- self.test_required()
+ def test_token_response_param(self):
+ self.decorator._token_response_param = 'foobar'
+ self.test_required()
- def test_decorator_from_client_secrets(self):
- decorator = OAuth2DecoratorFromClientSecrets(
+ def test_decorator_from_client_secrets(self):
+ decorator = OAuth2DecoratorFromClientSecrets(
datafile('client_secrets.json'),
scope=['foo_scope', 'bar_scope'])
- self._finish_setup(decorator, user_mock=UserMock)
+ self._finish_setup(decorator, user_mock=UserMock)
- self.assertFalse(decorator._in_error)
- self.decorator = decorator
- self.test_required()
- http = self.decorator.http()
- self.assertEquals('foo_access_token', http.request.credentials.access_token)
+ self.assertFalse(decorator._in_error)
+ self.decorator = decorator
+ self.test_required()
+ http = self.decorator.http()
+ self.assertEquals('foo_access_token', http.request.credentials.access_token)
- # revoke_uri is not required
- self.assertEqual(self.decorator._revoke_uri,
+ # revoke_uri is not required
+ self.assertEqual(self.decorator._revoke_uri,
'https://accounts.google.com/o/oauth2/revoke')
- self.assertEqual(self.decorator._revoke_uri,
+ self.assertEqual(self.decorator._revoke_uri,
self.decorator.credentials.revoke_uri)
- def test_decorator_from_client_secrets_kwargs(self):
- decorator = OAuth2DecoratorFromClientSecrets(
+ def test_decorator_from_client_secrets_kwargs(self):
+ decorator = OAuth2DecoratorFromClientSecrets(
datafile('client_secrets.json'),
scope=['foo_scope', 'bar_scope'],
approval_prompt='force')
- self.assertTrue('approval_prompt' in decorator._kwargs)
+ self.assertTrue('approval_prompt' in decorator._kwargs)
-
- def test_decorator_from_cached_client_secrets(self):
- cache_mock = CacheMock()
- load_and_cache('client_secrets.json', 'secret', cache_mock)
- decorator = OAuth2DecoratorFromClientSecrets(
+ def test_decorator_from_cached_client_secrets(self):
+ cache_mock = CacheMock()
+ load_and_cache('client_secrets.json', 'secret', cache_mock)
+ decorator = OAuth2DecoratorFromClientSecrets(
# filename, scope, message=None, cache=None
'secret', '', cache=cache_mock)
- self.assertFalse(decorator._in_error)
+ self.assertFalse(decorator._in_error)
- def test_decorator_from_client_secrets_not_logged_in_required(self):
- decorator = OAuth2DecoratorFromClientSecrets(
+ def test_decorator_from_client_secrets_not_logged_in_required(self):
+ decorator = OAuth2DecoratorFromClientSecrets(
datafile('client_secrets.json'),
scope=['foo_scope', 'bar_scope'], message='NotLoggedInMessage')
- self.decorator = decorator
- self._finish_setup(decorator, user_mock=UserNotLoggedInMock)
+ self.decorator = decorator
+ self._finish_setup(decorator, user_mock=UserNotLoggedInMock)
- self.assertFalse(decorator._in_error)
+ self.assertFalse(decorator._in_error)
- # An initial request to an oauth_required decorated path should be a
- # redirect to login.
- response = self.app.get('/foo_path')
- self.assertTrue(response.status.startswith('302'))
- self.assertTrue('Login' in str(response))
+ # An initial request to an oauth_required decorated path should be a
+ # redirect to login.
+ response = self.app.get('/foo_path')
+ self.assertTrue(response.status.startswith('302'))
+ self.assertTrue('Login' in str(response))
- def test_decorator_from_client_secrets_not_logged_in_aware(self):
- decorator = OAuth2DecoratorFromClientSecrets(
+ def test_decorator_from_client_secrets_not_logged_in_aware(self):
+ decorator = OAuth2DecoratorFromClientSecrets(
datafile('client_secrets.json'),
scope=['foo_scope', 'bar_scope'], message='NotLoggedInMessage')
- self.decorator = decorator
- self._finish_setup(decorator, user_mock=UserNotLoggedInMock)
+ self.decorator = decorator
+ self._finish_setup(decorator, user_mock=UserNotLoggedInMock)
- # An initial request to an oauth_aware decorated path should be a
- # redirect to login.
- response = self.app.get('/bar_path/2012/03')
- self.assertTrue(response.status.startswith('302'))
- self.assertTrue('Login' in str(response))
+ # An initial request to an oauth_aware decorated path should be a
+ # redirect to login.
+ response = self.app.get('/bar_path/2012/03')
+ self.assertTrue(response.status.startswith('302'))
+ self.assertTrue('Login' in str(response))
- def test_decorator_from_unfilled_client_secrets_required(self):
- MESSAGE = 'File is missing'
- try:
- decorator = OAuth2DecoratorFromClientSecrets(
+ def test_decorator_from_unfilled_client_secrets_required(self):
+ MESSAGE = 'File is missing'
+ try:
+ decorator = OAuth2DecoratorFromClientSecrets(
datafile('unfilled_client_secrets.json'),
scope=['foo_scope', 'bar_scope'], message=MESSAGE)
- except InvalidClientSecretsError:
- pass
+ except InvalidClientSecretsError:
+ pass
- def test_decorator_from_unfilled_client_secrets_aware(self):
- MESSAGE = 'File is missing'
- try:
- decorator = OAuth2DecoratorFromClientSecrets(
+ def test_decorator_from_unfilled_client_secrets_aware(self):
+ MESSAGE = 'File is missing'
+ try:
+ decorator = OAuth2DecoratorFromClientSecrets(
datafile('unfilled_client_secrets.json'),
scope=['foo_scope', 'bar_scope'], message=MESSAGE)
- except InvalidClientSecretsError:
- pass
+ except InvalidClientSecretsError:
+ pass
class DecoratorXsrfSecretTests(unittest.TestCase):
- """Test xsrf_secret_key."""
+ """Test xsrf_secret_key."""
- def setUp(self):
- self.testbed = testbed.Testbed()
- self.testbed.activate()
- self.testbed.init_datastore_v3_stub()
- self.testbed.init_memcache_stub()
+ def setUp(self):
+ self.testbed = testbed.Testbed()
+ self.testbed.activate()
+ self.testbed.init_datastore_v3_stub()
+ self.testbed.init_memcache_stub()
- def tearDown(self):
- self.testbed.deactivate()
+ def tearDown(self):
+ self.testbed.deactivate()
- def test_build_and_parse_state(self):
- secret = appengine.xsrf_secret_key()
+ def test_build_and_parse_state(self):
+ secret = appengine.xsrf_secret_key()
- # Secret shouldn't change from call to call.
- secret2 = appengine.xsrf_secret_key()
- self.assertEqual(secret, secret2)
+ # Secret shouldn't change from call to call.
+ secret2 = appengine.xsrf_secret_key()
+ self.assertEqual(secret, secret2)
- # Secret shouldn't change if memcache goes away.
- memcache.delete(appengine.XSRF_MEMCACHE_ID,
+ # Secret shouldn't change if memcache goes away.
+ memcache.delete(appengine.XSRF_MEMCACHE_ID,
namespace=appengine.OAUTH2CLIENT_NAMESPACE)
- secret3 = appengine.xsrf_secret_key()
- self.assertEqual(secret2, secret3)
+ secret3 = appengine.xsrf_secret_key()
+ self.assertEqual(secret2, secret3)
- # Secret should change if both memcache and the model goes away.
- memcache.delete(appengine.XSRF_MEMCACHE_ID,
+ # Secret should change if both memcache and the model goes away.
+ memcache.delete(appengine.XSRF_MEMCACHE_ID,
namespace=appengine.OAUTH2CLIENT_NAMESPACE)
- model = appengine.SiteXsrfSecretKey.get_or_insert('site')
- model.delete()
+ model = appengine.SiteXsrfSecretKey.get_or_insert('site')
+ model.delete()
- secret4 = appengine.xsrf_secret_key()
- self.assertNotEqual(secret3, secret4)
+ secret4 = appengine.xsrf_secret_key()
+ self.assertNotEqual(secret3, secret4)
- def test_ndb_insert_db_get(self):
- secret = appengine._generate_new_xsrf_secret_key()
- appengine.SiteXsrfSecretKeyNDB(id='site', secret=secret).put()
+ def test_ndb_insert_db_get(self):
+ secret = appengine._generate_new_xsrf_secret_key()
+ appengine.SiteXsrfSecretKeyNDB(id='site', secret=secret).put()
- site_key = appengine.SiteXsrfSecretKey.get_by_key_name('site')
- self.assertEqual(site_key.secret, secret)
+ site_key = appengine.SiteXsrfSecretKey.get_by_key_name('site')
+ self.assertEqual(site_key.secret, secret)
- def test_db_insert_ndb_get(self):
- secret = appengine._generate_new_xsrf_secret_key()
- appengine.SiteXsrfSecretKey(key_name='site', secret=secret).put()
+ def test_db_insert_ndb_get(self):
+ secret = appengine._generate_new_xsrf_secret_key()
+ appengine.SiteXsrfSecretKey(key_name='site', secret=secret).put()
- site_key = appengine.SiteXsrfSecretKeyNDB.get_by_id('site')
- self.assertEqual(site_key.secret, secret)
+ site_key = appengine.SiteXsrfSecretKeyNDB.get_by_id('site')
+ self.assertEqual(site_key.secret, secret)
class DecoratorXsrfProtectionTests(unittest.TestCase):
- """Test _build_state_value and _parse_state_value."""
+ """Test _build_state_value and _parse_state_value."""
- def setUp(self):
- self.testbed = testbed.Testbed()
- self.testbed.activate()
- self.testbed.init_datastore_v3_stub()
- self.testbed.init_memcache_stub()
+ def setUp(self):
+ self.testbed = testbed.Testbed()
+ self.testbed.activate()
+ self.testbed.init_datastore_v3_stub()
+ self.testbed.init_memcache_stub()
- def tearDown(self):
- self.testbed.deactivate()
+ def tearDown(self):
+ self.testbed.deactivate()
- def test_build_and_parse_state(self):
- state = appengine._build_state_value(MockRequestHandler(), UserMock())
- self.assertEqual(
+ def test_build_and_parse_state(self):
+ state = appengine._build_state_value(MockRequestHandler(), UserMock())
+ self.assertEqual(
'https://example.org',
appengine._parse_state_value(state, UserMock()))
- self.assertRaises(appengine.InvalidXsrfTokenError,
+ self.assertRaises(appengine.InvalidXsrfTokenError,
appengine._parse_state_value, state[1:], UserMock())
if __name__ == '__main__':
- unittest.main()
+ unittest.main()
diff --git a/tests/test_clientsecrets.py b/tests/test_clientsecrets.py
index 6212b37..0fc5794 100644
--- a/tests/test_clientsecrets.py
+++ b/tests/test_clientsecrets.py
@@ -16,7 +16,6 @@
__author__ = 'jcgregorio@google.com (Joe Gregorio)'
-
import os
import unittest
from io import StringIO
@@ -25,22 +24,22 @@ import httplib2
from oauth2client import clientsecrets
-
DATA_DIR = os.path.join(os.path.dirname(__file__), 'data')
VALID_FILE = os.path.join(DATA_DIR, 'client_secrets.json')
INVALID_FILE = os.path.join(DATA_DIR, 'unfilled_client_secrets.json')
NONEXISTENT_FILE = os.path.join(__file__, '..', 'afilethatisntthere.json')
+
class OAuth2CredentialsTests(unittest.TestCase):
- def setUp(self):
- pass
+ def setUp(self):
+ pass
- def tearDown(self):
- pass
+ def tearDown(self):
+ pass
- def test_validate_error(self):
- ERRORS = [
+ def test_validate_error(self):
+ ERRORS = [
('{}', 'Invalid'),
('{"foo": {}}', 'Unknown'),
('{"web": {}}', 'Missing'),
@@ -56,95 +55,95 @@ class OAuth2CredentialsTests(unittest.TestCase):
}
""", 'Property'),
]
- for src, match in ERRORS:
- # Ensure that it is unicode
- try:
- src = src.decode('utf-8')
- except AttributeError:
- pass
- # Test load(s)
- try:
- clientsecrets.loads(src)
- self.fail(src + ' should not be a valid client_secrets file.')
- except clientsecrets.InvalidClientSecretsError as e:
- self.assertTrue(str(e).startswith(match))
+ for src, match in ERRORS:
+ # Ensure that it is unicode
+ try:
+ src = src.decode('utf-8')
+ except AttributeError:
+ pass
+ # Test load(s)
+ try:
+ clientsecrets.loads(src)
+ self.fail(src + ' should not be a valid client_secrets file.')
+ except clientsecrets.InvalidClientSecretsError as e:
+ self.assertTrue(str(e).startswith(match))
- # Test loads(fp)
- try:
- fp = StringIO(src)
- clientsecrets.load(fp)
- self.fail(src + ' should not be a valid client_secrets file.')
- except clientsecrets.InvalidClientSecretsError as e:
- self.assertTrue(str(e).startswith(match))
+ # Test loads(fp)
+ try:
+ fp = StringIO(src)
+ clientsecrets.load(fp)
+ self.fail(src + ' should not be a valid client_secrets file.')
+ except clientsecrets.InvalidClientSecretsError as e:
+ self.assertTrue(str(e).startswith(match))
- def test_load_by_filename(self):
- try:
- clientsecrets._loadfile(NONEXISTENT_FILE)
- self.fail('should fail to load a missing client_secrets file.')
- except clientsecrets.InvalidClientSecretsError as e:
- self.assertTrue(str(e).startswith('File'))
+ def test_load_by_filename(self):
+ try:
+ clientsecrets._loadfile(NONEXISTENT_FILE)
+ self.fail('should fail to load a missing client_secrets file.')
+ except clientsecrets.InvalidClientSecretsError as e:
+ self.assertTrue(str(e).startswith('File'))
class CachedClientsecretsTests(unittest.TestCase):
- class CacheMock(object):
- def __init__(self):
- self.cache = {}
- self.last_get_ns = None
- self.last_set_ns = None
+ class CacheMock(object):
+ def __init__(self):
+ self.cache = {}
+ self.last_get_ns = None
+ self.last_set_ns = None
- def get(self, key, namespace=''):
- # ignoring namespace for easier testing
- self.last_get_ns = namespace
- return self.cache.get(key, None)
+ def get(self, key, namespace=''):
+ # ignoring namespace for easier testing
+ self.last_get_ns = namespace
+ return self.cache.get(key, None)
- def set(self, key, value, namespace=''):
- # ignoring namespace for easier testing
- self.last_set_ns = namespace
- self.cache[key] = value
+ def set(self, key, value, namespace=''):
+ # ignoring namespace for easier testing
+ self.last_set_ns = namespace
+ self.cache[key] = value
- def setUp(self):
- self.cache_mock = self.CacheMock()
+ def setUp(self):
+ self.cache_mock = self.CacheMock()
- def test_cache_miss(self):
- client_type, client_info = clientsecrets.loadfile(
+ def test_cache_miss(self):
+ client_type, client_info = clientsecrets.loadfile(
VALID_FILE, cache=self.cache_mock)
- self.assertEqual('web', client_type)
- self.assertEqual('foo_client_secret', client_info['client_secret'])
+ self.assertEqual('web', client_type)
+ self.assertEqual('foo_client_secret', client_info['client_secret'])
- cached = self.cache_mock.cache[VALID_FILE]
- self.assertEqual({client_type: client_info}, cached)
+ cached = self.cache_mock.cache[VALID_FILE]
+ self.assertEqual({client_type: client_info}, cached)
- # make sure we're using non-empty namespace
- ns = self.cache_mock.last_set_ns
- self.assertTrue(bool(ns))
- # make sure they're equal
- self.assertEqual(ns, self.cache_mock.last_get_ns)
+ # make sure we're using non-empty namespace
+ ns = self.cache_mock.last_set_ns
+ self.assertTrue(bool(ns))
+ # make sure they're equal
+ self.assertEqual(ns, self.cache_mock.last_get_ns)
- def test_cache_hit(self):
- self.cache_mock.cache[NONEXISTENT_FILE] = { 'web': 'secret info' }
+ def test_cache_hit(self):
+ self.cache_mock.cache[NONEXISTENT_FILE] = {'web': 'secret info'}
- client_type, client_info = clientsecrets.loadfile(
+ client_type, client_info = clientsecrets.loadfile(
NONEXISTENT_FILE, cache=self.cache_mock)
- self.assertEqual('web', client_type)
- self.assertEqual('secret info', client_info)
- # make sure we didn't do any set() RPCs
- self.assertEqual(None, self.cache_mock.last_set_ns)
+ self.assertEqual('web', client_type)
+ self.assertEqual('secret info', client_info)
+ # make sure we didn't do any set() RPCs
+ self.assertEqual(None, self.cache_mock.last_set_ns)
- def test_validation(self):
- try:
- clientsecrets.loadfile(INVALID_FILE, cache=self.cache_mock)
- self.fail('Expected InvalidClientSecretsError to be raised '
+ def test_validation(self):
+ try:
+ clientsecrets.loadfile(INVALID_FILE, cache=self.cache_mock)
+ self.fail('Expected InvalidClientSecretsError to be raised '
'while loading %s' % INVALID_FILE)
- except clientsecrets.InvalidClientSecretsError:
- pass
+ except clientsecrets.InvalidClientSecretsError:
+ pass
- def test_without_cache(self):
- # this also ensures loadfile() is backward compatible
- client_type, client_info = clientsecrets.loadfile(VALID_FILE)
- self.assertEqual('web', client_type)
- self.assertEqual('foo_client_secret', client_info['client_secret'])
+ def test_without_cache(self):
+ # this also ensures loadfile() is backward compatible
+ client_type, client_info = clientsecrets.loadfile(VALID_FILE)
+ self.assertEqual('web', client_type)
+ self.assertEqual('foo_client_secret', client_info['client_secret'])
if __name__ == '__main__':
- unittest.main()
+ unittest.main()
diff --git a/tests/test_crypt.py b/tests/test_crypt.py
index 53b91c2..bd2c105 100644
--- a/tests/test_crypt.py
+++ b/tests/test_crypt.py
@@ -18,10 +18,10 @@ import sys
import unittest
try:
- reload
+ reload
except NameError:
- # For Python3 (though importlib should be used, silly 3.3).
- from imp import reload
+ # For Python3 (though importlib should be used, silly 3.3).
+ from imp import reload
from oauth2client import _helpers
from oauth2client.client import HAS_OPENSSL
@@ -30,44 +30,44 @@ from oauth2client import crypt
def datafile(filename):
- f = open(os.path.join(os.path.dirname(__file__), 'data', filename), 'rb')
- data = f.read()
- f.close()
- return data
+ f = open(os.path.join(os.path.dirname(__file__), 'data', filename), 'rb')
+ data = f.read()
+ f.close()
+ return data
class Test_pkcs12_key_as_pem(unittest.TestCase):
- def _make_signed_jwt_creds(self, private_key_file='privatekey.p12',
+ def _make_signed_jwt_creds(self, private_key_file='privatekey.p12',
private_key=None):
- private_key = private_key or datafile(private_key_file)
- return SignedJwtAssertionCredentials(
+ private_key = private_key or datafile(private_key_file)
+ return SignedJwtAssertionCredentials(
'some_account@example.com',
private_key,
scope='read+write',
sub='joe@example.org')
- def _succeeds_helper(self, password=None):
- self.assertEqual(True, HAS_OPENSSL)
+ def _succeeds_helper(self, password=None):
+ self.assertEqual(True, HAS_OPENSSL)
- credentials = self._make_signed_jwt_creds()
- if password is None:
- password = credentials.private_key_password
- pem_contents = crypt.pkcs12_key_as_pem(credentials.private_key, password)
- pkcs12_key_as_pem = datafile('pem_from_pkcs12.pem')
- pkcs12_key_as_pem = _helpers._parse_pem_key(pkcs12_key_as_pem)
- alternate_pem = datafile('pem_from_pkcs12_alternate.pem')
- self.assertTrue(pem_contents in [pkcs12_key_as_pem, alternate_pem])
+ credentials = self._make_signed_jwt_creds()
+ if password is None:
+ password = credentials.private_key_password
+ pem_contents = crypt.pkcs12_key_as_pem(credentials.private_key, password)
+ pkcs12_key_as_pem = datafile('pem_from_pkcs12.pem')
+ pkcs12_key_as_pem = _helpers._parse_pem_key(pkcs12_key_as_pem)
+ alternate_pem = datafile('pem_from_pkcs12_alternate.pem')
+ self.assertTrue(pem_contents in [pkcs12_key_as_pem, alternate_pem])
- def test_succeeds(self):
- self._succeeds_helper()
+ def test_succeeds(self):
+ self._succeeds_helper()
- def test_succeeds_with_unicode_password(self):
- password = u'notasecret'
- self._succeeds_helper(password)
+ def test_succeeds_with_unicode_password(self):
+ password = u'notasecret'
+ self._succeeds_helper(password)
- def test_with_nonsense_key(self):
- from OpenSSL import crypto
- credentials = self._make_signed_jwt_creds(private_key=b'NOT_A_KEY')
- self.assertRaises(crypto.Error, crypt.pkcs12_key_as_pem,
+ def test_with_nonsense_key(self):
+ from OpenSSL import crypto
+ credentials = self._make_signed_jwt_creds(private_key=b'NOT_A_KEY')
+ self.assertRaises(crypto.Error, crypt.pkcs12_key_as_pem,
credentials.private_key, credentials.private_key_password)
diff --git a/tests/test_devshell.py b/tests/test_devshell.py
index b4623dd..2f052f9 100644
--- a/tests/test_devshell.py
+++ b/tests/test_devshell.py
@@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-
"""Tests for oauth2client.devshell."""
import os
@@ -30,110 +29,110 @@ from oauth2client.devshell import NoDevshellServer
class _AuthReferenceServer(threading.Thread):
- def __init__(self, response=None):
- super(_AuthReferenceServer, self).__init__(None)
- self.response = (response or
+ def __init__(self, response=None):
+ super(_AuthReferenceServer, self).__init__(None)
+ self.response = (response or
'["joe@example.com", "fooproj", "sometoken"]')
- def __enter__(self):
- self.start_server()
+ def __enter__(self):
+ self.start_server()
- def start_server(self):
- self._socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
- self._socket.bind(('localhost', 0))
- port = self._socket.getsockname()[1]
- os.environ[DEVSHELL_ENV] = str(port)
- self._socket.listen(0)
- self.daemon = True
- self.start()
- return self
+ def start_server(self):
+ self._socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
+ self._socket.bind(('localhost', 0))
+ port = self._socket.getsockname()[1]
+ os.environ[DEVSHELL_ENV] = str(port)
+ self._socket.listen(0)
+ self.daemon = True
+ self.start()
+ return self
- def __exit__(self, e_type, value, traceback):
- self.stop_server()
+ def __exit__(self, e_type, value, traceback):
+ self.stop_server()
- def stop_server(self):
- del os.environ[DEVSHELL_ENV]
- self._socket.close()
+ def stop_server(self):
+ del os.environ[DEVSHELL_ENV]
+ self._socket.close()
- def run(self):
- s = None
- try:
- # Do not set the timeout on the socket, leave it in the blocking mode as
- # setting the timeout seems to cause spurious EAGAIN errors on OSX.
- self._socket.settimeout(None)
+ def run(self):
+ s = None
+ try:
+ # Do not set the timeout on the socket, leave it in the blocking mode as
+ # setting the timeout seems to cause spurious EAGAIN errors on OSX.
+ self._socket.settimeout(None)
- s, unused_addr = self._socket.accept()
- resp_buffer = ''
- resp_1 = s.recv(6).decode()
- if '\n' not in resp_1:
- raise Exception('invalid request data')
- nstr, extra = resp_1.split('\n', 1)
- resp_buffer = extra
- n = int(nstr)
- to_read = n-len(extra)
- if to_read > 0:
- resp_buffer += s.recv(to_read, socket.MSG_WAITALL)
- if resp_buffer != CREDENTIAL_INFO_REQUEST_JSON:
- raise Exception('bad request')
- l = len(self.response)
- s.sendall(('%d\n%s' % (l, self.response)).encode())
- finally:
- if s:
- s.close()
+ s, unused_addr = self._socket.accept()
+ resp_buffer = ''
+ resp_1 = s.recv(6).decode()
+ if '\n' not in resp_1:
+ raise Exception('invalid request data')
+ nstr, extra = resp_1.split('\n', 1)
+ resp_buffer = extra
+ n = int(nstr)
+ to_read = n - len(extra)
+ if to_read > 0:
+ resp_buffer += s.recv(to_read, socket.MSG_WAITALL)
+ if resp_buffer != CREDENTIAL_INFO_REQUEST_JSON:
+ raise Exception('bad request')
+ l = len(self.response)
+ s.sendall(('%d\n%s' % (l, self.response)).encode())
+ finally:
+ if s:
+ s.close()
class DevshellCredentialsTests(unittest.TestCase):
- def test_signals_no_server(self):
- self.assertRaises(NoDevshellServer, DevshellCredentials)
+ def test_signals_no_server(self):
+ self.assertRaises(NoDevshellServer, DevshellCredentials)
- def test_request_response(self):
- with _AuthReferenceServer():
- response = _SendRecv()
- self.assertEqual(response.user_email, 'joe@example.com')
- self.assertEqual(response.project_id, 'fooproj')
- self.assertEqual(response.access_token, 'sometoken')
+ def test_request_response(self):
+ with _AuthReferenceServer():
+ response = _SendRecv()
+ self.assertEqual(response.user_email, 'joe@example.com')
+ self.assertEqual(response.project_id, 'fooproj')
+ self.assertEqual(response.access_token, 'sometoken')
- def test_no_refresh_token(self):
- with _AuthReferenceServer():
- creds = DevshellCredentials()
- self.assertEquals(None, creds.refresh_token)
+ def test_no_refresh_token(self):
+ with _AuthReferenceServer():
+ creds = DevshellCredentials()
+ self.assertEquals(None, creds.refresh_token)
- def test_reads_credentials(self):
- with _AuthReferenceServer():
- creds = DevshellCredentials()
- self.assertEqual('joe@example.com', creds.user_email)
- self.assertEqual('fooproj', creds.project_id)
- self.assertEqual('sometoken', creds.access_token)
+ def test_reads_credentials(self):
+ with _AuthReferenceServer():
+ creds = DevshellCredentials()
+ self.assertEqual('joe@example.com', creds.user_email)
+ self.assertEqual('fooproj', creds.project_id)
+ self.assertEqual('sometoken', creds.access_token)
- def test_handles_skipped_fields(self):
- with _AuthReferenceServer('["joe@example.com"]'):
- creds = DevshellCredentials()
- self.assertEqual('joe@example.com', creds.user_email)
- self.assertEqual(None, creds.project_id)
- self.assertEqual(None, creds.access_token)
+ def test_handles_skipped_fields(self):
+ with _AuthReferenceServer('["joe@example.com"]'):
+ creds = DevshellCredentials()
+ self.assertEqual('joe@example.com', creds.user_email)
+ self.assertEqual(None, creds.project_id)
+ self.assertEqual(None, creds.access_token)
- def test_handles_tiny_response(self):
- with _AuthReferenceServer('[]'):
- creds = DevshellCredentials()
- self.assertEqual(None, creds.user_email)
- self.assertEqual(None, creds.project_id)
- self.assertEqual(None, creds.access_token)
+ def test_handles_tiny_response(self):
+ with _AuthReferenceServer('[]'):
+ creds = DevshellCredentials()
+ self.assertEqual(None, creds.user_email)
+ self.assertEqual(None, creds.project_id)
+ self.assertEqual(None, creds.access_token)
- def test_handles_ignores_extra_fields(self):
- with _AuthReferenceServer(
+ def test_handles_ignores_extra_fields(self):
+ with _AuthReferenceServer(
'["joe@example.com", "fooproj", "sometoken", "extra"]'):
- creds = DevshellCredentials()
- self.assertEqual('joe@example.com', creds.user_email)
- self.assertEqual('fooproj', creds.project_id)
- self.assertEqual('sometoken', creds.access_token)
+ creds = DevshellCredentials()
+ self.assertEqual('joe@example.com', creds.user_email)
+ self.assertEqual('fooproj', creds.project_id)
+ self.assertEqual('sometoken', creds.access_token)
- def test_refuses_to_save_to_well_known_file(self):
- ORIGINAL_ISDIR = os.path.isdir
- try:
- os.path.isdir = lambda path: True
- with _AuthReferenceServer():
- creds = DevshellCredentials()
- self.assertRaises(NotImplementedError, save_to_well_known_file, creds)
- finally:
- os.path.isdir = ORIGINAL_ISDIR
+ def test_refuses_to_save_to_well_known_file(self):
+ ORIGINAL_ISDIR = os.path.isdir
+ try:
+ os.path.isdir = lambda path: True
+ with _AuthReferenceServer():
+ creds = DevshellCredentials()
+ self.assertRaises(NotImplementedError, save_to_well_known_file, creds)
+ finally:
+ os.path.isdir = ORIGINAL_ISDIR
diff --git a/tests/test_django_orm.py b/tests/test_django_orm.py
index 02c0b69..ebe4400 100644
--- a/tests/test_django_orm.py
+++ b/tests/test_django_orm.py
@@ -14,7 +14,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-
"""Discovery document tests
Unit tests for objects created from discovery documents.
@@ -31,10 +30,10 @@ import unittest
# Ensure that if app engine is available, we use the correct django from it
try:
- from google.appengine.dist import use_library
- use_library('django', '1.5')
+ from google.appengine.dist import use_library
+ use_library('django', '1.5')
except ImportError:
- pass
+ pass
from oauth2client.client import Credentials
from oauth2client.client import Flow
@@ -51,39 +50,39 @@ from oauth2client.django_orm import FlowField
class TestCredentialsField(unittest.TestCase):
- def setUp(self):
- self.field = CredentialsField()
- self.credentials = Credentials()
- self.pickle = base64.b64encode(pickle.dumps(self.credentials))
+ def setUp(self):
+ self.field = CredentialsField()
+ self.credentials = Credentials()
+ self.pickle = base64.b64encode(pickle.dumps(self.credentials))
- def test_field_is_text(self):
- self.assertEquals(self.field.get_internal_type(), 'TextField')
+ def test_field_is_text(self):
+ self.assertEquals(self.field.get_internal_type(), 'TextField')
- def test_field_unpickled(self):
- self.assertTrue(isinstance(self.field.to_python(self.pickle), Credentials))
+ def test_field_unpickled(self):
+ self.assertTrue(isinstance(self.field.to_python(self.pickle), Credentials))
- def test_field_pickled(self):
- prep_value = self.field.get_db_prep_value(self.credentials,
+ def test_field_pickled(self):
+ prep_value = self.field.get_db_prep_value(self.credentials,
connection=None)
- self.assertEqual(prep_value, self.pickle)
+ self.assertEqual(prep_value, self.pickle)
class TestFlowField(unittest.TestCase):
- def setUp(self):
- self.field = FlowField()
- self.flow = Flow()
- self.pickle = base64.b64encode(pickle.dumps(self.flow))
+ def setUp(self):
+ self.field = FlowField()
+ self.flow = Flow()
+ self.pickle = base64.b64encode(pickle.dumps(self.flow))
- def test_field_is_text(self):
- self.assertEquals(self.field.get_internal_type(), 'TextField')
+ def test_field_is_text(self):
+ self.assertEquals(self.field.get_internal_type(), 'TextField')
- def test_field_unpickled(self):
- self.assertTrue(isinstance(self.field.to_python(self.pickle), Flow))
+ def test_field_unpickled(self):
+ self.assertTrue(isinstance(self.field.to_python(self.pickle), Flow))
- def test_field_pickled(self):
- prep_value = self.field.get_db_prep_value(self.flow, connection=None)
- self.assertEqual(prep_value, self.pickle)
+ def test_field_pickled(self):
+ prep_value = self.field.get_db_prep_value(self.flow, connection=None)
+ self.assertEqual(prep_value, self.pickle)
if __name__ == '__main__':
- unittest.main()
+ unittest.main()
diff --git a/tests/test_file.py b/tests/test_file.py
index 89c3e02..c92c68f 100644
--- a/tests/test_file.py
+++ b/tests/test_file.py
@@ -14,7 +14,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-
"""Oauth2client.file tests
Unit tests for oauth2client.file
@@ -42,363 +41,362 @@ from oauth2client.client import AccessTokenCredentials
from oauth2client.client import OAuth2Credentials
from six.moves import http_client
try:
- # Python2
- from future_builtins import oct
+ # Python2
+ from future_builtins import oct
except:
- pass
-
+ pass
FILENAME = tempfile.mktemp('oauth2client_test.data')
class OAuth2ClientFileTests(unittest.TestCase):
- def tearDown(self):
- try:
- os.unlink(FILENAME)
- except OSError:
- pass
+ def tearDown(self):
+ try:
+ os.unlink(FILENAME)
+ except OSError:
+ pass
- def setUp(self):
- try:
- os.unlink(FILENAME)
- except OSError:
- pass
+ def setUp(self):
+ try:
+ os.unlink(FILENAME)
+ except OSError:
+ pass
- def create_test_credentials(self, client_id='some_client_id',
+ def create_test_credentials(self, client_id='some_client_id',
expiration=None):
- access_token = 'foo'
- client_secret = 'cOuDdkfjxxnv+'
- refresh_token = '1/0/a.df219fjls0'
- token_expiry = expiration or datetime.datetime.utcnow()
- token_uri = 'https://www.google.com/accounts/o8/oauth2/token'
- user_agent = 'refresh_checker/1.0'
+ access_token = 'foo'
+ client_secret = 'cOuDdkfjxxnv+'
+ refresh_token = '1/0/a.df219fjls0'
+ token_expiry = expiration or datetime.datetime.utcnow()
+ token_uri = 'https://www.google.com/accounts/o8/oauth2/token'
+ user_agent = 'refresh_checker/1.0'
- credentials = OAuth2Credentials(
+ credentials = OAuth2Credentials(
access_token, client_id, client_secret,
refresh_token, token_expiry, token_uri,
user_agent)
- return credentials
+ return credentials
- def test_non_existent_file_storage(self):
- s = file.Storage(FILENAME)
- credentials = s.get()
- self.assertEquals(None, credentials)
+ def test_non_existent_file_storage(self):
+ s = file.Storage(FILENAME)
+ credentials = s.get()
+ self.assertEquals(None, credentials)
- def test_no_sym_link_credentials(self):
- if hasattr(os, 'symlink'):
- SYMFILENAME = FILENAME + '.sym'
- os.symlink(FILENAME, SYMFILENAME)
- s = file.Storage(SYMFILENAME)
- try:
- s.get()
- self.fail('Should have raised an exception.')
- except file.CredentialsFileSymbolicLinkError:
- pass
- finally:
- os.unlink(SYMFILENAME)
+ def test_no_sym_link_credentials(self):
+ if hasattr(os, 'symlink'):
+ SYMFILENAME = FILENAME + '.sym'
+ os.symlink(FILENAME, SYMFILENAME)
+ s = file.Storage(SYMFILENAME)
+ try:
+ s.get()
+ self.fail('Should have raised an exception.')
+ except file.CredentialsFileSymbolicLinkError:
+ pass
+ finally:
+ os.unlink(SYMFILENAME)
- def test_pickle_and_json_interop(self):
- # Write a file with a pickled OAuth2Credentials.
- credentials = self.create_test_credentials()
+ def test_pickle_and_json_interop(self):
+ # Write a file with a pickled OAuth2Credentials.
+ credentials = self.create_test_credentials()
- f = open(FILENAME, 'wb')
- pickle.dump(credentials, f)
- f.close()
+ f = open(FILENAME, 'wb')
+ pickle.dump(credentials, f)
+ f.close()
- # Storage should be not be able to read that object, as the capability to
- # read and write credentials as pickled objects has been removed.
- s = file.Storage(FILENAME)
- read_credentials = s.get()
- self.assertEquals(None, read_credentials)
+ # Storage should be not be able to read that object, as the capability to
+ # read and write credentials as pickled objects has been removed.
+ s = file.Storage(FILENAME)
+ read_credentials = s.get()
+ self.assertEquals(None, read_credentials)
- # Now write it back out and confirm it has been rewritten as JSON
- s.put(credentials)
- with open(FILENAME) as f:
- data = json.load(f)
+ # Now write it back out and confirm it has been rewritten as JSON
+ s.put(credentials)
+ with open(FILENAME) as f:
+ data = json.load(f)
- self.assertEquals(data['access_token'], 'foo')
- self.assertEquals(data['_class'], 'OAuth2Credentials')
- self.assertEquals(data['_module'], OAuth2Credentials.__module__)
+ self.assertEquals(data['access_token'], 'foo')
+ self.assertEquals(data['_class'], 'OAuth2Credentials')
+ self.assertEquals(data['_module'], OAuth2Credentials.__module__)
- def test_token_refresh_store_expired(self):
- expiration = datetime.datetime.utcnow() - datetime.timedelta(minutes=15)
- credentials = self.create_test_credentials(expiration=expiration)
+ def test_token_refresh_store_expired(self):
+ expiration = datetime.datetime.utcnow() - datetime.timedelta(minutes=15)
+ credentials = self.create_test_credentials(expiration=expiration)
- s = file.Storage(FILENAME)
- s.put(credentials)
- credentials = s.get()
- new_cred = copy.copy(credentials)
- new_cred.access_token = 'bar'
- s.put(new_cred)
+ s = file.Storage(FILENAME)
+ s.put(credentials)
+ credentials = s.get()
+ new_cred = copy.copy(credentials)
+ new_cred.access_token = 'bar'
+ s.put(new_cred)
- access_token = '1/3w'
- token_response = {'access_token': access_token, 'expires_in': 3600}
- http = HttpMockSequence([
+ access_token = '1/3w'
+ token_response = {'access_token': access_token, 'expires_in': 3600}
+ http = HttpMockSequence([
({'status': '200'}, json.dumps(token_response).encode('utf-8')),
- ])
+ ])
- credentials._refresh(http.request)
- self.assertEquals(credentials.access_token, access_token)
+ credentials._refresh(http.request)
+ self.assertEquals(credentials.access_token, access_token)
- def test_token_refresh_store_expires_soon(self):
- # Tests the case where an access token that is valid when it is read from
- # the store expires before the original request succeeds.
- expiration = datetime.datetime.utcnow() + datetime.timedelta(minutes=15)
- credentials = self.create_test_credentials(expiration=expiration)
+ def test_token_refresh_store_expires_soon(self):
+ # Tests the case where an access token that is valid when it is read from
+ # the store expires before the original request succeeds.
+ expiration = datetime.datetime.utcnow() + datetime.timedelta(minutes=15)
+ credentials = self.create_test_credentials(expiration=expiration)
- s = file.Storage(FILENAME)
- s.put(credentials)
- credentials = s.get()
- new_cred = copy.copy(credentials)
- new_cred.access_token = 'bar'
- s.put(new_cred)
+ s = file.Storage(FILENAME)
+ s.put(credentials)
+ credentials = s.get()
+ new_cred = copy.copy(credentials)
+ new_cred.access_token = 'bar'
+ s.put(new_cred)
- access_token = '1/3w'
- token_response = {'access_token': access_token, 'expires_in': 3600}
- http = HttpMockSequence([
+ access_token = '1/3w'
+ token_response = {'access_token': access_token, 'expires_in': 3600}
+ http = HttpMockSequence([
({'status': str(http_client.UNAUTHORIZED)}, b'Initial token expired'),
({'status': str(http_client.UNAUTHORIZED)}, b'Store token expired'),
({'status': str(http_client.OK)},
json.dumps(token_response).encode('utf-8')),
({'status': str(http_client.OK)},
b'Valid response to original request')
- ])
+ ])
- credentials.authorize(http)
- http.request('https://example.com')
- self.assertEqual(credentials.access_token, access_token)
+ credentials.authorize(http)
+ http.request('https://example.com')
+ self.assertEqual(credentials.access_token, access_token)
- def test_token_refresh_good_store(self):
- expiration = datetime.datetime.utcnow() + datetime.timedelta(minutes=15)
- credentials = self.create_test_credentials(expiration=expiration)
+ def test_token_refresh_good_store(self):
+ expiration = datetime.datetime.utcnow() + datetime.timedelta(minutes=15)
+ credentials = self.create_test_credentials(expiration=expiration)
- s = file.Storage(FILENAME)
- s.put(credentials)
- credentials = s.get()
- new_cred = copy.copy(credentials)
- new_cred.access_token = 'bar'
- s.put(new_cred)
+ s = file.Storage(FILENAME)
+ s.put(credentials)
+ credentials = s.get()
+ new_cred = copy.copy(credentials)
+ new_cred.access_token = 'bar'
+ s.put(new_cred)
- credentials._refresh(lambda x: x)
- self.assertEquals(credentials.access_token, 'bar')
+ credentials._refresh(lambda x: x)
+ self.assertEquals(credentials.access_token, 'bar')
- def test_token_refresh_stream_body(self):
- expiration = datetime.datetime.utcnow() + datetime.timedelta(minutes=15)
- credentials = self.create_test_credentials(expiration=expiration)
+ def test_token_refresh_stream_body(self):
+ expiration = datetime.datetime.utcnow() + datetime.timedelta(minutes=15)
+ credentials = self.create_test_credentials(expiration=expiration)
- s = file.Storage(FILENAME)
- s.put(credentials)
- credentials = s.get()
- new_cred = copy.copy(credentials)
- new_cred.access_token = 'bar'
- s.put(new_cred)
+ s = file.Storage(FILENAME)
+ s.put(credentials)
+ credentials = s.get()
+ new_cred = copy.copy(credentials)
+ new_cred.access_token = 'bar'
+ s.put(new_cred)
- valid_access_token = '1/3w'
- token_response = {'access_token': valid_access_token, 'expires_in': 3600}
- http = HttpMockSequence([
+ valid_access_token = '1/3w'
+ token_response = {'access_token': valid_access_token, 'expires_in': 3600}
+ http = HttpMockSequence([
({'status': str(http_client.UNAUTHORIZED)}, b'Initial token expired'),
({'status': str(http_client.UNAUTHORIZED)}, b'Store token expired'),
({'status': str(http_client.OK)},
json.dumps(token_response).encode('utf-8')),
({'status': str(http_client.OK)}, 'echo_request_body')
- ])
+ ])
- body = six.StringIO('streaming body')
+ body = six.StringIO('streaming body')
- credentials.authorize(http)
- _, content = http.request('https://example.com', body=body)
- self.assertEqual(content, 'streaming body')
- self.assertEqual(credentials.access_token, valid_access_token)
+ credentials.authorize(http)
+ _, content = http.request('https://example.com', body=body)
+ self.assertEqual(content, 'streaming body')
+ self.assertEqual(credentials.access_token, valid_access_token)
- def test_credentials_delete(self):
- credentials = self.create_test_credentials()
+ def test_credentials_delete(self):
+ credentials = self.create_test_credentials()
- s = file.Storage(FILENAME)
- s.put(credentials)
- credentials = s.get()
- self.assertNotEquals(None, credentials)
- s.delete()
- credentials = s.get()
- self.assertEquals(None, credentials)
+ s = file.Storage(FILENAME)
+ s.put(credentials)
+ credentials = s.get()
+ self.assertNotEquals(None, credentials)
+ s.delete()
+ credentials = s.get()
+ self.assertEquals(None, credentials)
- def test_access_token_credentials(self):
- access_token = 'foo'
- user_agent = 'refresh_checker/1.0'
+ def test_access_token_credentials(self):
+ access_token = 'foo'
+ user_agent = 'refresh_checker/1.0'
- credentials = AccessTokenCredentials(access_token, user_agent)
+ credentials = AccessTokenCredentials(access_token, user_agent)
- s = file.Storage(FILENAME)
- credentials = s.put(credentials)
- credentials = s.get()
+ s = file.Storage(FILENAME)
+ credentials = s.put(credentials)
+ credentials = s.get()
- self.assertNotEquals(None, credentials)
- self.assertEquals('foo', credentials.access_token)
- mode = os.stat(FILENAME).st_mode
+ self.assertNotEquals(None, credentials)
+ self.assertEquals('foo', credentials.access_token)
+ mode = os.stat(FILENAME).st_mode
- if os.name == 'posix':
- self.assertEquals('0o600', oct(stat.S_IMODE(os.stat(FILENAME).st_mode)))
+ if os.name == 'posix':
+ self.assertEquals('0o600', oct(stat.S_IMODE(os.stat(FILENAME).st_mode)))
- def test_read_only_file_fail_lock(self):
- credentials = self.create_test_credentials()
+ def test_read_only_file_fail_lock(self):
+ credentials = self.create_test_credentials()
- open(FILENAME, 'a+b').close()
- os.chmod(FILENAME, 0o400)
+ open(FILENAME, 'a+b').close()
+ os.chmod(FILENAME, 0o400)
- store = multistore_file.get_credential_storage(
+ store = multistore_file.get_credential_storage(
FILENAME,
credentials.client_id,
credentials.user_agent,
['some-scope', 'some-other-scope'])
- store.put(credentials)
- if os.name == 'posix':
- self.assertTrue(store._multistore._read_only)
- os.chmod(FILENAME, 0o600)
+ store.put(credentials)
+ if os.name == 'posix':
+ self.assertTrue(store._multistore._read_only)
+ os.chmod(FILENAME, 0o600)
- def test_multistore_no_symbolic_link_files(self):
- if hasattr(os, 'symlink'):
- SYMFILENAME = FILENAME + 'sym'
- os.symlink(FILENAME, SYMFILENAME)
- store = multistore_file.get_credential_storage(
+ def test_multistore_no_symbolic_link_files(self):
+ if hasattr(os, 'symlink'):
+ SYMFILENAME = FILENAME + 'sym'
+ os.symlink(FILENAME, SYMFILENAME)
+ store = multistore_file.get_credential_storage(
SYMFILENAME,
'some_client_id',
'user-agent/1.0',
['some-scope', 'some-other-scope'])
- try:
- store.get()
- self.fail('Should have raised an exception.')
- except locked_file.CredentialsFileSymbolicLinkError:
- pass
- finally:
- os.unlink(SYMFILENAME)
+ try:
+ store.get()
+ self.fail('Should have raised an exception.')
+ except locked_file.CredentialsFileSymbolicLinkError:
+ pass
+ finally:
+ os.unlink(SYMFILENAME)
- def test_multistore_non_existent_file(self):
- store = multistore_file.get_credential_storage(
+ def test_multistore_non_existent_file(self):
+ store = multistore_file.get_credential_storage(
FILENAME,
'some_client_id',
'user-agent/1.0',
['some-scope', 'some-other-scope'])
- credentials = store.get()
- self.assertEquals(None, credentials)
+ credentials = store.get()
+ self.assertEquals(None, credentials)
- def test_multistore_file(self):
- credentials = self.create_test_credentials()
+ def test_multistore_file(self):
+ credentials = self.create_test_credentials()
- store = multistore_file.get_credential_storage(
+ store = multistore_file.get_credential_storage(
FILENAME,
credentials.client_id,
credentials.user_agent,
['some-scope', 'some-other-scope'])
- store.put(credentials)
- credentials = store.get()
+ store.put(credentials)
+ credentials = store.get()
- self.assertNotEquals(None, credentials)
- self.assertEquals('foo', credentials.access_token)
+ self.assertNotEquals(None, credentials)
+ self.assertEquals('foo', credentials.access_token)
- store.delete()
- credentials = store.get()
+ store.delete()
+ credentials = store.get()
- self.assertEquals(None, credentials)
+ self.assertEquals(None, credentials)
- if os.name == 'posix':
- self.assertEquals('0o600', oct(stat.S_IMODE(os.stat(FILENAME).st_mode)))
+ if os.name == 'posix':
+ self.assertEquals('0o600', oct(stat.S_IMODE(os.stat(FILENAME).st_mode)))
- def test_multistore_file_custom_key(self):
- credentials = self.create_test_credentials()
+ def test_multistore_file_custom_key(self):
+ credentials = self.create_test_credentials()
- custom_key = {'myapp': 'testing', 'clientid': 'some client'}
- store = multistore_file.get_credential_storage_custom_key(
+ custom_key = {'myapp': 'testing', 'clientid': 'some client'}
+ store = multistore_file.get_credential_storage_custom_key(
FILENAME, custom_key)
- store.put(credentials)
- stored_credentials = store.get()
+ store.put(credentials)
+ stored_credentials = store.get()
- self.assertNotEquals(None, stored_credentials)
- self.assertEqual(credentials.access_token, stored_credentials.access_token)
+ self.assertNotEquals(None, stored_credentials)
+ self.assertEqual(credentials.access_token, stored_credentials.access_token)
- store.delete()
- stored_credentials = store.get()
+ store.delete()
+ stored_credentials = store.get()
- self.assertEquals(None, stored_credentials)
+ self.assertEquals(None, stored_credentials)
- def test_multistore_file_custom_string_key(self):
- credentials = self.create_test_credentials()
+ def test_multistore_file_custom_string_key(self):
+ credentials = self.create_test_credentials()
- # store with string key
- store = multistore_file.get_credential_storage_custom_string_key(
+ # store with string key
+ store = multistore_file.get_credential_storage_custom_string_key(
FILENAME, 'mykey')
- store.put(credentials)
- stored_credentials = store.get()
+ store.put(credentials)
+ stored_credentials = store.get()
- self.assertNotEquals(None, stored_credentials)
- self.assertEqual(credentials.access_token, stored_credentials.access_token)
+ self.assertNotEquals(None, stored_credentials)
+ self.assertEqual(credentials.access_token, stored_credentials.access_token)
- # try retrieving with a dictionary
- store_dict = multistore_file.get_credential_storage_custom_string_key(
+ # try retrieving with a dictionary
+ store_dict = multistore_file.get_credential_storage_custom_string_key(
FILENAME, {'key': 'mykey'})
- stored_credentials = store.get()
- self.assertNotEquals(None, stored_credentials)
- self.assertEqual(credentials.access_token, stored_credentials.access_token)
+ stored_credentials = store.get()
+ self.assertNotEquals(None, stored_credentials)
+ self.assertEqual(credentials.access_token, stored_credentials.access_token)
- store.delete()
- stored_credentials = store.get()
+ store.delete()
+ stored_credentials = store.get()
- self.assertEquals(None, stored_credentials)
+ self.assertEquals(None, stored_credentials)
- def test_multistore_file_backwards_compatibility(self):
- credentials = self.create_test_credentials()
- scopes = ['scope1', 'scope2']
+ def test_multistore_file_backwards_compatibility(self):
+ credentials = self.create_test_credentials()
+ scopes = ['scope1', 'scope2']
- # store the credentials using the legacy key method
- store = multistore_file.get_credential_storage(
+ # store the credentials using the legacy key method
+ store = multistore_file.get_credential_storage(
FILENAME, 'client_id', 'user_agent', scopes)
- store.put(credentials)
+ store.put(credentials)
- # retrieve the credentials using a custom key that matches the legacy key
- key = {'clientId': 'client_id', 'userAgent': 'user_agent',
+ # retrieve the credentials using a custom key that matches the legacy key
+ key = {'clientId': 'client_id', 'userAgent': 'user_agent',
'scope': util.scopes_to_string(scopes)}
- store = multistore_file.get_credential_storage_custom_key(FILENAME, key)
- stored_credentials = store.get()
+ store = multistore_file.get_credential_storage_custom_key(FILENAME, key)
+ stored_credentials = store.get()
- self.assertEqual(credentials.access_token, stored_credentials.access_token)
+ self.assertEqual(credentials.access_token, stored_credentials.access_token)
- def test_multistore_file_get_all_keys(self):
- # start with no keys
- keys = multistore_file.get_all_credential_keys(FILENAME)
- self.assertEquals([], keys)
+ def test_multistore_file_get_all_keys(self):
+ # start with no keys
+ keys = multistore_file.get_all_credential_keys(FILENAME)
+ self.assertEquals([], keys)
- # store credentials
- credentials = self.create_test_credentials(client_id='client1')
- custom_key = {'myapp': 'testing', 'clientid': 'client1'}
- store1 = multistore_file.get_credential_storage_custom_key(
+ # store credentials
+ credentials = self.create_test_credentials(client_id='client1')
+ custom_key = {'myapp': 'testing', 'clientid': 'client1'}
+ store1 = multistore_file.get_credential_storage_custom_key(
FILENAME, custom_key)
- store1.put(credentials)
+ store1.put(credentials)
- keys = multistore_file.get_all_credential_keys(FILENAME)
- self.assertEquals([custom_key], keys)
+ keys = multistore_file.get_all_credential_keys(FILENAME)
+ self.assertEquals([custom_key], keys)
- # store more credentials
- credentials = self.create_test_credentials(client_id='client2')
- string_key = 'string_key'
- store2 = multistore_file.get_credential_storage_custom_string_key(
+ # store more credentials
+ credentials = self.create_test_credentials(client_id='client2')
+ string_key = 'string_key'
+ store2 = multistore_file.get_credential_storage_custom_string_key(
FILENAME, string_key)
- store2.put(credentials)
+ store2.put(credentials)
- keys = multistore_file.get_all_credential_keys(FILENAME)
- self.assertEquals(2, len(keys))
- self.assertTrue(custom_key in keys)
- self.assertTrue({'key': string_key} in keys)
+ keys = multistore_file.get_all_credential_keys(FILENAME)
+ self.assertEquals(2, len(keys))
+ self.assertTrue(custom_key in keys)
+ self.assertTrue({'key': string_key} in keys)
- # back to no keys
- store1.delete()
- store2.delete()
- keys = multistore_file.get_all_credential_keys(FILENAME)
- self.assertEquals([], keys)
+ # back to no keys
+ store1.delete()
+ store2.delete()
+ keys = multistore_file.get_all_credential_keys(FILENAME)
+ self.assertEquals([], keys)
if __name__ == '__main__':
- unittest.main()
+ unittest.main()
diff --git a/tests/test_flask_util.py b/tests/test_flask_util.py
index 1be1f60..56fa374 100644
--- a/tests/test_flask_util.py
+++ b/tests/test_flask_util.py
@@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-
"""Unit tests for the Flask utilities"""
__author__ = 'jonwayne@google.com (Jon Wayne Parrott)'
@@ -35,6 +34,7 @@ from oauth2client.client import OAuth2Credentials
class Http2Mock(object):
"""Mock httplib2.Http for code exchange / refresh"""
+
def __init__(self, status=httplib.OK, **kwargs):
self.status = status
self.content = {
diff --git a/tests/test_gce.py b/tests/test_gce.py
index 523b613..5683122 100644
--- a/tests/test_gce.py
+++ b/tests/test_gce.py
@@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-
"""Tests for oauth2client.gce.
Unit tests for oauth2client.gce.
@@ -36,86 +35,86 @@ from oauth2client.gce import AppAssertionCredentials
class AssertionCredentialsTests(unittest.TestCase):
- def _refresh_success_helper(self, bytes_response=False):
- access_token = u'this-is-a-token'
- return_val = json.dumps({u'accessToken': access_token})
- if bytes_response:
- return_val = _to_bytes(return_val)
- http = mock.MagicMock()
- http.request = mock.MagicMock(
+ def _refresh_success_helper(self, bytes_response=False):
+ access_token = u'this-is-a-token'
+ return_val = json.dumps({u'accessToken': access_token})
+ if bytes_response:
+ return_val = _to_bytes(return_val)
+ http = mock.MagicMock()
+ http.request = mock.MagicMock(
return_value=(mock.Mock(status=200), return_val))
- scopes = ['http://example.com/a', 'http://example.com/b']
- credentials = AppAssertionCredentials(scope=scopes)
- self.assertEquals(None, credentials.access_token)
- credentials.refresh(http)
- self.assertEquals(access_token, credentials.access_token)
+ scopes = ['http://example.com/a', 'http://example.com/b']
+ credentials = AppAssertionCredentials(scope=scopes)
+ self.assertEquals(None, credentials.access_token)
+ credentials.refresh(http)
+ self.assertEquals(access_token, credentials.access_token)
- base_metadata_uri = ('http://metadata.google.internal/0.1/meta-data/'
+ base_metadata_uri = ('http://metadata.google.internal/0.1/meta-data/'
'service-accounts/default/acquire')
- escaped_scopes = urllib.parse.quote(' '.join(scopes), safe='')
- request_uri = base_metadata_uri + '?scope=' + escaped_scopes
- http.request.assert_called_once_with(request_uri)
+ escaped_scopes = urllib.parse.quote(' '.join(scopes), safe='')
+ request_uri = base_metadata_uri + '?scope=' + escaped_scopes
+ http.request.assert_called_once_with(request_uri)
- def test_refresh_success(self):
- self._refresh_success_helper(bytes_response=False)
+ def test_refresh_success(self):
+ self._refresh_success_helper(bytes_response=False)
- def test_refresh_success_bytes(self):
- self._refresh_success_helper(bytes_response=True)
+ def test_refresh_success_bytes(self):
+ self._refresh_success_helper(bytes_response=True)
- def test_fail_refresh(self):
- http = mock.MagicMock()
- http.request = mock.MagicMock(return_value=(mock.Mock(status=400), '{}'))
+ def test_fail_refresh(self):
+ http = mock.MagicMock()
+ http.request = mock.MagicMock(return_value=(mock.Mock(status=400), '{}'))
- c = AppAssertionCredentials(scope=['http://example.com/a',
+ c = AppAssertionCredentials(scope=['http://example.com/a',
'http://example.com/b'])
- self.assertRaises(AccessTokenRefreshError, c.refresh, http)
+ self.assertRaises(AccessTokenRefreshError, c.refresh, http)
- def test_to_from_json(self):
- c = AppAssertionCredentials(scope=['http://example.com/a',
+ def test_to_from_json(self):
+ c = AppAssertionCredentials(scope=['http://example.com/a',
'http://example.com/b'])
- json = c.to_json()
- c2 = Credentials.new_from_json(json)
+ json = c.to_json()
+ c2 = Credentials.new_from_json(json)
- self.assertEqual(c.access_token, c2.access_token)
+ self.assertEqual(c.access_token, c2.access_token)
- def test_create_scoped_required_without_scopes(self):
- credentials = AppAssertionCredentials([])
- self.assertTrue(credentials.create_scoped_required())
+ def test_create_scoped_required_without_scopes(self):
+ credentials = AppAssertionCredentials([])
+ self.assertTrue(credentials.create_scoped_required())
- def test_create_scoped_required_with_scopes(self):
- credentials = AppAssertionCredentials(['dummy_scope'])
- self.assertFalse(credentials.create_scoped_required())
+ def test_create_scoped_required_with_scopes(self):
+ credentials = AppAssertionCredentials(['dummy_scope'])
+ self.assertFalse(credentials.create_scoped_required())
- def test_create_scoped(self):
- credentials = AppAssertionCredentials([])
- new_credentials = credentials.create_scoped(['dummy_scope'])
- self.assertNotEqual(credentials, new_credentials)
- self.assertTrue(isinstance(new_credentials, AppAssertionCredentials))
- self.assertEqual('dummy_scope', new_credentials.scope)
+ def test_create_scoped(self):
+ credentials = AppAssertionCredentials([])
+ new_credentials = credentials.create_scoped(['dummy_scope'])
+ self.assertNotEqual(credentials, new_credentials)
+ self.assertTrue(isinstance(new_credentials, AppAssertionCredentials))
+ self.assertEqual('dummy_scope', new_credentials.scope)
- def test_get_access_token(self):
- http = mock.MagicMock()
- http.request = mock.MagicMock(
+ def test_get_access_token(self):
+ http = mock.MagicMock()
+ http.request = mock.MagicMock(
return_value=(mock.Mock(status=200),
'{"accessToken": "this-is-a-token"}'))
- credentials = AppAssertionCredentials(['dummy_scope'])
- token = credentials.get_access_token(http=http)
- self.assertEqual('this-is-a-token', token.access_token)
- self.assertEqual(None, token.expires_in)
+ credentials = AppAssertionCredentials(['dummy_scope'])
+ token = credentials.get_access_token(http=http)
+ self.assertEqual('this-is-a-token', token.access_token)
+ self.assertEqual(None, token.expires_in)
- http.request.assert_called_once_with(
+ http.request.assert_called_once_with(
'http://metadata.google.internal/0.1/meta-data/service-accounts/'
'default/acquire?scope=dummy_scope')
- def test_save_to_well_known_file(self):
- import os
- ORIGINAL_ISDIR = os.path.isdir
- try:
- os.path.isdir = lambda path: True
- credentials = AppAssertionCredentials([])
- self.assertRaises(NotImplementedError, save_to_well_known_file,
+ def test_save_to_well_known_file(self):
+ import os
+ ORIGINAL_ISDIR = os.path.isdir
+ try:
+ os.path.isdir = lambda path: True
+ credentials = AppAssertionCredentials([])
+ self.assertRaises(NotImplementedError, save_to_well_known_file,
credentials)
- finally:
- os.path.isdir = ORIGINAL_ISDIR
+ finally:
+ os.path.isdir = ORIGINAL_ISDIR
diff --git a/tests/test_import.py b/tests/test_import.py
index 9d9e071..14b971c 100644
--- a/tests/test_import.py
+++ b/tests/test_import.py
@@ -19,9 +19,9 @@ import unittest
class ImportTest(unittest.TestCase):
- def test_tools_import(self):
- import oauth2client.tools
+ def test_tools_import(self):
+ import oauth2client.tools
if __name__ == '__main__':
- unittest.main()
+ unittest.main()
diff --git a/tests/test_jwt.py b/tests/test_jwt.py
index 5dd594f..a56f24e 100644
--- a/tests/test_jwt.py
+++ b/tests/test_jwt.py
@@ -14,7 +14,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-
"""Oauth2client tests
Unit tests for oauth2client.
@@ -42,295 +41,295 @@ from oauth2client.file import Storage
def datafile(filename):
- f = open(os.path.join(os.path.dirname(__file__), 'data', filename), 'rb')
- data = f.read()
- f.close()
- return data
+ f = open(os.path.join(os.path.dirname(__file__), 'data', filename), 'rb')
+ data = f.read()
+ f.close()
+ return data
class CryptTests(unittest.TestCase):
- def setUp(self):
- self.format = 'p12'
- self.signer = crypt.OpenSSLSigner
- self.verifier = crypt.OpenSSLVerifier
+ def setUp(self):
+ self.format = 'p12'
+ self.signer = crypt.OpenSSLSigner
+ self.verifier = crypt.OpenSSLVerifier
- def test_sign_and_verify(self):
- self._check_sign_and_verify('privatekey.%s' % self.format)
+ def test_sign_and_verify(self):
+ self._check_sign_and_verify('privatekey.%s' % self.format)
- def test_sign_and_verify_from_converted_pkcs12(self):
- # Tests that following instructions to convert from PKCS12 to PEM works.
- if self.format == 'pem':
- self._check_sign_and_verify('pem_from_pkcs12.pem')
+ def test_sign_and_verify_from_converted_pkcs12(self):
+ # Tests that following instructions to convert from PKCS12 to PEM works.
+ if self.format == 'pem':
+ self._check_sign_and_verify('pem_from_pkcs12.pem')
- def _check_sign_and_verify(self, private_key_file):
- private_key = datafile(private_key_file)
- public_key = datafile('publickey.pem')
+ def _check_sign_and_verify(self, private_key_file):
+ private_key = datafile(private_key_file)
+ public_key = datafile('publickey.pem')
- # We pass in a non-bytes password to make sure all branches
- # are traversed in tests.
- signer = self.signer.from_string(private_key,
+ # We pass in a non-bytes password to make sure all branches
+ # are traversed in tests.
+ signer = self.signer.from_string(private_key,
password=u'notasecret')
- signature = signer.sign('foo')
+ signature = signer.sign('foo')
- verifier = self.verifier.from_string(public_key, True)
- self.assertTrue(verifier.verify(b'foo', signature))
+ verifier = self.verifier.from_string(public_key, True)
+ self.assertTrue(verifier.verify(b'foo', signature))
- self.assertFalse(verifier.verify(b'bar', signature))
- self.assertFalse(verifier.verify(b'foo', b'bad signagure'))
- self.assertFalse(verifier.verify(b'foo', u'bad signagure'))
+ self.assertFalse(verifier.verify(b'bar', signature))
+ self.assertFalse(verifier.verify(b'foo', b'bad signagure'))
+ self.assertFalse(verifier.verify(b'foo', u'bad signagure'))
- def _check_jwt_failure(self, jwt, expected_error):
- public_key = datafile('publickey.pem')
- certs = {'foo': public_key}
- audience = ('https://www.googleapis.com/auth/id?client_id='
+ def _check_jwt_failure(self, jwt, expected_error):
+ public_key = datafile('publickey.pem')
+ certs = {'foo': public_key}
+ audience = ('https://www.googleapis.com/auth/id?client_id='
'external_public_key@testing.gserviceaccount.com')
- try:
- crypt.verify_signed_jwt_with_certs(jwt, certs, audience)
- self.fail()
- except crypt.AppIdentityError as e:
- self.assertTrue(expected_error in str(e))
+ try:
+ crypt.verify_signed_jwt_with_certs(jwt, certs, audience)
+ self.fail()
+ except crypt.AppIdentityError as e:
+ self.assertTrue(expected_error in str(e))
- def _create_signed_jwt(self):
- private_key = datafile('privatekey.%s' % self.format)
- signer = self.signer.from_string(private_key)
- audience = 'some_audience_address@testing.gserviceaccount.com'
- now = int(time.time())
+ def _create_signed_jwt(self):
+ private_key = datafile('privatekey.%s' % self.format)
+ signer = self.signer.from_string(private_key)
+ audience = 'some_audience_address@testing.gserviceaccount.com'
+ now = int(time.time())
- return crypt.make_signed_jwt(signer, {
+ return crypt.make_signed_jwt(signer, {
'aud': audience,
'iat': now,
'exp': now + 300,
'user': 'billy bob',
'metadata': {'meta': 'data'},
- })
+ })
- def test_verify_id_token(self):
- jwt = self._create_signed_jwt()
- public_key = datafile('publickey.pem')
- certs = {'foo': public_key}
- audience = 'some_audience_address@testing.gserviceaccount.com'
- contents = crypt.verify_signed_jwt_with_certs(jwt, certs, audience)
- self.assertEqual('billy bob', contents['user'])
- self.assertEqual('data', contents['metadata']['meta'])
+ def test_verify_id_token(self):
+ jwt = self._create_signed_jwt()
+ public_key = datafile('publickey.pem')
+ certs = {'foo': public_key}
+ audience = 'some_audience_address@testing.gserviceaccount.com'
+ contents = crypt.verify_signed_jwt_with_certs(jwt, certs, audience)
+ self.assertEqual('billy bob', contents['user'])
+ self.assertEqual('data', contents['metadata']['meta'])
- def test_verify_id_token_with_certs_uri(self):
- jwt = self._create_signed_jwt()
+ def test_verify_id_token_with_certs_uri(self):
+ jwt = self._create_signed_jwt()
- http = HttpMockSequence([
+ http = HttpMockSequence([
({'status': '200'}, datafile('certs.json')),
- ])
+ ])
- contents = verify_id_token(
+ contents = verify_id_token(
jwt, 'some_audience_address@testing.gserviceaccount.com', http=http)
- self.assertEqual('billy bob', contents['user'])
- self.assertEqual('data', contents['metadata']['meta'])
+ self.assertEqual('billy bob', contents['user'])
+ self.assertEqual('data', contents['metadata']['meta'])
- def test_verify_id_token_with_certs_uri_fails(self):
- jwt = self._create_signed_jwt()
+ def test_verify_id_token_with_certs_uri_fails(self):
+ jwt = self._create_signed_jwt()
- http = HttpMockSequence([
+ http = HttpMockSequence([
({'status': '404'}, datafile('certs.json')),
- ])
+ ])
- self.assertRaises(VerifyJwtTokenError, verify_id_token, jwt,
+ self.assertRaises(VerifyJwtTokenError, verify_id_token, jwt,
'some_audience_address@testing.gserviceaccount.com',
http=http)
- def test_verify_id_token_bad_tokens(self):
- private_key = datafile('privatekey.%s' % self.format)
+ def test_verify_id_token_bad_tokens(self):
+ private_key = datafile('privatekey.%s' % self.format)
- # Wrong number of segments
- self._check_jwt_failure('foo', 'Wrong number of segments')
+ # Wrong number of segments
+ self._check_jwt_failure('foo', 'Wrong number of segments')
- # Not json
- self._check_jwt_failure('foo.bar.baz', 'Can\'t parse token')
+ # Not json
+ self._check_jwt_failure('foo.bar.baz', 'Can\'t parse token')
- # Bad signature
- jwt = b'.'.join([b'foo', crypt._urlsafe_b64encode('{"a":"b"}'), b'baz'])
- self._check_jwt_failure(jwt, 'Invalid token signature')
+ # Bad signature
+ jwt = b'.'.join([b'foo', crypt._urlsafe_b64encode('{"a":"b"}'), b'baz'])
+ self._check_jwt_failure(jwt, 'Invalid token signature')
- # No expiration
- signer = self.signer.from_string(private_key)
- audience = ('https:#www.googleapis.com/auth/id?client_id='
+ # No expiration
+ signer = self.signer.from_string(private_key)
+ audience = ('https:#www.googleapis.com/auth/id?client_id='
'external_public_key@testing.gserviceaccount.com')
- jwt = crypt.make_signed_jwt(signer, {
+ jwt = crypt.make_signed_jwt(signer, {
'aud': audience,
'iat': time.time(),
- })
- self._check_jwt_failure(jwt, 'No exp field in token')
+ })
+ self._check_jwt_failure(jwt, 'No exp field in token')
- # No issued at
- jwt = crypt.make_signed_jwt(signer, {
+ # No issued at
+ jwt = crypt.make_signed_jwt(signer, {
'aud': 'audience',
'exp': time.time() + 400,
- })
- self._check_jwt_failure(jwt, 'No iat field in token')
+ })
+ self._check_jwt_failure(jwt, 'No iat field in token')
- # Too early
- jwt = crypt.make_signed_jwt(signer, {
+ # Too early
+ jwt = crypt.make_signed_jwt(signer, {
'aud': 'audience',
'iat': time.time() + 301,
'exp': time.time() + 400,
- })
- self._check_jwt_failure(jwt, 'Token used too early')
+ })
+ self._check_jwt_failure(jwt, 'Token used too early')
- # Too late
- jwt = crypt.make_signed_jwt(signer, {
+ # Too late
+ jwt = crypt.make_signed_jwt(signer, {
'aud': 'audience',
'iat': time.time() - 500,
'exp': time.time() - 301,
- })
- self._check_jwt_failure(jwt, 'Token used too late')
+ })
+ self._check_jwt_failure(jwt, 'Token used too late')
- # Wrong target
- jwt = crypt.make_signed_jwt(signer, {
+ # Wrong target
+ jwt = crypt.make_signed_jwt(signer, {
'aud': 'somebody else',
'iat': time.time(),
'exp': time.time() + 300,
- })
- self._check_jwt_failure(jwt, 'Wrong recipient')
+ })
+ self._check_jwt_failure(jwt, 'Wrong recipient')
- def test_from_string_non_509_cert(self):
- # Use a private key instead of a certificate to test the other branch
- # of from_string().
- public_key = datafile('privatekey.pem')
- verifier = self.verifier.from_string(public_key, is_x509_cert=False)
- self.assertTrue(isinstance(verifier, self.verifier))
+ def test_from_string_non_509_cert(self):
+ # Use a private key instead of a certificate to test the other branch
+ # of from_string().
+ public_key = datafile('privatekey.pem')
+ verifier = self.verifier.from_string(public_key, is_x509_cert=False)
+ self.assertTrue(isinstance(verifier, self.verifier))
class PEMCryptTestsPyCrypto(CryptTests):
- def setUp(self):
- self.format = 'pem'
- self.signer = crypt.PyCryptoSigner
- self.verifier = crypt.PyCryptoVerifier
+ def setUp(self):
+ self.format = 'pem'
+ self.signer = crypt.PyCryptoSigner
+ self.verifier = crypt.PyCryptoVerifier
class PEMCryptTestsOpenSSL(CryptTests):
- def setUp(self):
- self.format = 'pem'
- self.signer = crypt.OpenSSLSigner
- self.verifier = crypt.OpenSSLVerifier
+ def setUp(self):
+ self.format = 'pem'
+ self.signer = crypt.OpenSSLSigner
+ self.verifier = crypt.OpenSSLVerifier
class SignedJwtAssertionCredentialsTests(unittest.TestCase):
- def setUp(self):
- self.format = 'p12'
- crypt.Signer = crypt.OpenSSLSigner
+ def setUp(self):
+ self.format = 'p12'
+ crypt.Signer = crypt.OpenSSLSigner
- def test_credentials_good(self):
- private_key = datafile('privatekey.%s' % self.format)
- credentials = SignedJwtAssertionCredentials(
+ def test_credentials_good(self):
+ private_key = datafile('privatekey.%s' % self.format)
+ credentials = SignedJwtAssertionCredentials(
'some_account@example.com',
private_key,
scope='read+write',
sub='joe@example.org')
- http = HttpMockSequence([
+ http = HttpMockSequence([
({'status': '200'}, b'{"access_token":"1/3w","expires_in":3600}'),
({'status': '200'}, 'echo_request_headers'),
- ])
- http = credentials.authorize(http)
- resp, content = http.request('http://example.org')
- self.assertEqual(b'Bearer 1/3w', content[b'Authorization'])
+ ])
+ http = credentials.authorize(http)
+ resp, content = http.request('http://example.org')
+ self.assertEqual(b'Bearer 1/3w', content[b'Authorization'])
- def test_credentials_to_from_json(self):
- private_key = datafile('privatekey.%s' % self.format)
- credentials = SignedJwtAssertionCredentials(
+ def test_credentials_to_from_json(self):
+ private_key = datafile('privatekey.%s' % self.format)
+ credentials = SignedJwtAssertionCredentials(
'some_account@example.com',
private_key,
scope='read+write',
sub='joe@example.org')
- json = credentials.to_json()
- restored = Credentials.new_from_json(json)
- self.assertEqual(credentials.private_key, restored.private_key)
- self.assertEqual(credentials.private_key_password,
+ json = credentials.to_json()
+ restored = Credentials.new_from_json(json)
+ self.assertEqual(credentials.private_key, restored.private_key)
+ self.assertEqual(credentials.private_key_password,
restored.private_key_password)
- self.assertEqual(credentials.kwargs, restored.kwargs)
+ self.assertEqual(credentials.kwargs, restored.kwargs)
- def _credentials_refresh(self, credentials):
- http = HttpMockSequence([
+ def _credentials_refresh(self, credentials):
+ http = HttpMockSequence([
({'status': '200'}, b'{"access_token":"1/3w","expires_in":3600}'),
({'status': '401'}, b''),
({'status': '200'}, b'{"access_token":"3/3w","expires_in":3600}'),
({'status': '200'}, 'echo_request_headers'),
- ])
- http = credentials.authorize(http)
- _, content = http.request('http://example.org')
- return content
+ ])
+ http = credentials.authorize(http)
+ _, content = http.request('http://example.org')
+ return content
- def test_credentials_refresh_without_storage(self):
- private_key = datafile('privatekey.%s' % self.format)
- credentials = SignedJwtAssertionCredentials(
+ def test_credentials_refresh_without_storage(self):
+ private_key = datafile('privatekey.%s' % self.format)
+ credentials = SignedJwtAssertionCredentials(
'some_account@example.com',
private_key,
scope='read+write',
sub='joe@example.org')
- content = self._credentials_refresh(credentials)
+ content = self._credentials_refresh(credentials)
- self.assertEqual(b'Bearer 3/3w', content[b'Authorization'])
+ self.assertEqual(b'Bearer 3/3w', content[b'Authorization'])
- def test_credentials_refresh_with_storage(self):
- private_key = datafile('privatekey.%s' % self.format)
- credentials = SignedJwtAssertionCredentials(
+ def test_credentials_refresh_with_storage(self):
+ private_key = datafile('privatekey.%s' % self.format)
+ credentials = SignedJwtAssertionCredentials(
'some_account@example.com',
private_key,
scope='read+write',
sub='joe@example.org')
- (filehandle, filename) = tempfile.mkstemp()
- os.close(filehandle)
- store = Storage(filename)
- store.put(credentials)
- credentials.set_store(store)
+ (filehandle, filename) = tempfile.mkstemp()
+ os.close(filehandle)
+ store = Storage(filename)
+ store.put(credentials)
+ credentials.set_store(store)
- content = self._credentials_refresh(credentials)
+ content = self._credentials_refresh(credentials)
- self.assertEqual(b'Bearer 3/3w', content[b'Authorization'])
- os.unlink(filename)
+ self.assertEqual(b'Bearer 3/3w', content[b'Authorization'])
+ os.unlink(filename)
class PEMSignedJwtAssertionCredentialsOpenSSLTests(
SignedJwtAssertionCredentialsTests):
- def setUp(self):
- self.format = 'pem'
- crypt.Signer = crypt.OpenSSLSigner
+ def setUp(self):
+ self.format = 'pem'
+ crypt.Signer = crypt.OpenSSLSigner
class PEMSignedJwtAssertionCredentialsPyCryptoTests(
SignedJwtAssertionCredentialsTests):
- def setUp(self):
- self.format = 'pem'
- crypt.Signer = crypt.PyCryptoSigner
+ def setUp(self):
+ self.format = 'pem'
+ crypt.Signer = crypt.PyCryptoSigner
class PKCSSignedJwtAssertionCredentialsPyCryptoTests(unittest.TestCase):
- def test_for_failure(self):
- crypt.Signer = crypt.PyCryptoSigner
- private_key = datafile('privatekey.p12')
- credentials = SignedJwtAssertionCredentials(
+ def test_for_failure(self):
+ crypt.Signer = crypt.PyCryptoSigner
+ private_key = datafile('privatekey.p12')
+ credentials = SignedJwtAssertionCredentials(
'some_account@example.com',
private_key,
scope='read+write',
sub='joe@example.org')
- try:
- credentials._generate_assertion()
- self.fail()
- except NotImplementedError:
- pass
+ try:
+ credentials._generate_assertion()
+ self.fail()
+ except NotImplementedError:
+ pass
class TestHasOpenSSLFlag(unittest.TestCase):
- def test_true(self):
- self.assertEqual(True, HAS_OPENSSL)
- self.assertEqual(True, HAS_CRYPTO)
+ def test_true(self):
+ self.assertEqual(True, HAS_OPENSSL)
+ self.assertEqual(True, HAS_CRYPTO)
if __name__ == '__main__':
- unittest.main()
+ unittest.main()
diff --git a/tests/test_keyring.py b/tests/test_keyring.py
index ee6ba26..8b0ea0d 100644
--- a/tests/test_keyring.py
+++ b/tests/test_keyring.py
@@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-
"""Tests for oauth2client.keyring_storage tests.
Unit tests for oauth2client.keyring_storage.
@@ -33,59 +32,59 @@ from oauth2client.keyring_storage import Storage
class OAuth2ClientKeyringTests(unittest.TestCase):
- def test_non_existent_credentials_storage(self):
- with mock.patch.object(keyring, 'get_password',
+ def test_non_existent_credentials_storage(self):
+ with mock.patch.object(keyring, 'get_password',
return_value=None,
autospec=True) as get_password:
- s = Storage('my_unit_test', 'me')
- credentials = s.get()
- self.assertEquals(None, credentials)
- get_password.assert_called_once_with('my_unit_test', 'me')
+ s = Storage('my_unit_test', 'me')
+ credentials = s.get()
+ self.assertEquals(None, credentials)
+ get_password.assert_called_once_with('my_unit_test', 'me')
- def test_malformed_credentials_in_storage(self):
- with mock.patch.object(keyring, 'get_password',
+ def test_malformed_credentials_in_storage(self):
+ with mock.patch.object(keyring, 'get_password',
return_value='{',
autospec=True) as get_password:
- s = Storage('my_unit_test', 'me')
- credentials = s.get()
- self.assertEquals(None, credentials)
- get_password.assert_called_once_with('my_unit_test', 'me')
+ s = Storage('my_unit_test', 'me')
+ credentials = s.get()
+ self.assertEquals(None, credentials)
+ get_password.assert_called_once_with('my_unit_test', 'me')
- def test_json_credentials_storage(self):
- access_token = 'foo'
- client_id = 'some_client_id'
- client_secret = 'cOuDdkfjxxnv+'
- refresh_token = '1/0/a.df219fjls0'
- token_expiry = datetime.datetime.utcnow()
- user_agent = 'refresh_checker/1.0'
+ def test_json_credentials_storage(self):
+ access_token = 'foo'
+ client_id = 'some_client_id'
+ client_secret = 'cOuDdkfjxxnv+'
+ refresh_token = '1/0/a.df219fjls0'
+ token_expiry = datetime.datetime.utcnow()
+ user_agent = 'refresh_checker/1.0'
- credentials = OAuth2Credentials(
+ credentials = OAuth2Credentials(
access_token, client_id, client_secret,
refresh_token, token_expiry, GOOGLE_TOKEN_URI,
user_agent)
- # Setting autospec on a mock with an iterable side_effect is
- # currently broken (http://bugs.python.org/issue17826), so instead
- # we patch twice.
- with mock.patch.object(keyring, 'get_password',
+ # Setting autospec on a mock with an iterable side_effect is
+ # currently broken (http://bugs.python.org/issue17826), so instead
+ # we patch twice.
+ with mock.patch.object(keyring, 'get_password',
return_value=None,
autospec=True) as get_password:
- with mock.patch.object(keyring, 'set_password',
+ with mock.patch.object(keyring, 'set_password',
return_value=None,
autospec=True) as set_password:
- s = Storage('my_unit_test', 'me')
- self.assertEquals(None, s.get())
+ s = Storage('my_unit_test', 'me')
+ self.assertEquals(None, s.get())
- s.put(credentials)
+ s.put(credentials)
- set_password.assert_called_once_with(
+ set_password.assert_called_once_with(
'my_unit_test', 'me', credentials.to_json())
- get_password.assert_called_once_with('my_unit_test', 'me')
+ get_password.assert_called_once_with('my_unit_test', 'me')
- with mock.patch.object(keyring, 'get_password',
+ with mock.patch.object(keyring, 'get_password',
return_value=credentials.to_json(),
autospec=True) as get_password:
- restored = s.get()
- self.assertEqual('foo', restored.access_token)
- self.assertEqual('some_client_id', restored.client_id)
- get_password.assert_called_once_with('my_unit_test', 'me')
+ restored = s.get()
+ self.assertEqual('foo', restored.access_token)
+ self.assertEqual('some_client_id', restored.client_id)
+ get_password.assert_called_once_with('my_unit_test', 'me')
diff --git a/tests/test_oauth2client.py b/tests/test_oauth2client.py
index 61565e6..d010827 100644
--- a/tests/test_oauth2client.py
+++ b/tests/test_oauth2client.py
@@ -14,7 +14,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-
"""Oauth2client tests
Unit tests for oauth2client.
@@ -88,488 +87,488 @@ DATA_DIR = os.path.join(os.path.dirname(__file__), 'data')
# TODO(craigcitro): This is duplicated from
# googleapiclient.test_discovery; consolidate these definitions.
def assertUrisEqual(testcase, expected, actual):
- """Test that URIs are the same, up to reordering of query parameters."""
- expected = urllib.parse.urlparse(expected)
- actual = urllib.parse.urlparse(actual)
- testcase.assertEqual(expected.scheme, actual.scheme)
- testcase.assertEqual(expected.netloc, actual.netloc)
- testcase.assertEqual(expected.path, actual.path)
- testcase.assertEqual(expected.params, actual.params)
- testcase.assertEqual(expected.fragment, actual.fragment)
- expected_query = urllib.parse.parse_qs(expected.query)
- actual_query = urllib.parse.parse_qs(actual.query)
- for name in expected_query.keys():
- testcase.assertEqual(expected_query[name], actual_query[name])
- for name in actual_query.keys():
- testcase.assertEqual(expected_query[name], actual_query[name])
+ """Test that URIs are the same, up to reordering of query parameters."""
+ expected = urllib.parse.urlparse(expected)
+ actual = urllib.parse.urlparse(actual)
+ testcase.assertEqual(expected.scheme, actual.scheme)
+ testcase.assertEqual(expected.netloc, actual.netloc)
+ testcase.assertEqual(expected.path, actual.path)
+ testcase.assertEqual(expected.params, actual.params)
+ testcase.assertEqual(expected.fragment, actual.fragment)
+ expected_query = urllib.parse.parse_qs(expected.query)
+ actual_query = urllib.parse.parse_qs(actual.query)
+ for name in expected_query.keys():
+ testcase.assertEqual(expected_query[name], actual_query[name])
+ for name in actual_query.keys():
+ testcase.assertEqual(expected_query[name], actual_query[name])
def datafile(filename):
- return os.path.join(DATA_DIR, filename)
+ return os.path.join(DATA_DIR, filename)
def load_and_cache(existing_file, fakename, cache_mock):
- client_type, client_info = _loadfile(datafile(existing_file))
- cache_mock.cache[fakename] = {client_type: client_info}
+ client_type, client_info = _loadfile(datafile(existing_file))
+ cache_mock.cache[fakename] = {client_type: client_info}
class CacheMock(object):
def __init__(self):
- self.cache = {}
+ self.cache = {}
def get(self, key, namespace=''):
- # ignoring namespace for easier testing
- return self.cache.get(key, None)
+ # ignoring namespace for easier testing
+ return self.cache.get(key, None)
def set(self, key, value, namespace=''):
- # ignoring namespace for easier testing
- self.cache[key] = value
+ # ignoring namespace for easier testing
+ self.cache[key] = value
class CredentialsTests(unittest.TestCase):
- def test_to_from_json(self):
- credentials = Credentials()
- json = credentials.to_json()
- restored = Credentials.new_from_json(json)
+ def test_to_from_json(self):
+ credentials = Credentials()
+ json = credentials.to_json()
+ restored = Credentials.new_from_json(json)
class MockResponse(object):
- """Mock the response of urllib2.urlopen() call."""
+ """Mock the response of urllib2.urlopen() call."""
- def __init__(self, headers):
- self._headers = headers
+ def __init__(self, headers):
+ self._headers = headers
- def info(self):
- class Info:
- def __init__(self, headers):
- self.headers = headers
+ def info(self):
+ class Info:
+ def __init__(self, headers):
+ self.headers = headers
- def get(self, key, default=None):
- return self.headers.get(key, default)
+ def get(self, key, default=None):
+ return self.headers.get(key, default)
- return Info(self._headers)
+ return Info(self._headers)
@contextlib.contextmanager
def mock_module_import(module):
- """Place a dummy objects in sys.modules to mock an import test."""
- parts = module.split('.')
- entries = ['.'.join(parts[:i+1]) for i in range(len(parts))]
- for entry in entries:
- sys.modules[entry] = object()
-
- try:
- yield
-
- finally:
+ """Place a dummy objects in sys.modules to mock an import test."""
+ parts = module.split('.')
+ entries = ['.'.join(parts[:i + 1]) for i in range(len(parts))]
for entry in entries:
- del sys.modules[entry]
+ sys.modules[entry] = object()
+
+ try:
+ yield
+
+ finally:
+ for entry in entries:
+ del sys.modules[entry]
class GoogleCredentialsTests(unittest.TestCase):
- def setUp(self):
- self.env_server_software = os.environ.get('SERVER_SOFTWARE', None)
- self.env_google_application_credentials = (
+ def setUp(self):
+ self.env_server_software = os.environ.get('SERVER_SOFTWARE', None)
+ self.env_google_application_credentials = (
os.environ.get(GOOGLE_APPLICATION_CREDENTIALS, None))
- self.env_appdata = os.environ.get('APPDATA', None)
- self.os_name = os.name
- from oauth2client import client
- client.SETTINGS.env_name = None
+ self.env_appdata = os.environ.get('APPDATA', None)
+ self.os_name = os.name
+ from oauth2client import client
+ client.SETTINGS.env_name = None
- def tearDown(self):
- self.reset_env('SERVER_SOFTWARE', self.env_server_software)
- self.reset_env(GOOGLE_APPLICATION_CREDENTIALS,
+ def tearDown(self):
+ self.reset_env('SERVER_SOFTWARE', self.env_server_software)
+ self.reset_env(GOOGLE_APPLICATION_CREDENTIALS,
self.env_google_application_credentials)
- self.reset_env('APPDATA', self.env_appdata)
- os.name = self.os_name
+ self.reset_env('APPDATA', self.env_appdata)
+ os.name = self.os_name
- def reset_env(self, env, value):
- """Set the environment variable 'env' to 'value'."""
- if value is not None:
- os.environ[env] = value
- else:
- os.environ.pop(env, '')
+ def reset_env(self, env, value):
+ """Set the environment variable 'env' to 'value'."""
+ if value is not None:
+ os.environ[env] = value
+ else:
+ os.environ.pop(env, '')
- def validate_service_account_credentials(self, credentials):
- self.assertTrue(isinstance(credentials, _ServiceAccountCredentials))
- self.assertEqual('123', credentials._service_account_id)
- self.assertEqual('dummy@google.com', credentials._service_account_email)
- self.assertEqual('ABCDEF', credentials._private_key_id)
- self.assertEqual('', credentials._scopes)
+ def validate_service_account_credentials(self, credentials):
+ self.assertTrue(isinstance(credentials, _ServiceAccountCredentials))
+ self.assertEqual('123', credentials._service_account_id)
+ self.assertEqual('dummy@google.com', credentials._service_account_email)
+ self.assertEqual('ABCDEF', credentials._private_key_id)
+ self.assertEqual('', credentials._scopes)
- def validate_google_credentials(self, credentials):
- self.assertTrue(isinstance(credentials, GoogleCredentials))
- self.assertEqual(None, credentials.access_token)
- self.assertEqual('123', credentials.client_id)
- self.assertEqual('secret', credentials.client_secret)
- self.assertEqual('alabalaportocala', credentials.refresh_token)
- self.assertEqual(None, credentials.token_expiry)
- self.assertEqual(GOOGLE_TOKEN_URI, credentials.token_uri)
- self.assertEqual('Python client library', credentials.user_agent)
+ def validate_google_credentials(self, credentials):
+ self.assertTrue(isinstance(credentials, GoogleCredentials))
+ self.assertEqual(None, credentials.access_token)
+ self.assertEqual('123', credentials.client_id)
+ self.assertEqual('secret', credentials.client_secret)
+ self.assertEqual('alabalaportocala', credentials.refresh_token)
+ self.assertEqual(None, credentials.token_expiry)
+ self.assertEqual(GOOGLE_TOKEN_URI, credentials.token_uri)
+ self.assertEqual('Python client library', credentials.user_agent)
- def get_a_google_credentials_object(self):
- return GoogleCredentials(None, None, None, None, None, None, None, None)
+ def get_a_google_credentials_object(self):
+ return GoogleCredentials(None, None, None, None, None, None, None, None)
- def test_create_scoped_required(self):
- self.assertFalse(
+ def test_create_scoped_required(self):
+ self.assertFalse(
self.get_a_google_credentials_object().create_scoped_required())
- def test_create_scoped(self):
- credentials = self.get_a_google_credentials_object()
- self.assertEqual(credentials, credentials.create_scoped(None))
- self.assertEqual(credentials,
+ def test_create_scoped(self):
+ credentials = self.get_a_google_credentials_object()
+ self.assertEqual(credentials, credentials.create_scoped(None))
+ self.assertEqual(credentials,
credentials.create_scoped(['dummy_scope']))
- def test_environment_check_gae_production(self):
- with mock_module_import('google.appengine'):
- os.environ['SERVER_SOFTWARE'] = 'Google App Engine/XYZ'
- self.assertTrue(_in_gae_environment())
- self.assertFalse(_in_gce_environment())
+ def test_environment_check_gae_production(self):
+ with mock_module_import('google.appengine'):
+ os.environ['SERVER_SOFTWARE'] = 'Google App Engine/XYZ'
+ self.assertTrue(_in_gae_environment())
+ self.assertFalse(_in_gce_environment())
- def test_environment_check_gae_local(self):
- with mock_module_import('google.appengine'):
- os.environ['SERVER_SOFTWARE'] = 'Development/XYZ'
- self.assertTrue(_in_gae_environment())
- self.assertFalse(_in_gce_environment())
+ def test_environment_check_gae_local(self):
+ with mock_module_import('google.appengine'):
+ os.environ['SERVER_SOFTWARE'] = 'Development/XYZ'
+ self.assertTrue(_in_gae_environment())
+ self.assertFalse(_in_gce_environment())
- def test_environment_check_fastpath(self):
- os.environ['SERVER_SOFTWARE'] = 'Development/XYZ'
- with mock_module_import('google.appengine'):
- with mock.patch.object(urllib.request, 'urlopen',
+ def test_environment_check_fastpath(self):
+ os.environ['SERVER_SOFTWARE'] = 'Development/XYZ'
+ with mock_module_import('google.appengine'):
+ with mock.patch.object(urllib.request, 'urlopen',
return_value=MockResponse({}),
autospec=True) as urlopen:
- self.assertTrue(_in_gae_environment())
- self.assertFalse(_in_gce_environment())
- # We already know are in GAE, so we shouldn't actually do the urlopen.
- self.assertFalse(urlopen.called)
+ self.assertTrue(_in_gae_environment())
+ self.assertFalse(_in_gce_environment())
+ # We already know are in GAE, so we shouldn't actually do the urlopen.
+ self.assertFalse(urlopen.called)
- def test_environment_caching(self):
- os.environ['SERVER_SOFTWARE'] = 'Development/XYZ'
- with mock_module_import('google.appengine'):
- self.assertTrue(_in_gae_environment())
- os.environ['SERVER_SOFTWARE'] = ''
- # Even though we no longer pass the environment check, it is cached.
- self.assertTrue(_in_gae_environment())
+ def test_environment_caching(self):
+ os.environ['SERVER_SOFTWARE'] = 'Development/XYZ'
+ with mock_module_import('google.appengine'):
+ self.assertTrue(_in_gae_environment())
+ os.environ['SERVER_SOFTWARE'] = ''
+ # Even though we no longer pass the environment check, it is cached.
+ self.assertTrue(_in_gae_environment())
- def test_environment_check_gae_module_on_gce(self):
- with mock_module_import('google.appengine'):
- os.environ['SERVER_SOFTWARE'] = ''
- response = MockResponse({'Metadata-Flavor': 'Google'})
- with mock.patch.object(urllib.request, 'urlopen',
+ def test_environment_check_gae_module_on_gce(self):
+ with mock_module_import('google.appengine'):
+ os.environ['SERVER_SOFTWARE'] = ''
+ response = MockResponse({'Metadata-Flavor': 'Google'})
+ with mock.patch.object(urllib.request, 'urlopen',
return_value=response,
autospec=True) as urlopen:
- self.assertFalse(_in_gae_environment())
- self.assertTrue(_in_gce_environment())
- urlopen.assert_called_once_with(
+ self.assertFalse(_in_gae_environment())
+ self.assertTrue(_in_gce_environment())
+ urlopen.assert_called_once_with(
'http://169.254.169.254/', timeout=1)
- def test_environment_check_gae_module_unknown(self):
- with mock_module_import('google.appengine'):
- os.environ['SERVER_SOFTWARE'] = ''
- with mock.patch.object(urllib.request, 'urlopen',
+ def test_environment_check_gae_module_unknown(self):
+ with mock_module_import('google.appengine'):
+ os.environ['SERVER_SOFTWARE'] = ''
+ with mock.patch.object(urllib.request, 'urlopen',
return_value=MockResponse({}),
autospec=True) as urlopen:
- self.assertFalse(_in_gae_environment())
- self.assertFalse(_in_gce_environment())
- urlopen.assert_called_once_with(
+ self.assertFalse(_in_gae_environment())
+ self.assertFalse(_in_gce_environment())
+ urlopen.assert_called_once_with(
'http://169.254.169.254/', timeout=1)
- def test_environment_check_gce_production(self):
- os.environ['SERVER_SOFTWARE'] = ''
- response = MockResponse({'Metadata-Flavor': 'Google'})
- with mock.patch.object(urllib.request, 'urlopen',
+ def test_environment_check_gce_production(self):
+ os.environ['SERVER_SOFTWARE'] = ''
+ response = MockResponse({'Metadata-Flavor': 'Google'})
+ with mock.patch.object(urllib.request, 'urlopen',
return_value=response,
autospec=True) as urlopen:
- self.assertFalse(_in_gae_environment())
- self.assertTrue(_in_gce_environment())
- urlopen.assert_called_once_with(
+ self.assertFalse(_in_gae_environment())
+ self.assertTrue(_in_gce_environment())
+ urlopen.assert_called_once_with(
'http://169.254.169.254/', timeout=1)
- def test_environment_check_gce_timeout(self):
- os.environ['SERVER_SOFTWARE'] = ''
- response = MockResponse({'Metadata-Flavor': 'Google'})
- with mock.patch.object(urllib.request, 'urlopen',
+ def test_environment_check_gce_timeout(self):
+ os.environ['SERVER_SOFTWARE'] = ''
+ response = MockResponse({'Metadata-Flavor': 'Google'})
+ with mock.patch.object(urllib.request, 'urlopen',
return_value=response,
autospec=True) as urlopen:
- urlopen.side_effect = socket.timeout()
- self.assertFalse(_in_gce_environment())
- urlopen.assert_called_once_with(
+ urlopen.side_effect = socket.timeout()
+ self.assertFalse(_in_gce_environment())
+ urlopen.assert_called_once_with(
'http://169.254.169.254/', timeout=1)
- with mock.patch.object(urllib.request, 'urlopen',
+ with mock.patch.object(urllib.request, 'urlopen',
return_value=response,
autospec=True) as urlopen:
- urlopen.side_effect = urllib.error.URLError(socket.timeout())
- self.assertFalse(_in_gce_environment())
- urlopen.assert_called_once_with(
+ urlopen.side_effect = urllib.error.URLError(socket.timeout())
+ self.assertFalse(_in_gce_environment())
+ urlopen.assert_called_once_with(
'http://169.254.169.254/', timeout=1)
- def test_environment_check_unknown(self):
- os.environ['SERVER_SOFTWARE'] = ''
- with mock.patch.object(urllib.request, 'urlopen',
+ def test_environment_check_unknown(self):
+ os.environ['SERVER_SOFTWARE'] = ''
+ with mock.patch.object(urllib.request, 'urlopen',
return_value=MockResponse({}),
autospec=True) as urlopen:
- self.assertFalse(_in_gce_environment())
- self.assertFalse(_in_gae_environment())
- urlopen.assert_called_once_with(
+ self.assertFalse(_in_gce_environment())
+ self.assertFalse(_in_gae_environment())
+ urlopen.assert_called_once_with(
'http://169.254.169.254/', timeout=1)
- def test_get_environment_variable_file(self):
- environment_variable_file = datafile(
+ def test_get_environment_variable_file(self):
+ environment_variable_file = datafile(
os.path.join('gcloud', 'application_default_credentials.json'))
- os.environ[GOOGLE_APPLICATION_CREDENTIALS] = environment_variable_file
- self.assertEqual(environment_variable_file,
+ os.environ[GOOGLE_APPLICATION_CREDENTIALS] = environment_variable_file
+ self.assertEqual(environment_variable_file,
_get_environment_variable_file())
- def test_get_environment_variable_file_error(self):
- nonexistent_file = datafile('nonexistent')
- os.environ[GOOGLE_APPLICATION_CREDENTIALS] = nonexistent_file
- # we can't use self.assertRaisesRegexp() because it is only in Python 2.7+
- try:
- _get_environment_variable_file()
- self.fail(nonexistent_file + ' should not exist.')
- except ApplicationDefaultCredentialsError as error:
- self.assertEqual('File ' + nonexistent_file +
+ def test_get_environment_variable_file_error(self):
+ nonexistent_file = datafile('nonexistent')
+ os.environ[GOOGLE_APPLICATION_CREDENTIALS] = nonexistent_file
+ # we can't use self.assertRaisesRegexp() because it is only in Python 2.7+
+ try:
+ _get_environment_variable_file()
+ self.fail(nonexistent_file + ' should not exist.')
+ except ApplicationDefaultCredentialsError as error:
+ self.assertEqual('File ' + nonexistent_file +
' (pointed by ' + GOOGLE_APPLICATION_CREDENTIALS +
' environment variable) does not exist!',
str(error))
- def test_get_well_known_file_on_windows(self):
- ORIGINAL_ISDIR = os.path.isdir
- try:
- os.path.isdir = lambda path: True
- well_known_file = datafile(
+ def test_get_well_known_file_on_windows(self):
+ ORIGINAL_ISDIR = os.path.isdir
+ try:
+ os.path.isdir = lambda path: True
+ well_known_file = datafile(
os.path.join(client._CLOUDSDK_CONFIG_DIRECTORY,
'application_default_credentials.json'))
- os.name = 'nt'
- os.environ['APPDATA'] = DATA_DIR
- self.assertEqual(well_known_file, _get_well_known_file())
- finally:
- os.path.isdir = ORIGINAL_ISDIR
+ os.name = 'nt'
+ os.environ['APPDATA'] = DATA_DIR
+ self.assertEqual(well_known_file, _get_well_known_file())
+ finally:
+ os.path.isdir = ORIGINAL_ISDIR
- def test_get_well_known_file_with_custom_config_dir(self):
- ORIGINAL_ENVIRON = os.environ
- ORIGINAL_ISDIR = os.path.isdir
- CUSTOM_DIR = 'CUSTOM_DIR'
- EXPECTED_FILE = os.path.join(CUSTOM_DIR,
+ def test_get_well_known_file_with_custom_config_dir(self):
+ ORIGINAL_ENVIRON = os.environ
+ ORIGINAL_ISDIR = os.path.isdir
+ CUSTOM_DIR = 'CUSTOM_DIR'
+ EXPECTED_FILE = os.path.join(CUSTOM_DIR,
'application_default_credentials.json')
- try:
- os.environ = {client._CLOUDSDK_CONFIG_ENV_VAR: CUSTOM_DIR}
- os.path.isdir = lambda path: True
- well_known_file = _get_well_known_file()
- self.assertEqual(well_known_file, EXPECTED_FILE)
- finally:
- os.environ = ORIGINAL_ENVIRON
- os.path.isdir = ORIGINAL_ISDIR
+ try:
+ os.environ = {client._CLOUDSDK_CONFIG_ENV_VAR: CUSTOM_DIR}
+ os.path.isdir = lambda path: True
+ well_known_file = _get_well_known_file()
+ self.assertEqual(well_known_file, EXPECTED_FILE)
+ finally:
+ os.environ = ORIGINAL_ENVIRON
+ os.path.isdir = ORIGINAL_ISDIR
- def test_get_application_default_credential_from_file_service_account(self):
- credentials_file = datafile(
+ def test_get_application_default_credential_from_file_service_account(self):
+ credentials_file = datafile(
os.path.join('gcloud', 'application_default_credentials.json'))
- credentials = _get_application_default_credential_from_file(
+ credentials = _get_application_default_credential_from_file(
credentials_file)
- self.validate_service_account_credentials(credentials)
+ self.validate_service_account_credentials(credentials)
- def test_save_to_well_known_file_service_account(self):
- credential_file = datafile(
+ def test_save_to_well_known_file_service_account(self):
+ credential_file = datafile(
os.path.join('gcloud', 'application_default_credentials.json'))
- credentials = _get_application_default_credential_from_file(
+ credentials = _get_application_default_credential_from_file(
credential_file)
- temp_credential_file = datafile(
+ temp_credential_file = datafile(
os.path.join('gcloud', 'temp_well_known_file_service_account.json'))
- save_to_well_known_file(credentials, temp_credential_file)
- with open(temp_credential_file) as f:
- d = json.load(f)
- self.assertEqual('service_account', d['type'])
- self.assertEqual('123', d['client_id'])
- self.assertEqual('dummy@google.com', d['client_email'])
- self.assertEqual('ABCDEF', d['private_key_id'])
- os.remove(temp_credential_file)
+ save_to_well_known_file(credentials, temp_credential_file)
+ with open(temp_credential_file) as f:
+ d = json.load(f)
+ self.assertEqual('service_account', d['type'])
+ self.assertEqual('123', d['client_id'])
+ self.assertEqual('dummy@google.com', d['client_email'])
+ self.assertEqual('ABCDEF', d['private_key_id'])
+ os.remove(temp_credential_file)
- def test_save_well_known_file_with_non_existent_config_dir(self):
- credential_file = datafile(
+ def test_save_well_known_file_with_non_existent_config_dir(self):
+ credential_file = datafile(
os.path.join('gcloud', 'application_default_credentials.json'))
- credentials = _get_application_default_credential_from_file(
+ credentials = _get_application_default_credential_from_file(
credential_file)
- ORIGINAL_ISDIR = os.path.isdir
- try:
- os.path.isdir = lambda path: False
- self.assertRaises(OSError, save_to_well_known_file, credentials)
- finally:
- os.path.isdir = ORIGINAL_ISDIR
+ ORIGINAL_ISDIR = os.path.isdir
+ try:
+ os.path.isdir = lambda path: False
+ self.assertRaises(OSError, save_to_well_known_file, credentials)
+ finally:
+ os.path.isdir = ORIGINAL_ISDIR
- def test_get_application_default_credential_from_file_authorized_user(self):
- credentials_file = datafile(
+ def test_get_application_default_credential_from_file_authorized_user(self):
+ credentials_file = datafile(
os.path.join('gcloud',
'application_default_credentials_authorized_user.json'))
- credentials = _get_application_default_credential_from_file(
+ credentials = _get_application_default_credential_from_file(
credentials_file)
- self.validate_google_credentials(credentials)
+ self.validate_google_credentials(credentials)
- def test_save_to_well_known_file_authorized_user(self):
- credentials_file = datafile(
+ def test_save_to_well_known_file_authorized_user(self):
+ credentials_file = datafile(
os.path.join('gcloud',
'application_default_credentials_authorized_user.json'))
- credentials = _get_application_default_credential_from_file(
+ credentials = _get_application_default_credential_from_file(
credentials_file)
- temp_credential_file = datafile(
+ temp_credential_file = datafile(
os.path.join('gcloud', 'temp_well_known_file_authorized_user.json'))
- save_to_well_known_file(credentials, temp_credential_file)
- with open(temp_credential_file) as f:
- d = json.load(f)
- self.assertEqual('authorized_user', d['type'])
- self.assertEqual('123', d['client_id'])
- self.assertEqual('secret', d['client_secret'])
- self.assertEqual('alabalaportocala', d['refresh_token'])
- os.remove(temp_credential_file)
+ save_to_well_known_file(credentials, temp_credential_file)
+ with open(temp_credential_file) as f:
+ d = json.load(f)
+ self.assertEqual('authorized_user', d['type'])
+ self.assertEqual('123', d['client_id'])
+ self.assertEqual('secret', d['client_secret'])
+ self.assertEqual('alabalaportocala', d['refresh_token'])
+ os.remove(temp_credential_file)
- def test_get_application_default_credential_from_malformed_file_1(self):
- credentials_file = datafile(
+ def test_get_application_default_credential_from_malformed_file_1(self):
+ credentials_file = datafile(
os.path.join('gcloud',
'application_default_credentials_malformed_1.json'))
- # we can't use self.assertRaisesRegexp() because it is only in Python 2.7+
- try:
- _get_application_default_credential_from_file(credentials_file)
- self.fail('An exception was expected!')
- except ApplicationDefaultCredentialsError as error:
- self.assertEqual("'type' field should be defined "
+ # we can't use self.assertRaisesRegexp() because it is only in Python 2.7+
+ try:
+ _get_application_default_credential_from_file(credentials_file)
+ self.fail('An exception was expected!')
+ except ApplicationDefaultCredentialsError as error:
+ self.assertEqual("'type' field should be defined "
"(and have one of the '" + AUTHORIZED_USER +
"' or '" + SERVICE_ACCOUNT + "' values)",
str(error))
- def test_get_application_default_credential_from_malformed_file_2(self):
- credentials_file = datafile(
+ def test_get_application_default_credential_from_malformed_file_2(self):
+ credentials_file = datafile(
os.path.join('gcloud',
'application_default_credentials_malformed_2.json'))
- # we can't use self.assertRaisesRegexp() because it is only in Python 2.7+
- try:
- _get_application_default_credential_from_file(credentials_file)
- self.fail('An exception was expected!')
- except ApplicationDefaultCredentialsError as error:
- self.assertEqual('The following field(s) must be defined: private_key_id',
+ # we can't use self.assertRaisesRegexp() because it is only in Python 2.7+
+ try:
+ _get_application_default_credential_from_file(credentials_file)
+ self.fail('An exception was expected!')
+ except ApplicationDefaultCredentialsError as error:
+ self.assertEqual('The following field(s) must be defined: private_key_id',
str(error))
- def test_get_application_default_credential_from_malformed_file_3(self):
- credentials_file = datafile(
+ def test_get_application_default_credential_from_malformed_file_3(self):
+ credentials_file = datafile(
os.path.join('gcloud',
'application_default_credentials_malformed_3.json'))
- self.assertRaises(ValueError, _get_application_default_credential_from_file,
+ self.assertRaises(ValueError, _get_application_default_credential_from_file,
credentials_file)
- def test_raise_exception_for_missing_fields(self):
- missing_fields = ['first', 'second', 'third']
- # we can't use self.assertRaisesRegexp() because it is only in Python 2.7+
- try:
- _raise_exception_for_missing_fields(missing_fields)
- self.fail('An exception was expected!')
- except ApplicationDefaultCredentialsError as error:
- self.assertEqual('The following field(s) must be defined: ' +
+ def test_raise_exception_for_missing_fields(self):
+ missing_fields = ['first', 'second', 'third']
+ # we can't use self.assertRaisesRegexp() because it is only in Python 2.7+
+ try:
+ _raise_exception_for_missing_fields(missing_fields)
+ self.fail('An exception was expected!')
+ except ApplicationDefaultCredentialsError as error:
+ self.assertEqual('The following field(s) must be defined: ' +
', '.join(missing_fields),
str(error))
- def test_raise_exception_for_reading_json(self):
- credential_file = 'any_file'
- extra_help = ' be good'
- error = ApplicationDefaultCredentialsError('stuff happens')
- # we can't use self.assertRaisesRegexp() because it is only in Python 2.7+
- try:
- _raise_exception_for_reading_json(credential_file, extra_help, error)
- self.fail('An exception was expected!')
- except ApplicationDefaultCredentialsError as ex:
- self.assertEqual('An error was encountered while reading '
- 'json file: '+ credential_file +
+ def test_raise_exception_for_reading_json(self):
+ credential_file = 'any_file'
+ extra_help = ' be good'
+ error = ApplicationDefaultCredentialsError('stuff happens')
+ # we can't use self.assertRaisesRegexp() because it is only in Python 2.7+
+ try:
+ _raise_exception_for_reading_json(credential_file, extra_help, error)
+ self.fail('An exception was expected!')
+ except ApplicationDefaultCredentialsError as ex:
+ self.assertEqual('An error was encountered while reading '
+ 'json file: ' + credential_file +
extra_help + ': ' + str(error),
str(ex))
- def test_get_application_default_from_environment_variable_service_account(
+ def test_get_application_default_from_environment_variable_service_account(
self):
- os.environ['SERVER_SOFTWARE'] = ''
- environment_variable_file = datafile(
+ os.environ['SERVER_SOFTWARE'] = ''
+ environment_variable_file = datafile(
os.path.join('gcloud', 'application_default_credentials.json'))
- os.environ[GOOGLE_APPLICATION_CREDENTIALS] = environment_variable_file
- self.validate_service_account_credentials(
+ os.environ[GOOGLE_APPLICATION_CREDENTIALS] = environment_variable_file
+ self.validate_service_account_credentials(
GoogleCredentials.get_application_default())
- def test_env_name(self):
- from oauth2client import client
- self.assertEqual(None, client.SETTINGS.env_name)
- self.test_get_application_default_from_environment_variable_service_account()
- self.assertEqual(DEFAULT_ENV_NAME, client.SETTINGS.env_name)
+ def test_env_name(self):
+ from oauth2client import client
+ self.assertEqual(None, client.SETTINGS.env_name)
+ self.test_get_application_default_from_environment_variable_service_account()
+ self.assertEqual(DEFAULT_ENV_NAME, client.SETTINGS.env_name)
- def test_get_application_default_from_environment_variable_authorized_user(
+ def test_get_application_default_from_environment_variable_authorized_user(
self):
- os.environ['SERVER_SOFTWARE'] = ''
- environment_variable_file = datafile(
+ os.environ['SERVER_SOFTWARE'] = ''
+ environment_variable_file = datafile(
os.path.join('gcloud',
'application_default_credentials_authorized_user.json'))
- os.environ[GOOGLE_APPLICATION_CREDENTIALS] = environment_variable_file
- self.validate_google_credentials(
+ os.environ[GOOGLE_APPLICATION_CREDENTIALS] = environment_variable_file
+ self.validate_google_credentials(
GoogleCredentials.get_application_default())
- def test_get_application_default_from_environment_variable_malformed_file(
+ def test_get_application_default_from_environment_variable_malformed_file(
self):
- os.environ['SERVER_SOFTWARE'] = ''
- environment_variable_file = datafile(
+ os.environ['SERVER_SOFTWARE'] = ''
+ environment_variable_file = datafile(
os.path.join('gcloud',
'application_default_credentials_malformed_3.json'))
- os.environ[GOOGLE_APPLICATION_CREDENTIALS] = environment_variable_file
- # we can't use self.assertRaisesRegexp() because it is only in Python 2.7+
- try:
- GoogleCredentials.get_application_default()
- self.fail('An exception was expected!')
- except ApplicationDefaultCredentialsError as error:
- self.assertTrue(str(error).startswith(
+ os.environ[GOOGLE_APPLICATION_CREDENTIALS] = environment_variable_file
+ # we can't use self.assertRaisesRegexp() because it is only in Python 2.7+
+ try:
+ GoogleCredentials.get_application_default()
+ self.fail('An exception was expected!')
+ except ApplicationDefaultCredentialsError as error:
+ self.assertTrue(str(error).startswith(
'An error was encountered while reading json file: ' +
environment_variable_file + ' (pointed to by ' +
GOOGLE_APPLICATION_CREDENTIALS + ' environment variable):'))
- def test_get_application_default_environment_not_set_up(self):
- # It is normal for this test to fail if run inside
- # a Google Compute Engine VM or after 'gcloud auth login' command
- # has been executed on a non Windows machine.
- os.environ['SERVER_SOFTWARE'] = ''
- os.environ[GOOGLE_APPLICATION_CREDENTIALS] = ''
- os.environ['APPDATA'] = ''
- # we can't use self.assertRaisesRegexp() because it is only in Python 2.7+
- VALID_CONFIG_DIR = client._CLOUDSDK_CONFIG_DIRECTORY
- ORIGINAL_ISDIR = os.path.isdir
- try:
- os.path.isdir = lambda path: True
- client._CLOUDSDK_CONFIG_DIRECTORY = 'BOGUS_CONFIG_DIR'
- GoogleCredentials.get_application_default()
- self.fail('An exception was expected!')
- except ApplicationDefaultCredentialsError as error:
- self.assertEqual(ADC_HELP_MSG, str(error))
- finally:
- os.path.isdir = ORIGINAL_ISDIR
- client._CLOUDSDK_CONFIG_DIRECTORY = VALID_CONFIG_DIR
+ def test_get_application_default_environment_not_set_up(self):
+ # It is normal for this test to fail if run inside
+ # a Google Compute Engine VM or after 'gcloud auth login' command
+ # has been executed on a non Windows machine.
+ os.environ['SERVER_SOFTWARE'] = ''
+ os.environ[GOOGLE_APPLICATION_CREDENTIALS] = ''
+ os.environ['APPDATA'] = ''
+ # we can't use self.assertRaisesRegexp() because it is only in Python 2.7+
+ VALID_CONFIG_DIR = client._CLOUDSDK_CONFIG_DIRECTORY
+ ORIGINAL_ISDIR = os.path.isdir
+ try:
+ os.path.isdir = lambda path: True
+ client._CLOUDSDK_CONFIG_DIRECTORY = 'BOGUS_CONFIG_DIR'
+ GoogleCredentials.get_application_default()
+ self.fail('An exception was expected!')
+ except ApplicationDefaultCredentialsError as error:
+ self.assertEqual(ADC_HELP_MSG, str(error))
+ finally:
+ os.path.isdir = ORIGINAL_ISDIR
+ client._CLOUDSDK_CONFIG_DIRECTORY = VALID_CONFIG_DIR
- def test_from_stream_service_account(self):
- credentials_file = datafile(
+ def test_from_stream_service_account(self):
+ credentials_file = datafile(
os.path.join('gcloud', 'application_default_credentials.json'))
- credentials = (
+ credentials = (
self.get_a_google_credentials_object().from_stream(credentials_file))
- self.validate_service_account_credentials(credentials)
+ self.validate_service_account_credentials(credentials)
- def test_from_stream_authorized_user(self):
- credentials_file = datafile(
+ def test_from_stream_authorized_user(self):
+ credentials_file = datafile(
os.path.join('gcloud',
'application_default_credentials_authorized_user.json'))
- credentials = (
+ credentials = (
self.get_a_google_credentials_object().from_stream(credentials_file))
- self.validate_google_credentials(credentials)
+ self.validate_google_credentials(credentials)
- def test_from_stream_malformed_file_1(self):
- credentials_file = datafile(
+ def test_from_stream_malformed_file_1(self):
+ credentials_file = datafile(
os.path.join('gcloud',
'application_default_credentials_malformed_1.json'))
- # we can't use self.assertRaisesRegexp() because it is only in Python 2.7+
- try:
- self.get_a_google_credentials_object().from_stream(credentials_file)
- self.fail('An exception was expected!')
- except ApplicationDefaultCredentialsError as error:
- self.assertEqual("An error was encountered while reading json file: " +
+ # we can't use self.assertRaisesRegexp() because it is only in Python 2.7+
+ try:
+ self.get_a_google_credentials_object().from_stream(credentials_file)
+ self.fail('An exception was expected!')
+ except ApplicationDefaultCredentialsError as error:
+ self.assertEqual("An error was encountered while reading json file: " +
credentials_file +
" (provided as parameter to the from_stream() method): "
"'type' field should be defined (and have one of the '" +
@@ -577,107 +576,108 @@ class GoogleCredentialsTests(unittest.TestCase):
"' values)",
str(error))
- def test_from_stream_malformed_file_2(self):
- credentials_file = datafile(
+ def test_from_stream_malformed_file_2(self):
+ credentials_file = datafile(
os.path.join('gcloud',
'application_default_credentials_malformed_2.json'))
- # we can't use self.assertRaisesRegexp() because it is only in Python 2.7+
- try:
- self.get_a_google_credentials_object().from_stream(credentials_file)
- self.fail('An exception was expected!')
- except ApplicationDefaultCredentialsError as error:
- self.assertEqual('An error was encountered while reading json file: ' +
+ # we can't use self.assertRaisesRegexp() because it is only in Python 2.7+
+ try:
+ self.get_a_google_credentials_object().from_stream(credentials_file)
+ self.fail('An exception was expected!')
+ except ApplicationDefaultCredentialsError as error:
+ self.assertEqual('An error was encountered while reading json file: ' +
credentials_file +
' (provided as parameter to the from_stream() method): '
'The following field(s) must be defined: '
'private_key_id',
str(error))
- def test_from_stream_malformed_file_3(self):
- credentials_file = datafile(
+ def test_from_stream_malformed_file_3(self):
+ credentials_file = datafile(
os.path.join('gcloud',
'application_default_credentials_malformed_3.json'))
- self.assertRaises(
+ self.assertRaises(
ApplicationDefaultCredentialsError,
self.get_a_google_credentials_object().from_stream, credentials_file)
class DummyDeleteStorage(Storage):
- delete_called = False
+ delete_called = False
- def locked_delete(self):
- self.delete_called = True
+ def locked_delete(self):
+ self.delete_called = True
def _token_revoke_test_helper(testcase, status, revoke_raise,
valid_bool_value, token_attr):
- current_store = getattr(testcase.credentials, 'store', None)
+ current_store = getattr(testcase.credentials, 'store', None)
- dummy_store = DummyDeleteStorage()
- testcase.credentials.set_store(dummy_store)
+ dummy_store = DummyDeleteStorage()
+ testcase.credentials.set_store(dummy_store)
- actual_do_revoke = testcase.credentials._do_revoke
- testcase.token_from_revoke = None
- def do_revoke_stub(http_request, token):
- testcase.token_from_revoke = token
- return actual_do_revoke(http_request, token)
- testcase.credentials._do_revoke = do_revoke_stub
+ actual_do_revoke = testcase.credentials._do_revoke
+ testcase.token_from_revoke = None
- http = HttpMock(headers={'status': status})
- if revoke_raise:
- testcase.assertRaises(TokenRevokeError, testcase.credentials.revoke, http)
- else:
- testcase.credentials.revoke(http)
+ def do_revoke_stub(http_request, token):
+ testcase.token_from_revoke = token
+ return actual_do_revoke(http_request, token)
+ testcase.credentials._do_revoke = do_revoke_stub
- testcase.assertEqual(getattr(testcase.credentials, token_attr),
+ http = HttpMock(headers={'status': status})
+ if revoke_raise:
+ testcase.assertRaises(TokenRevokeError, testcase.credentials.revoke, http)
+ else:
+ testcase.credentials.revoke(http)
+
+ testcase.assertEqual(getattr(testcase.credentials, token_attr),
testcase.token_from_revoke)
- testcase.assertEqual(valid_bool_value, testcase.credentials.invalid)
- testcase.assertEqual(valid_bool_value, dummy_store.delete_called)
+ testcase.assertEqual(valid_bool_value, testcase.credentials.invalid)
+ testcase.assertEqual(valid_bool_value, dummy_store.delete_called)
- testcase.credentials.set_store(current_store)
+ testcase.credentials.set_store(current_store)
class BasicCredentialsTests(unittest.TestCase):
- def setUp(self):
- access_token = 'foo'
- client_id = 'some_client_id'
- client_secret = 'cOuDdkfjxxnv+'
- refresh_token = '1/0/a.df219fjls0'
- token_expiry = datetime.datetime.utcnow()
- user_agent = 'refresh_checker/1.0'
- self.credentials = OAuth2Credentials(
+ def setUp(self):
+ access_token = 'foo'
+ client_id = 'some_client_id'
+ client_secret = 'cOuDdkfjxxnv+'
+ refresh_token = '1/0/a.df219fjls0'
+ token_expiry = datetime.datetime.utcnow()
+ user_agent = 'refresh_checker/1.0'
+ self.credentials = OAuth2Credentials(
access_token, client_id, client_secret,
refresh_token, token_expiry, GOOGLE_TOKEN_URI,
user_agent, revoke_uri=GOOGLE_REVOKE_URI, scopes='foo',
token_info_uri=GOOGLE_TOKEN_INFO_URI)
- # Provoke a failure if @util.positional is not respected.
- self.old_positional_enforcement = (
+ # Provoke a failure if @util.positional is not respected.
+ self.old_positional_enforcement = (
oauth2client_util.positional_parameters_enforcement)
- oauth2client_util.positional_parameters_enforcement = (
+ oauth2client_util.positional_parameters_enforcement = (
oauth2client_util.POSITIONAL_EXCEPTION)
- def tearDown(self):
- oauth2client_util.positional_parameters_enforcement = (
+ def tearDown(self):
+ oauth2client_util.positional_parameters_enforcement = (
self.old_positional_enforcement)
- def test_token_refresh_success(self):
- for status_code in REFRESH_STATUS_CODES:
- token_response = {'access_token': '1/3w', 'expires_in': 3600}
- http = HttpMockSequence([
+ def test_token_refresh_success(self):
+ for status_code in REFRESH_STATUS_CODES:
+ token_response = {'access_token': '1/3w', 'expires_in': 3600}
+ http = HttpMockSequence([
({'status': status_code}, b''),
({'status': '200'}, json.dumps(token_response).encode('utf-8')),
({'status': '200'}, 'echo_request_headers'),
- ])
- http = self.credentials.authorize(http)
- resp, content = http.request('http://example.com')
- self.assertEqual(b'Bearer 1/3w', content[b'Authorization'])
- self.assertFalse(self.credentials.access_token_expired)
- self.assertEqual(token_response, self.credentials.token_response)
+ ])
+ http = self.credentials.authorize(http)
+ resp, content = http.request('http://example.com')
+ self.assertEqual(b'Bearer 1/3w', content[b'Authorization'])
+ self.assertFalse(self.credentials.access_token_expired)
+ self.assertEqual(token_response, self.credentials.token_response)
- def test_recursive_authorize(self):
- """Tests that OAuth2Credentials does not introduce new method constraints.
+ def test_recursive_authorize(self):
+ """Tests that OAuth2Credentials does not introduce new method constraints.
Formerly, OAuth2Credentials.authorize monkeypatched the request method of
its httplib2.Http argument with a wrapper annotated with
@@ -687,331 +687,332 @@ class BasicCredentialsTests(unittest.TestCase):
even respect that requirement. So before the removal of the annotation, this
test would fail.
"""
- token_response = {'access_token': '1/3w', 'expires_in': 3600}
- encoded_response = json.dumps(token_response).encode('utf-8')
- http = HttpMockSequence([
+ token_response = {'access_token': '1/3w', 'expires_in': 3600}
+ encoded_response = json.dumps(token_response).encode('utf-8')
+ http = HttpMockSequence([
({'status': '200'}, encoded_response),
- ])
- http = self.credentials.authorize(http)
- http = self.credentials.authorize(http)
- http.request('http://example.com')
+ ])
+ http = self.credentials.authorize(http)
+ http = self.credentials.authorize(http)
+ http.request('http://example.com')
- def test_token_refresh_failure(self):
- for status_code in REFRESH_STATUS_CODES:
- http = HttpMockSequence([
+ def test_token_refresh_failure(self):
+ for status_code in REFRESH_STATUS_CODES:
+ http = HttpMockSequence([
({'status': status_code}, b''),
({'status': '400'}, b'{"error":"access_denied"}'),
])
- http = self.credentials.authorize(http)
- try:
- http.request('http://example.com')
- self.fail('should raise AccessTokenRefreshError exception')
- except AccessTokenRefreshError:
- pass
- self.assertTrue(self.credentials.access_token_expired)
- self.assertEqual(None, self.credentials.token_response)
+ http = self.credentials.authorize(http)
+ try:
+ http.request('http://example.com')
+ self.fail('should raise AccessTokenRefreshError exception')
+ except AccessTokenRefreshError:
+ pass
+ self.assertTrue(self.credentials.access_token_expired)
+ self.assertEqual(None, self.credentials.token_response)
- def test_token_revoke_success(self):
- _token_revoke_test_helper(
+ def test_token_revoke_success(self):
+ _token_revoke_test_helper(
self, '200', revoke_raise=False,
valid_bool_value=True, token_attr='refresh_token')
- def test_token_revoke_failure(self):
- _token_revoke_test_helper(
+ def test_token_revoke_failure(self):
+ _token_revoke_test_helper(
self, '400', revoke_raise=True,
valid_bool_value=False, token_attr='refresh_token')
- def test_token_revoke_fallback(self):
- original_credentials = self.credentials.to_json()
- self.credentials.refresh_token = None
- _token_revoke_test_helper(
+ def test_token_revoke_fallback(self):
+ original_credentials = self.credentials.to_json()
+ self.credentials.refresh_token = None
+ _token_revoke_test_helper(
self, '200', revoke_raise=False,
valid_bool_value=True, token_attr='access_token')
- self.credentials = self.credentials.from_json(original_credentials)
+ self.credentials = self.credentials.from_json(original_credentials)
- def test_non_401_error_response(self):
- http = HttpMockSequence([
+ def test_non_401_error_response(self):
+ http = HttpMockSequence([
({'status': '400'}, b''),
])
- http = self.credentials.authorize(http)
- resp, content = http.request('http://example.com')
- self.assertEqual(400, resp.status)
- self.assertEqual(None, self.credentials.token_response)
+ http = self.credentials.authorize(http)
+ resp, content = http.request('http://example.com')
+ self.assertEqual(400, resp.status)
+ self.assertEqual(None, self.credentials.token_response)
- def test_to_from_json(self):
- json = self.credentials.to_json()
- instance = OAuth2Credentials.from_json(json)
- self.assertEqual(OAuth2Credentials, type(instance))
- instance.token_expiry = None
- self.credentials.token_expiry = None
+ def test_to_from_json(self):
+ json = self.credentials.to_json()
+ instance = OAuth2Credentials.from_json(json)
+ self.assertEqual(OAuth2Credentials, type(instance))
+ instance.token_expiry = None
+ self.credentials.token_expiry = None
- self.assertEqual(instance.__dict__, self.credentials.__dict__)
+ self.assertEqual(instance.__dict__, self.credentials.__dict__)
- def test_from_json_token_expiry(self):
- data = json.loads(self.credentials.to_json())
- data['token_expiry'] = None
- instance = OAuth2Credentials.from_json(json.dumps(data))
- self.assertTrue(isinstance(instance, OAuth2Credentials))
+ def test_from_json_token_expiry(self):
+ data = json.loads(self.credentials.to_json())
+ data['token_expiry'] = None
+ instance = OAuth2Credentials.from_json(json.dumps(data))
+ self.assertTrue(isinstance(instance, OAuth2Credentials))
- def test_unicode_header_checks(self):
- access_token = u'foo'
- client_id = u'some_client_id'
- client_secret = u'cOuDdkfjxxnv+'
- refresh_token = u'1/0/a.df219fjls0'
- token_expiry = str(datetime.datetime.utcnow())
- token_uri = str(GOOGLE_TOKEN_URI)
- revoke_uri = str(GOOGLE_REVOKE_URI)
- user_agent = u'refresh_checker/1.0'
- credentials = OAuth2Credentials(access_token, client_id, client_secret,
+ def test_unicode_header_checks(self):
+ access_token = u'foo'
+ client_id = u'some_client_id'
+ client_secret = u'cOuDdkfjxxnv+'
+ refresh_token = u'1/0/a.df219fjls0'
+ token_expiry = str(datetime.datetime.utcnow())
+ token_uri = str(GOOGLE_TOKEN_URI)
+ revoke_uri = str(GOOGLE_REVOKE_URI)
+ user_agent = u'refresh_checker/1.0'
+ credentials = OAuth2Credentials(access_token, client_id, client_secret,
refresh_token, token_expiry, token_uri,
user_agent, revoke_uri=revoke_uri)
- # First, test that we correctly encode basic objects, making sure
- # to include a bytes object. Note that oauth2client will normalize
- # everything to bytes, no matter what python version we're in.
- http = credentials.authorize(HttpMock(headers={'status': '200'}))
- headers = {u'foo': 3, b'bar': True, 'baz': b'abc'}
- cleaned_headers = {b'foo': b'3', b'bar': b'True', b'baz': b'abc'}
- http.request(u'http://example.com', method=u'GET', headers=headers)
- for k, v in cleaned_headers.items():
- self.assertTrue(k in http.headers)
- self.assertEqual(v, http.headers[k])
+ # First, test that we correctly encode basic objects, making sure
+ # to include a bytes object. Note that oauth2client will normalize
+ # everything to bytes, no matter what python version we're in.
+ http = credentials.authorize(HttpMock(headers={'status': '200'}))
+ headers = {u'foo': 3, b'bar': True, 'baz': b'abc'}
+ cleaned_headers = {b'foo': b'3', b'bar': b'True', b'baz': b'abc'}
+ http.request(u'http://example.com', method=u'GET', headers=headers)
+ for k, v in cleaned_headers.items():
+ self.assertTrue(k in http.headers)
+ self.assertEqual(v, http.headers[k])
- # Next, test that we do fail on unicode.
- unicode_str = six.unichr(40960) + 'abcd'
- self.assertRaises(
+ # Next, test that we do fail on unicode.
+ unicode_str = six.unichr(40960) + 'abcd'
+ self.assertRaises(
NonAsciiHeaderError,
http.request,
u'http://example.com', method=u'GET', headers={u'foo': unicode_str})
- def test_no_unicode_in_request_params(self):
- access_token = u'foo'
- client_id = u'some_client_id'
- client_secret = u'cOuDdkfjxxnv+'
- refresh_token = u'1/0/a.df219fjls0'
- token_expiry = str(datetime.datetime.utcnow())
- token_uri = str(GOOGLE_TOKEN_URI)
- revoke_uri = str(GOOGLE_REVOKE_URI)
- user_agent = u'refresh_checker/1.0'
- credentials = OAuth2Credentials(access_token, client_id, client_secret,
+ def test_no_unicode_in_request_params(self):
+ access_token = u'foo'
+ client_id = u'some_client_id'
+ client_secret = u'cOuDdkfjxxnv+'
+ refresh_token = u'1/0/a.df219fjls0'
+ token_expiry = str(datetime.datetime.utcnow())
+ token_uri = str(GOOGLE_TOKEN_URI)
+ revoke_uri = str(GOOGLE_REVOKE_URI)
+ user_agent = u'refresh_checker/1.0'
+ credentials = OAuth2Credentials(access_token, client_id, client_secret,
refresh_token, token_expiry, token_uri,
user_agent, revoke_uri=revoke_uri)
- http = HttpMock(headers={'status': '200'})
- http = credentials.authorize(http)
- http.request(u'http://example.com', method=u'GET', headers={u'foo': u'bar'})
- for k, v in six.iteritems(http.headers):
- self.assertTrue(isinstance(k, six.binary_type))
- self.assertTrue(isinstance(v, six.binary_type))
+ http = HttpMock(headers={'status': '200'})
+ http = credentials.authorize(http)
+ http.request(u'http://example.com', method=u'GET', headers={u'foo': u'bar'})
+ for k, v in six.iteritems(http.headers):
+ self.assertTrue(isinstance(k, six.binary_type))
+ self.assertTrue(isinstance(v, six.binary_type))
- # Test again with unicode strings that can't simply be converted to ASCII.
- try:
- http.request(
+ # Test again with unicode strings that can't simply be converted to ASCII.
+ try:
+ http.request(
u'http://example.com', method=u'GET', headers={u'foo': u'\N{COMET}'})
- self.fail('Expected exception to be raised.')
- except NonAsciiHeaderError:
- pass
+ self.fail('Expected exception to be raised.')
+ except NonAsciiHeaderError:
+ pass
- self.credentials.token_response = 'foobar'
- instance = OAuth2Credentials.from_json(self.credentials.to_json())
- self.assertEqual('foobar', instance.token_response)
+ self.credentials.token_response = 'foobar'
+ instance = OAuth2Credentials.from_json(self.credentials.to_json())
+ self.assertEqual('foobar', instance.token_response)
- def test_get_access_token(self):
- S = 2 # number of seconds in which the token expires
- token_response_first = {'access_token': 'first_token', 'expires_in': S}
- token_response_second = {'access_token': 'second_token', 'expires_in': S}
- http = HttpMockSequence([
+ def test_get_access_token(self):
+ S = 2 # number of seconds in which the token expires
+ token_response_first = {'access_token': 'first_token', 'expires_in': S}
+ token_response_second = {'access_token': 'second_token', 'expires_in': S}
+ http = HttpMockSequence([
({'status': '200'}, json.dumps(token_response_first).encode('utf-8')),
({'status': '200'}, json.dumps(token_response_second).encode('utf-8')),
- ])
+ ])
- token = self.credentials.get_access_token(http=http)
- self.assertEqual('first_token', token.access_token)
- self.assertEqual(S - 1, token.expires_in)
- self.assertFalse(self.credentials.access_token_expired)
- self.assertEqual(token_response_first, self.credentials.token_response)
+ token = self.credentials.get_access_token(http=http)
+ self.assertEqual('first_token', token.access_token)
+ self.assertEqual(S - 1, token.expires_in)
+ self.assertFalse(self.credentials.access_token_expired)
+ self.assertEqual(token_response_first, self.credentials.token_response)
- token = self.credentials.get_access_token(http=http)
- self.assertEqual('first_token', token.access_token)
- self.assertEqual(S - 1, token.expires_in)
- self.assertFalse(self.credentials.access_token_expired)
- self.assertEqual(token_response_first, self.credentials.token_response)
+ token = self.credentials.get_access_token(http=http)
+ self.assertEqual('first_token', token.access_token)
+ self.assertEqual(S - 1, token.expires_in)
+ self.assertFalse(self.credentials.access_token_expired)
+ self.assertEqual(token_response_first, self.credentials.token_response)
- time.sleep(S + 0.5) # some margin to avoid flakiness
- self.assertTrue(self.credentials.access_token_expired)
+ time.sleep(S + 0.5) # some margin to avoid flakiness
+ self.assertTrue(self.credentials.access_token_expired)
- token = self.credentials.get_access_token(http=http)
- self.assertEqual('second_token', token.access_token)
- self.assertEqual(S - 1, token.expires_in)
- self.assertFalse(self.credentials.access_token_expired)
- self.assertEqual(token_response_second, self.credentials.token_response)
+ token = self.credentials.get_access_token(http=http)
+ self.assertEqual('second_token', token.access_token)
+ self.assertEqual(S - 1, token.expires_in)
+ self.assertFalse(self.credentials.access_token_expired)
+ self.assertEqual(token_response_second, self.credentials.token_response)
- def test_has_scopes(self):
- self.assertTrue(self.credentials.has_scopes('foo'))
- self.assertTrue(self.credentials.has_scopes(['foo']))
- self.assertFalse(self.credentials.has_scopes('bar'))
- self.assertFalse(self.credentials.has_scopes(['bar']))
+ def test_has_scopes(self):
+ self.assertTrue(self.credentials.has_scopes('foo'))
+ self.assertTrue(self.credentials.has_scopes(['foo']))
+ self.assertFalse(self.credentials.has_scopes('bar'))
+ self.assertFalse(self.credentials.has_scopes(['bar']))
- self.credentials.scopes = set(['foo', 'bar'])
- self.assertTrue(self.credentials.has_scopes('foo'))
- self.assertTrue(self.credentials.has_scopes('bar'))
- self.assertFalse(self.credentials.has_scopes('baz'))
- self.assertTrue(self.credentials.has_scopes(['foo', 'bar']))
- self.assertFalse(self.credentials.has_scopes(['foo', 'baz']))
+ self.credentials.scopes = set(['foo', 'bar'])
+ self.assertTrue(self.credentials.has_scopes('foo'))
+ self.assertTrue(self.credentials.has_scopes('bar'))
+ self.assertFalse(self.credentials.has_scopes('baz'))
+ self.assertTrue(self.credentials.has_scopes(['foo', 'bar']))
+ self.assertFalse(self.credentials.has_scopes(['foo', 'baz']))
- self.credentials.scopes = set([])
- self.assertFalse(self.credentials.has_scopes('foo'))
+ self.credentials.scopes = set([])
+ self.assertFalse(self.credentials.has_scopes('foo'))
- def test_retrieve_scopes(self):
- info_response_first = {'scope': 'foo bar'}
- info_response_second = {'error_description': 'abcdef'}
- http = HttpMockSequence([
+ def test_retrieve_scopes(self):
+ info_response_first = {'scope': 'foo bar'}
+ info_response_second = {'error_description': 'abcdef'}
+ http = HttpMockSequence([
({'status': '200'}, json.dumps(info_response_first).encode('utf-8')),
({'status': '400'}, json.dumps(info_response_second).encode('utf-8')),
({'status': '500'}, b''),
- ])
+ ])
- self.credentials.retrieve_scopes(http)
- self.assertEqual(set(['foo', 'bar']), self.credentials.scopes)
+ self.credentials.retrieve_scopes(http)
+ self.assertEqual(set(['foo', 'bar']), self.credentials.scopes)
- self.assertRaises(
+ self.assertRaises(
Error,
self.credentials.retrieve_scopes,
http)
- self.assertRaises(
+ self.assertRaises(
Error,
self.credentials.retrieve_scopes,
http)
+
class AccessTokenCredentialsTests(unittest.TestCase):
- def setUp(self):
- access_token = 'foo'
- user_agent = 'refresh_checker/1.0'
- self.credentials = AccessTokenCredentials(access_token, user_agent,
+ def setUp(self):
+ access_token = 'foo'
+ user_agent = 'refresh_checker/1.0'
+ self.credentials = AccessTokenCredentials(access_token, user_agent,
revoke_uri=GOOGLE_REVOKE_URI)
- def test_token_refresh_success(self):
- for status_code in REFRESH_STATUS_CODES:
- http = HttpMockSequence([
+ def test_token_refresh_success(self):
+ for status_code in REFRESH_STATUS_CODES:
+ http = HttpMockSequence([
({'status': status_code}, b''),
])
- http = self.credentials.authorize(http)
- try:
- resp, content = http.request('http://example.com')
- self.fail('should throw exception if token expires')
- except AccessTokenCredentialsError:
- pass
- except Exception:
- self.fail('should only throw AccessTokenCredentialsError')
+ http = self.credentials.authorize(http)
+ try:
+ resp, content = http.request('http://example.com')
+ self.fail('should throw exception if token expires')
+ except AccessTokenCredentialsError:
+ pass
+ except Exception:
+ self.fail('should only throw AccessTokenCredentialsError')
- def test_token_revoke_success(self):
- _token_revoke_test_helper(
+ def test_token_revoke_success(self):
+ _token_revoke_test_helper(
self, '200', revoke_raise=False,
valid_bool_value=True, token_attr='access_token')
- def test_token_revoke_failure(self):
- _token_revoke_test_helper(
+ def test_token_revoke_failure(self):
+ _token_revoke_test_helper(
self, '400', revoke_raise=True,
valid_bool_value=False, token_attr='access_token')
- def test_non_401_error_response(self):
- http = HttpMockSequence([
+ def test_non_401_error_response(self):
+ http = HttpMockSequence([
({'status': '400'}, b''),
])
- http = self.credentials.authorize(http)
- resp, content = http.request('http://example.com')
- self.assertEqual(400, resp.status)
+ http = self.credentials.authorize(http)
+ resp, content = http.request('http://example.com')
+ self.assertEqual(400, resp.status)
- def test_auth_header_sent(self):
- http = HttpMockSequence([
+ def test_auth_header_sent(self):
+ http = HttpMockSequence([
({'status': '200'}, 'echo_request_headers'),
])
- http = self.credentials.authorize(http)
- resp, content = http.request('http://example.com')
- self.assertEqual(b'Bearer foo', content[b'Authorization'])
+ http = self.credentials.authorize(http)
+ resp, content = http.request('http://example.com')
+ self.assertEqual(b'Bearer foo', content[b'Authorization'])
class TestAssertionCredentials(unittest.TestCase):
- assertion_text = 'This is the assertion'
- assertion_type = 'http://www.google.com/assertionType'
+ assertion_text = 'This is the assertion'
+ assertion_type = 'http://www.google.com/assertionType'
- class AssertionCredentialsTestImpl(AssertionCredentials):
+ class AssertionCredentialsTestImpl(AssertionCredentials):
- def _generate_assertion(self):
- return TestAssertionCredentials.assertion_text
+ def _generate_assertion(self):
+ return TestAssertionCredentials.assertion_text
- def setUp(self):
- user_agent = 'fun/2.0'
- self.credentials = self.AssertionCredentialsTestImpl(self.assertion_type,
+ def setUp(self):
+ user_agent = 'fun/2.0'
+ self.credentials = self.AssertionCredentialsTestImpl(self.assertion_type,
user_agent=user_agent)
- def test_assertion_body(self):
- body = urllib.parse.parse_qs(
+ def test_assertion_body(self):
+ body = urllib.parse.parse_qs(
self.credentials._generate_refresh_request_body())
- self.assertEqual(self.assertion_text, body['assertion'][0])
- self.assertEqual('urn:ietf:params:oauth:grant-type:jwt-bearer',
+ self.assertEqual(self.assertion_text, body['assertion'][0])
+ self.assertEqual('urn:ietf:params:oauth:grant-type:jwt-bearer',
body['grant_type'][0])
- def test_assertion_refresh(self):
- http = HttpMockSequence([
+ def test_assertion_refresh(self):
+ http = HttpMockSequence([
({'status': '200'}, b'{"access_token":"1/3w"}'),
({'status': '200'}, 'echo_request_headers'),
])
- http = self.credentials.authorize(http)
- resp, content = http.request('http://example.com')
- self.assertEqual(b'Bearer 1/3w', content[b'Authorization'])
+ http = self.credentials.authorize(http)
+ resp, content = http.request('http://example.com')
+ self.assertEqual(b'Bearer 1/3w', content[b'Authorization'])
- def test_token_revoke_success(self):
- _token_revoke_test_helper(
+ def test_token_revoke_success(self):
+ _token_revoke_test_helper(
self, '200', revoke_raise=False,
valid_bool_value=True, token_attr='access_token')
- def test_token_revoke_failure(self):
- _token_revoke_test_helper(
+ def test_token_revoke_failure(self):
+ _token_revoke_test_helper(
self, '400', revoke_raise=True,
valid_bool_value=False, token_attr='access_token')
class UpdateQueryParamsTest(unittest.TestCase):
- def test_update_query_params_no_params(self):
- uri = 'http://www.google.com'
- updated = _update_query_params(uri, {'a': 'b'})
- self.assertEqual(updated, uri + '?a=b')
+ def test_update_query_params_no_params(self):
+ uri = 'http://www.google.com'
+ updated = _update_query_params(uri, {'a': 'b'})
+ self.assertEqual(updated, uri + '?a=b')
- def test_update_query_params_existing_params(self):
- uri = 'http://www.google.com?x=y'
- updated = _update_query_params(uri, {'a': 'b', 'c': 'd&'})
- hardcoded_update = uri + '&a=b&c=d%26'
- assertUrisEqual(self, updated, hardcoded_update)
+ def test_update_query_params_existing_params(self):
+ uri = 'http://www.google.com?x=y'
+ updated = _update_query_params(uri, {'a': 'b', 'c': 'd&'})
+ hardcoded_update = uri + '&a=b&c=d%26'
+ assertUrisEqual(self, updated, hardcoded_update)
class ExtractIdTokenTest(unittest.TestCase):
- """Tests _extract_id_token()."""
+ """Tests _extract_id_token()."""
- def test_extract_success(self):
- body = {'foo': 'bar'}
- body_json = json.dumps(body).encode('ascii')
- payload = base64.urlsafe_b64encode(body_json).strip(b'=')
- jwt = b'stuff.' + payload + b'.signature'
+ def test_extract_success(self):
+ body = {'foo': 'bar'}
+ body_json = json.dumps(body).encode('ascii')
+ payload = base64.urlsafe_b64encode(body_json).strip(b'=')
+ jwt = b'stuff.' + payload + b'.signature'
- extracted = _extract_id_token(jwt)
- self.assertEqual(extracted, body)
+ extracted = _extract_id_token(jwt)
+ self.assertEqual(extracted, body)
- def test_extract_failure(self):
- body = {'foo': 'bar'}
- body_json = json.dumps(body).encode('ascii')
- payload = base64.urlsafe_b64encode(body_json).strip(b'=')
- jwt = b'stuff.' + payload
+ def test_extract_failure(self):
+ body = {'foo': 'bar'}
+ body_json = json.dumps(body).encode('ascii')
+ payload = base64.urlsafe_b64encode(body_json).strip(b'=')
+ jwt = b'stuff.' + payload
- self.assertRaises(VerifyJwtTokenError, _extract_id_token, jwt)
+ self.assertRaises(VerifyJwtTokenError, _extract_id_token, jwt)
class OAuth2WebServerFlowTest(unittest.TestCase):
- def setUp(self):
- self.flow = OAuth2WebServerFlow(
+ def setUp(self):
+ self.flow = OAuth2WebServerFlow(
client_id='client_id+1',
client_secret='secret+1',
scope='foo',
@@ -1020,21 +1021,21 @@ class OAuth2WebServerFlowTest(unittest.TestCase):
revoke_uri='dummy_revoke_uri',
)
- def test_construct_authorize_url(self):
- authorize_url = self.flow.step1_get_authorize_url(state='state+1')
+ def test_construct_authorize_url(self):
+ authorize_url = self.flow.step1_get_authorize_url(state='state+1')
- parsed = urllib.parse.urlparse(authorize_url)
- q = urllib.parse.parse_qs(parsed[4])
- self.assertEqual('client_id+1', q['client_id'][0])
- self.assertEqual('code', q['response_type'][0])
- self.assertEqual('foo', q['scope'][0])
- self.assertEqual(OOB_CALLBACK_URN, q['redirect_uri'][0])
- self.assertEqual('offline', q['access_type'][0])
- self.assertEqual('state+1', q['state'][0])
+ parsed = urllib.parse.urlparse(authorize_url)
+ q = urllib.parse.parse_qs(parsed[4])
+ self.assertEqual('client_id+1', q['client_id'][0])
+ self.assertEqual('code', q['response_type'][0])
+ self.assertEqual('foo', q['scope'][0])
+ self.assertEqual(OOB_CALLBACK_URN, q['redirect_uri'][0])
+ self.assertEqual('offline', q['access_type'][0])
+ self.assertEqual('state+1', q['state'][0])
- def test_override_flow_via_kwargs(self):
- """Passing kwargs to override defaults."""
- flow = OAuth2WebServerFlow(
+ def test_override_flow_via_kwargs(self):
+ """Passing kwargs to override defaults."""
+ flow = OAuth2WebServerFlow(
client_id='client_id+1',
client_secret='secret+1',
scope='foo',
@@ -1043,95 +1044,95 @@ class OAuth2WebServerFlowTest(unittest.TestCase):
access_type='online',
response_type='token'
)
- authorize_url = flow.step1_get_authorize_url()
+ authorize_url = flow.step1_get_authorize_url()
- parsed = urllib.parse.urlparse(authorize_url)
- q = urllib.parse.parse_qs(parsed[4])
- self.assertEqual('client_id+1', q['client_id'][0])
- self.assertEqual('token', q['response_type'][0])
- self.assertEqual('foo', q['scope'][0])
- self.assertEqual(OOB_CALLBACK_URN, q['redirect_uri'][0])
- self.assertEqual('online', q['access_type'][0])
+ parsed = urllib.parse.urlparse(authorize_url)
+ q = urllib.parse.parse_qs(parsed[4])
+ self.assertEqual('client_id+1', q['client_id'][0])
+ self.assertEqual('token', q['response_type'][0])
+ self.assertEqual('foo', q['scope'][0])
+ self.assertEqual(OOB_CALLBACK_URN, q['redirect_uri'][0])
+ self.assertEqual('online', q['access_type'][0])
- def test_scope_is_required(self):
- self.assertRaises(TypeError, OAuth2WebServerFlow, 'client_id+1')
+ def test_scope_is_required(self):
+ self.assertRaises(TypeError, OAuth2WebServerFlow, 'client_id+1')
- def test_exchange_failure(self):
- http = HttpMockSequence([
+ def test_exchange_failure(self):
+ http = HttpMockSequence([
({'status': '400'}, b'{"error":"invalid_request"}'),
])
- try:
- credentials = self.flow.step2_exchange('some random code', http=http)
- self.fail('should raise exception if exchange doesn\'t get 200')
- except FlowExchangeError:
- pass
+ try:
+ credentials = self.flow.step2_exchange('some random code', http=http)
+ self.fail('should raise exception if exchange doesn\'t get 200')
+ except FlowExchangeError:
+ pass
- def test_urlencoded_exchange_failure(self):
- http = HttpMockSequence([
+ def test_urlencoded_exchange_failure(self):
+ http = HttpMockSequence([
({'status': '400'}, b'error=invalid_request'),
- ])
+ ])
- try:
- credentials = self.flow.step2_exchange('some random code', http=http)
- self.fail('should raise exception if exchange doesn\'t get 200')
- except FlowExchangeError as e:
- self.assertEqual('invalid_request', str(e))
+ try:
+ credentials = self.flow.step2_exchange('some random code', http=http)
+ self.fail('should raise exception if exchange doesn\'t get 200')
+ except FlowExchangeError as e:
+ self.assertEqual('invalid_request', str(e))
- def test_exchange_failure_with_json_error(self):
- # Some providers have 'error' attribute as a JSON object
- # in place of regular string.
- # This test makes sure no strange object-to-string coversion
- # exceptions are being raised instead of FlowExchangeError.
- http = HttpMockSequence([
+ def test_exchange_failure_with_json_error(self):
+ # Some providers have 'error' attribute as a JSON object
+ # in place of regular string.
+ # This test makes sure no strange object-to-string coversion
+ # exceptions are being raised instead of FlowExchangeError.
+ http = HttpMockSequence([
({'status': '400'},
b""" {"error": {
"type": "OAuthException",
"message": "Error validating verification code."} }"""),
])
- try:
- credentials = self.flow.step2_exchange('some random code', http=http)
- self.fail('should raise exception if exchange doesn\'t get 200')
- except FlowExchangeError as e:
- pass
+ try:
+ credentials = self.flow.step2_exchange('some random code', http=http)
+ self.fail('should raise exception if exchange doesn\'t get 200')
+ except FlowExchangeError as e:
+ pass
- def test_exchange_success(self):
- http = HttpMockSequence([
+ def test_exchange_success(self):
+ http = HttpMockSequence([
({'status': '200'},
b"""{ "access_token":"SlAV32hkKG",
"expires_in":3600,
"refresh_token":"8xLOxBtZp8" }"""),
])
- credentials = self.flow.step2_exchange('some random code', http=http)
- self.assertEqual('SlAV32hkKG', credentials.access_token)
- self.assertNotEqual(None, credentials.token_expiry)
- self.assertEqual('8xLOxBtZp8', credentials.refresh_token)
- self.assertEqual('dummy_revoke_uri', credentials.revoke_uri)
- self.assertEqual(set(['foo']), credentials.scopes)
+ credentials = self.flow.step2_exchange('some random code', http=http)
+ self.assertEqual('SlAV32hkKG', credentials.access_token)
+ self.assertNotEqual(None, credentials.token_expiry)
+ self.assertEqual('8xLOxBtZp8', credentials.refresh_token)
+ self.assertEqual('dummy_revoke_uri', credentials.revoke_uri)
+ self.assertEqual(set(['foo']), credentials.scopes)
- def test_exchange_dictlike(self):
- class FakeDict(object):
- def __init__(self, d):
- self.d = d
+ def test_exchange_dictlike(self):
+ class FakeDict(object):
+ def __init__(self, d):
+ self.d = d
- def __getitem__(self, name):
- return self.d[name]
+ def __getitem__(self, name):
+ return self.d[name]
- def __contains__(self, name):
- return name in self.d
+ def __contains__(self, name):
+ return name in self.d
- code = 'some random code'
+ code = 'some random code'
not_a_dict = FakeDict({'code': code})
payload = (b'{'
b' "access_token":"SlAV32hkKG",'
b' "expires_in":3600,'
b' "refresh_token":"8xLOxBtZp8"'
b'}')
- http = HttpMockSequence([({'status': '200'}, payload),])
+ http = HttpMockSequence([({'status': '200'}, payload), ])
- credentials = self.flow.step2_exchange(not_a_dict, http=http)
+ credentials = self.flow.step2_exchange(not_a_dict, http=http)
self.assertEqual('SlAV32hkKG', credentials.access_token)
self.assertNotEqual(None, credentials.token_expiry)
self.assertEqual('8xLOxBtZp8', credentials.refresh_token)
@@ -1140,9 +1141,9 @@ class OAuth2WebServerFlowTest(unittest.TestCase):
request_code = urllib.parse.parse_qs(http.requests[0]['body'])['code'][0]
self.assertEqual(code, request_code)
- def test_exchange_using_authorization_header(self):
- auth_header = 'Basic Y2xpZW50X2lkKzE6c2VjcmV0KzE=',
- flow = OAuth2WebServerFlow(
+ def test_exchange_using_authorization_header(self):
+ auth_header = 'Basic Y2xpZW50X2lkKzE6c2VjcmV0KzE=',
+ flow = OAuth2WebServerFlow(
client_id='client_id+1',
authorization_header=auth_header,
scope='foo',
@@ -1150,223 +1151,223 @@ class OAuth2WebServerFlowTest(unittest.TestCase):
user_agent='unittest-sample/1.0',
revoke_uri='dummy_revoke_uri',
)
- http = HttpMockSequence([
+ http = HttpMockSequence([
({'status': '200'}, b'access_token=SlAV32hkKG'),
- ])
+ ])
- credentials = flow.step2_exchange('some random code', http=http)
- self.assertEqual('SlAV32hkKG', credentials.access_token)
+ credentials = flow.step2_exchange('some random code', http=http)
+ self.assertEqual('SlAV32hkKG', credentials.access_token)
- test_request = http.requests[0]
- # Did we pass the Authorization header?
- self.assertEqual(test_request['headers']['Authorization'], auth_header)
- # Did we omit client_secret from POST body?
- self.assertTrue('client_secret' not in test_request['body'])
+ test_request = http.requests[0]
+ # Did we pass the Authorization header?
+ self.assertEqual(test_request['headers']['Authorization'], auth_header)
+ # Did we omit client_secret from POST body?
+ self.assertTrue('client_secret' not in test_request['body'])
- def test_urlencoded_exchange_success(self):
- http = HttpMockSequence([
+ def test_urlencoded_exchange_success(self):
+ http = HttpMockSequence([
({'status': '200'}, b'access_token=SlAV32hkKG&expires_in=3600'),
- ])
+ ])
- credentials = self.flow.step2_exchange('some random code', http=http)
- self.assertEqual('SlAV32hkKG', credentials.access_token)
- self.assertNotEqual(None, credentials.token_expiry)
+ credentials = self.flow.step2_exchange('some random code', http=http)
+ self.assertEqual('SlAV32hkKG', credentials.access_token)
+ self.assertNotEqual(None, credentials.token_expiry)
- def test_urlencoded_expires_param(self):
- http = HttpMockSequence([
+ def test_urlencoded_expires_param(self):
+ http = HttpMockSequence([
# Note the 'expires=3600' where you'd normally
# have if named 'expires_in'
({'status': '200'}, b'access_token=SlAV32hkKG&expires=3600'),
- ])
+ ])
- credentials = self.flow.step2_exchange('some random code', http=http)
- self.assertNotEqual(None, credentials.token_expiry)
+ credentials = self.flow.step2_exchange('some random code', http=http)
+ self.assertNotEqual(None, credentials.token_expiry)
- def test_exchange_no_expires_in(self):
- http = HttpMockSequence([
+ def test_exchange_no_expires_in(self):
+ http = HttpMockSequence([
({'status': '200'}, b"""{ "access_token":"SlAV32hkKG",
"refresh_token":"8xLOxBtZp8" }"""),
])
- credentials = self.flow.step2_exchange('some random code', http=http)
- self.assertEqual(None, credentials.token_expiry)
+ credentials = self.flow.step2_exchange('some random code', http=http)
+ self.assertEqual(None, credentials.token_expiry)
- def test_urlencoded_exchange_no_expires_in(self):
- http = HttpMockSequence([
+ def test_urlencoded_exchange_no_expires_in(self):
+ http = HttpMockSequence([
# This might be redundant but just to make sure
# urlencoded access_token gets parsed correctly
({'status': '200'}, b'access_token=SlAV32hkKG'),
- ])
+ ])
- credentials = self.flow.step2_exchange('some random code', http=http)
- self.assertEqual(None, credentials.token_expiry)
+ credentials = self.flow.step2_exchange('some random code', http=http)
+ self.assertEqual(None, credentials.token_expiry)
- def test_exchange_fails_if_no_code(self):
- http = HttpMockSequence([
+ def test_exchange_fails_if_no_code(self):
+ http = HttpMockSequence([
({'status': '200'}, b"""{ "access_token":"SlAV32hkKG",
"refresh_token":"8xLOxBtZp8" }"""),
])
- code = {'error': 'thou shall not pass'}
- try:
- credentials = self.flow.step2_exchange(code, http=http)
- self.fail('should raise exception if no code in dictionary.')
- except FlowExchangeError as e:
- self.assertTrue('shall not pass' in str(e))
+ code = {'error': 'thou shall not pass'}
+ try:
+ credentials = self.flow.step2_exchange(code, http=http)
+ self.fail('should raise exception if no code in dictionary.')
+ except FlowExchangeError as e:
+ self.assertTrue('shall not pass' in str(e))
- def test_exchange_id_token_fail(self):
- http = HttpMockSequence([
+ def test_exchange_id_token_fail(self):
+ http = HttpMockSequence([
({'status': '200'}, b"""{ "access_token":"SlAV32hkKG",
"refresh_token":"8xLOxBtZp8",
"id_token": "stuff.payload"}"""),
])
- self.assertRaises(VerifyJwtTokenError, self.flow.step2_exchange,
+ self.assertRaises(VerifyJwtTokenError, self.flow.step2_exchange,
'some random code', http=http)
- def test_exchange_id_token(self):
- body = {'foo': 'bar'}
- body_json = json.dumps(body).encode('ascii')
- payload = base64.urlsafe_b64encode(body_json).strip(b'=')
- jwt = (base64.urlsafe_b64encode(b'stuff') + b'.' + payload + b'.' +
+ def test_exchange_id_token(self):
+ body = {'foo': 'bar'}
+ body_json = json.dumps(body).encode('ascii')
+ payload = base64.urlsafe_b64encode(body_json).strip(b'=')
+ jwt = (base64.urlsafe_b64encode(b'stuff') + b'.' + payload + b'.' +
base64.urlsafe_b64encode(b'signature'))
- http = HttpMockSequence([
+ http = HttpMockSequence([
({'status': '200'}, ("""{ "access_token":"SlAV32hkKG",
"refresh_token":"8xLOxBtZp8",
"id_token": "%s"}""" % jwt).encode('utf-8')),
])
- credentials = self.flow.step2_exchange('some random code', http=http)
- self.assertEqual(credentials.id_token, body)
+ credentials = self.flow.step2_exchange('some random code', http=http)
+ self.assertEqual(credentials.id_token, body)
class FlowFromCachedClientsecrets(unittest.TestCase):
- def test_flow_from_clientsecrets_cached(self):
- cache_mock = CacheMock()
- load_and_cache('client_secrets.json', 'some_secrets', cache_mock)
+ def test_flow_from_clientsecrets_cached(self):
+ cache_mock = CacheMock()
+ load_and_cache('client_secrets.json', 'some_secrets', cache_mock)
- flow = flow_from_clientsecrets(
+ flow = flow_from_clientsecrets(
'some_secrets', '', redirect_uri='oob', cache=cache_mock)
- self.assertEqual('foo_client_secret', flow.client_secret)
+ self.assertEqual('foo_client_secret', flow.client_secret)
class CredentialsFromCodeTests(unittest.TestCase):
- def setUp(self):
- self.client_id = 'client_id_abc'
- self.client_secret = 'secret_use_code'
- self.scope = 'foo'
- self.code = '12345abcde'
- self.redirect_uri = 'postmessage'
+ def setUp(self):
+ self.client_id = 'client_id_abc'
+ self.client_secret = 'secret_use_code'
+ self.scope = 'foo'
+ self.code = '12345abcde'
+ self.redirect_uri = 'postmessage'
- def test_exchange_code_for_token(self):
- token = 'asdfghjkl'
- payload = json.dumps({'access_token': token, 'expires_in': 3600})
- http = HttpMockSequence([
+ def test_exchange_code_for_token(self):
+ token = 'asdfghjkl'
+ payload = json.dumps({'access_token': token, 'expires_in': 3600})
+ http = HttpMockSequence([
({'status': '200'}, payload.encode('utf-8')),
- ])
- credentials = credentials_from_code(self.client_id, self.client_secret,
+ ])
+ credentials = credentials_from_code(self.client_id, self.client_secret,
self.scope, self.code, redirect_uri=self.redirect_uri,
http=http)
- self.assertEqual(credentials.access_token, token)
- self.assertNotEqual(None, credentials.token_expiry)
- self.assertEqual(set(['foo']), credentials.scopes)
+ self.assertEqual(credentials.access_token, token)
+ self.assertNotEqual(None, credentials.token_expiry)
+ self.assertEqual(set(['foo']), credentials.scopes)
- def test_exchange_code_for_token_fail(self):
- http = HttpMockSequence([
+ def test_exchange_code_for_token_fail(self):
+ http = HttpMockSequence([
({'status': '400'}, b'{"error":"invalid_request"}'),
])
- try:
- credentials = credentials_from_code(self.client_id, self.client_secret,
+ try:
+ credentials = credentials_from_code(self.client_id, self.client_secret,
self.scope, self.code, redirect_uri=self.redirect_uri,
http=http)
- self.fail('should raise exception if exchange doesn\'t get 200')
- except FlowExchangeError:
- pass
+ self.fail('should raise exception if exchange doesn\'t get 200')
+ except FlowExchangeError:
+ pass
- def test_exchange_code_and_file_for_token(self):
- http = HttpMockSequence([
+ def test_exchange_code_and_file_for_token(self):
+ http = HttpMockSequence([
({'status': '200'},
b"""{ "access_token":"asdfghjkl",
"expires_in":3600 }"""),
- ])
- credentials = credentials_from_clientsecrets_and_code(
+ ])
+ credentials = credentials_from_clientsecrets_and_code(
datafile('client_secrets.json'), self.scope,
self.code, http=http)
- self.assertEqual(credentials.access_token, 'asdfghjkl')
- self.assertNotEqual(None, credentials.token_expiry)
- self.assertEqual(set(['foo']), credentials.scopes)
+ self.assertEqual(credentials.access_token, 'asdfghjkl')
+ self.assertNotEqual(None, credentials.token_expiry)
+ self.assertEqual(set(['foo']), credentials.scopes)
- def test_exchange_code_and_cached_file_for_token(self):
- http = HttpMockSequence([
+ def test_exchange_code_and_cached_file_for_token(self):
+ http = HttpMockSequence([
({'status': '200'}, b'{ "access_token":"asdfghjkl"}'),
])
- cache_mock = CacheMock()
- load_and_cache('client_secrets.json', 'some_secrets', cache_mock)
+ cache_mock = CacheMock()
+ load_and_cache('client_secrets.json', 'some_secrets', cache_mock)
- credentials = credentials_from_clientsecrets_and_code(
+ credentials = credentials_from_clientsecrets_and_code(
'some_secrets', self.scope,
self.code, http=http, cache=cache_mock)
- self.assertEqual(credentials.access_token, 'asdfghjkl')
- self.assertEqual(set(['foo']), credentials.scopes)
+ self.assertEqual(credentials.access_token, 'asdfghjkl')
+ self.assertEqual(set(['foo']), credentials.scopes)
- def test_exchange_code_and_file_for_token_fail(self):
- http = HttpMockSequence([
+ def test_exchange_code_and_file_for_token_fail(self):
+ http = HttpMockSequence([
({'status': '400'}, b'{"error":"invalid_request"}'),
])
- try:
- credentials = credentials_from_clientsecrets_and_code(
+ try:
+ credentials = credentials_from_clientsecrets_and_code(
datafile('client_secrets.json'), self.scope,
self.code, http=http)
- self.fail('should raise exception if exchange doesn\'t get 200')
- except FlowExchangeError:
- pass
+ self.fail('should raise exception if exchange doesn\'t get 200')
+ except FlowExchangeError:
+ pass
class MemoryCacheTests(unittest.TestCase):
- def test_get_set_delete(self):
- m = MemoryCache()
- self.assertEqual(None, m.get('foo'))
- self.assertEqual(None, m.delete('foo'))
- m.set('foo', 'bar')
- self.assertEqual('bar', m.get('foo'))
- m.delete('foo')
- self.assertEqual(None, m.get('foo'))
+ def test_get_set_delete(self):
+ m = MemoryCache()
+ self.assertEqual(None, m.get('foo'))
+ self.assertEqual(None, m.delete('foo'))
+ m.set('foo', 'bar')
+ self.assertEqual('bar', m.get('foo'))
+ m.delete('foo')
+ self.assertEqual(None, m.get('foo'))
class Test__save_private_file(unittest.TestCase):
- def _save_helper(self, filename):
- contents = []
- contents_str = '[]'
- client._save_private_file(filename, contents)
- with open(filename, 'r') as f:
- stored_contents = f.read()
- self.assertEqual(stored_contents, contents_str)
+ def _save_helper(self, filename):
+ contents = []
+ contents_str = '[]'
+ client._save_private_file(filename, contents)
+ with open(filename, 'r') as f:
+ stored_contents = f.read()
+ self.assertEqual(stored_contents, contents_str)
- stat_mode = os.stat(filename).st_mode
- # Octal 777, only last 3 positions matter for permissions mask.
- stat_mode &= 0o777
- self.assertEqual(stat_mode, 0o600)
+ stat_mode = os.stat(filename).st_mode
+ # Octal 777, only last 3 positions matter for permissions mask.
+ stat_mode &= 0o777
+ self.assertEqual(stat_mode, 0o600)
- def test_new(self):
- import tempfile
- filename = tempfile.mktemp()
- self.assertFalse(os.path.exists(filename))
- self._save_helper(filename)
+ def test_new(self):
+ import tempfile
+ filename = tempfile.mktemp()
+ self.assertFalse(os.path.exists(filename))
+ self._save_helper(filename)
- def test_existing(self):
- import tempfile
- filename = tempfile.mktemp()
- with open(filename, 'w') as f:
- f.write('a bunch of nonsense longer than []')
- self.assertTrue(os.path.exists(filename))
- self._save_helper(filename)
+ def test_existing(self):
+ import tempfile
+ filename = tempfile.mktemp()
+ with open(filename, 'w') as f:
+ f.write('a bunch of nonsense longer than []')
+ self.assertTrue(os.path.exists(filename))
+ self._save_helper(filename)
if __name__ == '__main__':
- unittest.main()
+ unittest.main()
diff --git a/tests/test_service_account.py b/tests/test_service_account.py
index 5d1a125..3306341 100644
--- a/tests/test_service_account.py
+++ b/tests/test_service_account.py
@@ -14,7 +14,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-
"""Oauth2client tests.
Unit tests for service account credentials implemented using RSA.
@@ -31,94 +30,94 @@ from oauth2client.service_account import _ServiceAccountCredentials
def datafile(filename):
- # TODO(orestica): Refactor this using pkgutil.get_data
- f = open(os.path.join(os.path.dirname(__file__), 'data', filename), 'rb')
- data = f.read()
- f.close()
- return data
+ # TODO(orestica): Refactor this using pkgutil.get_data
+ f = open(os.path.join(os.path.dirname(__file__), 'data', filename), 'rb')
+ data = f.read()
+ f.close()
+ return data
class ServiceAccountCredentialsTests(unittest.TestCase):
- def setUp(self):
- self.service_account_id = '123'
- self.service_account_email = 'dummy@google.com'
- self.private_key_id = 'ABCDEF'
- self.private_key = datafile('pem_from_pkcs12.pem')
- self.scopes = ['dummy_scope']
- self.credentials = _ServiceAccountCredentials(self.service_account_id,
+ def setUp(self):
+ self.service_account_id = '123'
+ self.service_account_email = 'dummy@google.com'
+ self.private_key_id = 'ABCDEF'
+ self.private_key = datafile('pem_from_pkcs12.pem')
+ self.scopes = ['dummy_scope']
+ self.credentials = _ServiceAccountCredentials(self.service_account_id,
self.service_account_email,
self.private_key_id,
self.private_key,
[])
- def test_sign_blob(self):
- private_key_id, signature = self.credentials.sign_blob('Google')
- self.assertEqual( self.private_key_id, private_key_id)
+ def test_sign_blob(self):
+ private_key_id, signature = self.credentials.sign_blob('Google')
+ self.assertEqual(self.private_key_id, private_key_id)
- pub_key = rsa.PublicKey.load_pkcs1_openssl_pem(
+ pub_key = rsa.PublicKey.load_pkcs1_openssl_pem(
datafile('publickey_openssl.pem'))
- self.assertTrue(rsa.pkcs1.verify(b'Google', signature, pub_key))
+ self.assertTrue(rsa.pkcs1.verify(b'Google', signature, pub_key))
- try:
- rsa.pkcs1.verify(b'Orest', signature, pub_key)
- self.fail('Verification should have failed!')
- except rsa.pkcs1.VerificationError:
- pass # Expected
+ try:
+ rsa.pkcs1.verify(b'Orest', signature, pub_key)
+ self.fail('Verification should have failed!')
+ except rsa.pkcs1.VerificationError:
+ pass # Expected
- try:
- rsa.pkcs1.verify(b'Google', b'bad signature', pub_key)
- self.fail('Verification should have failed!')
- except rsa.pkcs1.VerificationError:
- pass # Expected
+ try:
+ rsa.pkcs1.verify(b'Google', b'bad signature', pub_key)
+ self.fail('Verification should have failed!')
+ except rsa.pkcs1.VerificationError:
+ pass # Expected
- def test_service_account_email(self):
- self.assertEqual(self.service_account_email,
+ def test_service_account_email(self):
+ self.assertEqual(self.service_account_email,
self.credentials.service_account_email)
- def test_create_scoped_required_without_scopes(self):
- self.assertTrue(self.credentials.create_scoped_required())
+ def test_create_scoped_required_without_scopes(self):
+ self.assertTrue(self.credentials.create_scoped_required())
- def test_create_scoped_required_with_scopes(self):
- self.credentials = _ServiceAccountCredentials(self.service_account_id,
+ def test_create_scoped_required_with_scopes(self):
+ self.credentials = _ServiceAccountCredentials(self.service_account_id,
self.service_account_email,
self.private_key_id,
self.private_key,
self.scopes)
- self.assertFalse(self.credentials.create_scoped_required())
+ self.assertFalse(self.credentials.create_scoped_required())
- def test_create_scoped(self):
- new_credentials = self.credentials.create_scoped(self.scopes)
- self.assertNotEqual(self.credentials, new_credentials)
- self.assertTrue(isinstance(new_credentials, _ServiceAccountCredentials))
- self.assertEqual('dummy_scope', new_credentials._scopes)
+ def test_create_scoped(self):
+ new_credentials = self.credentials.create_scoped(self.scopes)
+ self.assertNotEqual(self.credentials, new_credentials)
+ self.assertTrue(isinstance(new_credentials, _ServiceAccountCredentials))
+ self.assertEqual('dummy_scope', new_credentials._scopes)
- def test_access_token(self):
- S = 2 # number of seconds in which the token expires
- token_response_first = {'access_token': 'first_token', 'expires_in': S}
- token_response_second = {'access_token': 'second_token', 'expires_in': S}
- http = HttpMockSequence([
+ def test_access_token(self):
+ S = 2 # number of seconds in which the token expires
+ token_response_first = {'access_token': 'first_token', 'expires_in': S}
+ token_response_second = {'access_token': 'second_token', 'expires_in': S}
+ http = HttpMockSequence([
({'status': '200'}, json.dumps(token_response_first).encode('utf-8')),
({'status': '200'}, json.dumps(token_response_second).encode('utf-8')),
- ])
+ ])
- token = self.credentials.get_access_token(http=http)
- self.assertEqual('first_token', token.access_token)
- self.assertEqual(S - 1, token.expires_in)
- self.assertFalse(self.credentials.access_token_expired)
- self.assertEqual(token_response_first, self.credentials.token_response)
+ token = self.credentials.get_access_token(http=http)
+ self.assertEqual('first_token', token.access_token)
+ self.assertEqual(S - 1, token.expires_in)
+ self.assertFalse(self.credentials.access_token_expired)
+ self.assertEqual(token_response_first, self.credentials.token_response)
- token = self.credentials.get_access_token(http=http)
- self.assertEqual('first_token', token.access_token)
- self.assertEqual(S - 1, token.expires_in)
- self.assertFalse(self.credentials.access_token_expired)
- self.assertEqual(token_response_first, self.credentials.token_response)
+ token = self.credentials.get_access_token(http=http)
+ self.assertEqual('first_token', token.access_token)
+ self.assertEqual(S - 1, token.expires_in)
+ self.assertFalse(self.credentials.access_token_expired)
+ self.assertEqual(token_response_first, self.credentials.token_response)
- time.sleep(S + 0.5) # some margin to avoid flakiness
- self.assertTrue(self.credentials.access_token_expired)
+ time.sleep(S + 0.5) # some margin to avoid flakiness
+ self.assertTrue(self.credentials.access_token_expired)
- token = self.credentials.get_access_token(http=http)
- self.assertEqual('second_token', token.access_token)
- self.assertEqual(S - 1, token.expires_in)
- self.assertFalse(self.credentials.access_token_expired)
- self.assertEqual(token_response_second, self.credentials.token_response)
+ token = self.credentials.get_access_token(http=http)
+ self.assertEqual('second_token', token.access_token)
+ self.assertEqual(S - 1, token.expires_in)
+ self.assertFalse(self.credentials.access_token_expired)
+ self.assertEqual(token_response_second, self.credentials.token_response)
diff --git a/tests/test_tools.py b/tests/test_tools.py
index 23aca90..6b32e39 100644
--- a/tests/test_tools.py
+++ b/tests/test_tools.py
@@ -5,6 +5,7 @@ from oauth2client import tools
from six.moves.urllib import request
import threading
+
class TestClientRedirectServer(unittest.TestCase):
"""Test the ClientRedirectServer and ClientRedirectHandler classes."""
@@ -15,16 +16,15 @@ class TestClientRedirectServer(unittest.TestCase):
httpd = tools.ClientRedirectServer(('localhost', 0), tools.ClientRedirectHandler)
code = 'foo'
url = 'http://localhost:%i?code=%s' % (httpd.server_address[1], code)
- t = threading.Thread(target = httpd.handle_request)
+ t = threading.Thread(target=httpd.handle_request)
t.setDaemon(True)
t.start()
- f = request.urlopen( url )
+ f = request.urlopen(url)
self.assertTrue(f.read())
t.join()
httpd.server_close()
- self.assertEqual(httpd.query_params.get('code'),code)
+ self.assertEqual(httpd.query_params.get('code'), code)
if __name__ == '__main__':
unittest.main()
-
diff --git a/tests/test_util.py b/tests/test_util.py
index b3fc326..7c781dd 100644
--- a/tests/test_util.py
+++ b/tests/test_util.py
@@ -9,8 +9,8 @@ from oauth2client import util
class ScopeToStringTests(unittest.TestCase):
- def test_iterables(self):
- cases = [
+ def test_iterables(self):
+ cases = [
('', ''),
('', ()),
('', []),
@@ -22,36 +22,37 @@ class ScopeToStringTests(unittest.TestCase):
('a b', ('a', 'b')),
('a b', 'a b'),
('a b', (s for s in ['a', 'b'])),
- ]
- for expected, case in cases:
- self.assertEqual(expected, util.scopes_to_string(case))
+ ]
+ for expected, case in cases:
+ self.assertEqual(expected, util.scopes_to_string(case))
class StringToScopeTests(unittest.TestCase):
- def test_conversion(self):
- cases = [
+ def test_conversion(self):
+ cases = [
(['a', 'b'], ['a', 'b']),
('', []),
('a', ['a']),
('a b c d e f', ['a', 'b', 'c', 'd', 'e', 'f']),
- ]
+ ]
+
+ for case, expected in cases:
+ self.assertEqual(expected, util.string_to_scopes(case))
- for case, expected in cases:
- self.assertEqual(expected, util.string_to_scopes(case))
class KeyConversionTests(unittest.TestCase):
- def test_key_conversions(self):
- d = {'somekey': 'some value', 'another': 'something else', 'onemore': 'foo'}
- tuple_key = util.dict_to_tuple_key(d)
+ def test_key_conversions(self):
+ d = {'somekey': 'some value', 'another': 'something else', 'onemore': 'foo'}
+ tuple_key = util.dict_to_tuple_key(d)
- # the resulting key should be naturally sorted
- self.assertEqual(
+ # the resulting key should be naturally sorted
+ self.assertEqual(
(('another', 'something else'),
('onemore', 'foo'),
('somekey', 'some value')),
tuple_key)
- # check we get the original dictionary back
- self.assertEqual(d, dict(tuple_key))
+ # check we get the original dictionary back
+ self.assertEqual(d, dict(tuple_key))
diff --git a/tests/test_xsrfutil.py b/tests/test_xsrfutil.py
index 5825b5e..6f36bbc 100644
--- a/tests/test_xsrfutil.py
+++ b/tests/test_xsrfutil.py
@@ -34,78 +34,78 @@ TEST_EXTRA_INFO_2 = 'more_extra_info'
class XsrfUtilTests(unittest.TestCase):
- """Test xsrfutil functions."""
+ """Test xsrfutil functions."""
- def testGenerateAndValidateToken(self):
- """Test generating and validating a token."""
- token = xsrfutil.generate_token(TEST_KEY,
+ def testGenerateAndValidateToken(self):
+ """Test generating and validating a token."""
+ token = xsrfutil.generate_token(TEST_KEY,
TEST_USER_ID_1,
action_id=TEST_ACTION_ID_1,
when=TEST_TIME)
- # Check that the token is considered valid when it should be.
- self.assertTrue(xsrfutil.validate_token(TEST_KEY,
+ # Check that the token is considered valid when it should be.
+ self.assertTrue(xsrfutil.validate_token(TEST_KEY,
token,
TEST_USER_ID_1,
action_id=TEST_ACTION_ID_1,
current_time=TEST_TIME))
- # Should still be valid 15 minutes later.
- later15mins = TEST_TIME + 15*60
- self.assertTrue(xsrfutil.validate_token(TEST_KEY,
+ # Should still be valid 15 minutes later.
+ later15mins = TEST_TIME + 15 * 60
+ self.assertTrue(xsrfutil.validate_token(TEST_KEY,
token,
TEST_USER_ID_1,
action_id=TEST_ACTION_ID_1,
current_time=later15mins))
- # But not if beyond the timeout.
- later2hours = TEST_TIME + 2*60*60
- self.assertFalse(xsrfutil.validate_token(TEST_KEY,
+ # But not if beyond the timeout.
+ later2hours = TEST_TIME + 2 * 60 * 60
+ self.assertFalse(xsrfutil.validate_token(TEST_KEY,
token,
TEST_USER_ID_1,
action_id=TEST_ACTION_ID_1,
current_time=later2hours))
- # Or if the key is different.
- self.assertFalse(xsrfutil.validate_token('another key',
+ # Or if the key is different.
+ self.assertFalse(xsrfutil.validate_token('another key',
token,
TEST_USER_ID_1,
action_id=TEST_ACTION_ID_1,
current_time=later15mins))
- # Or the user ID....
- self.assertFalse(xsrfutil.validate_token(TEST_KEY,
+ # Or the user ID....
+ self.assertFalse(xsrfutil.validate_token(TEST_KEY,
token,
TEST_USER_ID_2,
action_id=TEST_ACTION_ID_1,
current_time=later15mins))
- # Or the action ID...
- self.assertFalse(xsrfutil.validate_token(TEST_KEY,
+ # Or the action ID...
+ self.assertFalse(xsrfutil.validate_token(TEST_KEY,
token,
TEST_USER_ID_1,
action_id=TEST_ACTION_ID_2,
current_time=later15mins))
- # Invalid when truncated
- self.assertFalse(xsrfutil.validate_token(TEST_KEY,
+ # Invalid when truncated
+ self.assertFalse(xsrfutil.validate_token(TEST_KEY,
token[:-1],
TEST_USER_ID_1,
action_id=TEST_ACTION_ID_1,
current_time=later15mins))
- # Invalid with extra garbage
- self.assertFalse(xsrfutil.validate_token(TEST_KEY,
+ # Invalid with extra garbage
+ self.assertFalse(xsrfutil.validate_token(TEST_KEY,
token + b'x',
TEST_USER_ID_1,
action_id=TEST_ACTION_ID_1,
current_time=later15mins))
- # Invalid with token of None
- self.assertFalse(xsrfutil.validate_token(TEST_KEY,
+ # Invalid with token of None
+ self.assertFalse(xsrfutil.validate_token(TEST_KEY,
None,
TEST_USER_ID_1,
action_id=TEST_ACTION_ID_1))
if __name__ == '__main__':
- unittest.main()
+ unittest.main()