Raw pep8ify changes.
Simply ran pep8ify -w oauth2client/ pep8ify -w tests/
This commit is contained in:
@@ -19,7 +19,7 @@ import six
|
||||
|
||||
|
||||
def _parse_pem_key(raw_key_input):
|
||||
"""Identify and extract PEM keys.
|
||||
"""Identify and extract PEM keys.
|
||||
|
||||
Determines whether the given key is in the format of PEM key, and extracts
|
||||
the relevant part of the key if it is.
|
||||
@@ -30,17 +30,17 @@ def _parse_pem_key(raw_key_input):
|
||||
Returns:
|
||||
string, The actual key if the contents are from a PEM file, or else None.
|
||||
"""
|
||||
offset = raw_key_input.find(b'-----BEGIN ')
|
||||
if offset != -1:
|
||||
return raw_key_input[offset:]
|
||||
offset = raw_key_input.find(b'-----BEGIN ')
|
||||
if offset != -1:
|
||||
return raw_key_input[offset:]
|
||||
|
||||
|
||||
def _json_encode(data):
|
||||
return json.dumps(data, separators=(',', ':'))
|
||||
return json.dumps(data, separators=(',', ':'))
|
||||
|
||||
|
||||
def _to_bytes(value, encoding='ascii'):
|
||||
"""Converts a string value to bytes, if necessary.
|
||||
"""Converts a string value to bytes, if necessary.
|
||||
|
||||
Unfortunately, ``six.b`` is insufficient for this task since in
|
||||
Python2 it does not modify ``unicode`` objects.
|
||||
@@ -60,16 +60,16 @@ def _to_bytes(value, encoding='ascii'):
|
||||
Raises:
|
||||
ValueError if the value could not be converted to bytes.
|
||||
"""
|
||||
result = (value.encode(encoding)
|
||||
result = (value.encode(encoding)
|
||||
if isinstance(value, six.text_type) else value)
|
||||
if isinstance(result, six.binary_type):
|
||||
return result
|
||||
else:
|
||||
raise ValueError('%r could not be converted to bytes' % (value,))
|
||||
if isinstance(result, six.binary_type):
|
||||
return result
|
||||
else:
|
||||
raise ValueError('%r could not be converted to bytes' % (value, ))
|
||||
|
||||
|
||||
def _from_bytes(value):
|
||||
"""Converts bytes to a string value, if necessary.
|
||||
"""Converts bytes to a string value, if necessary.
|
||||
|
||||
Args:
|
||||
value: The string/bytes value to be converted.
|
||||
@@ -81,21 +81,21 @@ def _from_bytes(value):
|
||||
Raises:
|
||||
ValueError if the value could not be converted to unicode.
|
||||
"""
|
||||
result = (value.decode('utf-8')
|
||||
result = (value.decode('utf-8')
|
||||
if isinstance(value, six.binary_type) else value)
|
||||
if isinstance(result, six.text_type):
|
||||
return result
|
||||
else:
|
||||
raise ValueError('%r could not be converted to unicode' % (value,))
|
||||
if isinstance(result, six.text_type):
|
||||
return result
|
||||
else:
|
||||
raise ValueError('%r could not be converted to unicode' % (value, ))
|
||||
|
||||
|
||||
def _urlsafe_b64encode(raw_bytes):
|
||||
raw_bytes = _to_bytes(raw_bytes, encoding='utf-8')
|
||||
return base64.urlsafe_b64encode(raw_bytes).rstrip(b'=')
|
||||
raw_bytes = _to_bytes(raw_bytes, encoding='utf-8')
|
||||
return base64.urlsafe_b64encode(raw_bytes).rstrip(b'=')
|
||||
|
||||
|
||||
def _urlsafe_b64decode(b64string):
|
||||
# Guard against unicode strings, which base64 can't handle.
|
||||
b64string = _to_bytes(b64string)
|
||||
padded = b64string + b'=' * (4 - len(b64string) % 4)
|
||||
return base64.urlsafe_b64decode(padded)
|
||||
# Guard against unicode strings, which base64 can't handle.
|
||||
b64string = _to_bytes(b64string)
|
||||
padded = b64string + b'=' * (4 - len(b64string) % 4)
|
||||
return base64.urlsafe_b64decode(padded)
|
||||
|
||||
@@ -22,18 +22,18 @@ from oauth2client._helpers import _to_bytes
|
||||
|
||||
|
||||
class OpenSSLVerifier(object):
|
||||
"""Verifies the signature on a message."""
|
||||
"""Verifies the signature on a message."""
|
||||
|
||||
def __init__(self, pubkey):
|
||||
"""Constructor.
|
||||
def __init__(self, pubkey):
|
||||
"""Constructor.
|
||||
|
||||
Args:
|
||||
pubkey, OpenSSL.crypto.PKey, The public key to verify with.
|
||||
"""
|
||||
self._pubkey = pubkey
|
||||
self._pubkey = pubkey
|
||||
|
||||
def verify(self, message, signature):
|
||||
"""Verifies a message against a signature.
|
||||
def verify(self, message, signature):
|
||||
"""Verifies a message against a signature.
|
||||
|
||||
Args:
|
||||
message: string or bytes, The message to verify. If string, will be
|
||||
@@ -45,17 +45,17 @@ class OpenSSLVerifier(object):
|
||||
True if message was signed by the private key associated with the public
|
||||
key that this object was constructed with.
|
||||
"""
|
||||
message = _to_bytes(message, encoding='utf-8')
|
||||
signature = _to_bytes(signature, encoding='utf-8')
|
||||
try:
|
||||
crypto.verify(self._pubkey, signature, message, 'sha256')
|
||||
return True
|
||||
except crypto.Error:
|
||||
return False
|
||||
message = _to_bytes(message, encoding='utf-8')
|
||||
signature = _to_bytes(signature, encoding='utf-8')
|
||||
try:
|
||||
crypto.verify(self._pubkey, signature, message, 'sha256')
|
||||
return True
|
||||
except crypto.Error:
|
||||
return False
|
||||
|
||||
@staticmethod
|
||||
@staticmethod
|
||||
def from_string(key_pem, is_x509_cert):
|
||||
"""Construct a Verified instance from a string.
|
||||
"""Construct a Verified instance from a string.
|
||||
|
||||
Args:
|
||||
key_pem: string, public key in PEM format.
|
||||
@@ -68,26 +68,26 @@ class OpenSSLVerifier(object):
|
||||
Raises:
|
||||
OpenSSL.crypto.Error if the key_pem can't be parsed.
|
||||
"""
|
||||
if is_x509_cert:
|
||||
pubkey = crypto.load_certificate(crypto.FILETYPE_PEM, key_pem)
|
||||
else:
|
||||
pubkey = crypto.load_privatekey(crypto.FILETYPE_PEM, key_pem)
|
||||
return OpenSSLVerifier(pubkey)
|
||||
if is_x509_cert:
|
||||
pubkey = crypto.load_certificate(crypto.FILETYPE_PEM, key_pem)
|
||||
else:
|
||||
pubkey = crypto.load_privatekey(crypto.FILETYPE_PEM, key_pem)
|
||||
return OpenSSLVerifier(pubkey)
|
||||
|
||||
|
||||
class OpenSSLSigner(object):
|
||||
"""Signs messages with a private key."""
|
||||
"""Signs messages with a private key."""
|
||||
|
||||
def __init__(self, pkey):
|
||||
"""Constructor.
|
||||
def __init__(self, pkey):
|
||||
"""Constructor.
|
||||
|
||||
Args:
|
||||
pkey, OpenSSL.crypto.PKey (or equiv), The private key to sign with.
|
||||
"""
|
||||
self._key = pkey
|
||||
self._key = pkey
|
||||
|
||||
def sign(self, message):
|
||||
"""Signs a message.
|
||||
def sign(self, message):
|
||||
"""Signs a message.
|
||||
|
||||
Args:
|
||||
message: bytes, Message to be signed.
|
||||
@@ -95,12 +95,12 @@ class OpenSSLSigner(object):
|
||||
Returns:
|
||||
string, The signature of the message for the given key.
|
||||
"""
|
||||
message = _to_bytes(message, encoding='utf-8')
|
||||
return crypto.sign(self._key, message, 'sha256')
|
||||
message = _to_bytes(message, encoding='utf-8')
|
||||
return crypto.sign(self._key, message, 'sha256')
|
||||
|
||||
@staticmethod
|
||||
def from_string(key, password=b'notasecret'):
|
||||
"""Construct a Signer instance from a string.
|
||||
@staticmethod
|
||||
def from_string(key, password=b'notasecret'):
|
||||
"""Construct a Signer instance from a string.
|
||||
|
||||
Args:
|
||||
key: string, private key in PKCS12 or PEM format.
|
||||
@@ -112,17 +112,17 @@ class OpenSSLSigner(object):
|
||||
Raises:
|
||||
OpenSSL.crypto.Error if the key can't be parsed.
|
||||
"""
|
||||
parsed_pem_key = _parse_pem_key(key)
|
||||
if parsed_pem_key:
|
||||
pkey = crypto.load_privatekey(crypto.FILETYPE_PEM, parsed_pem_key)
|
||||
else:
|
||||
password = _to_bytes(password, encoding='utf-8')
|
||||
pkey = crypto.load_pkcs12(key, password).get_privatekey()
|
||||
return OpenSSLSigner(pkey)
|
||||
parsed_pem_key = _parse_pem_key(key)
|
||||
if parsed_pem_key:
|
||||
pkey = crypto.load_privatekey(crypto.FILETYPE_PEM, parsed_pem_key)
|
||||
else:
|
||||
password = _to_bytes(password, encoding='utf-8')
|
||||
pkey = crypto.load_pkcs12(key, password).get_privatekey()
|
||||
return OpenSSLSigner(pkey)
|
||||
|
||||
|
||||
def pkcs12_key_as_pem(private_key_text, private_key_password):
|
||||
"""Convert the contents of a PKCS12 key to PEM using OpenSSL.
|
||||
"""Convert the contents of a PKCS12 key to PEM using OpenSSL.
|
||||
|
||||
Args:
|
||||
private_key_text: String. Private key.
|
||||
@@ -131,9 +131,9 @@ def pkcs12_key_as_pem(private_key_text, private_key_password):
|
||||
Returns:
|
||||
String. PEM contents of ``private_key_text``.
|
||||
"""
|
||||
decoded_body = base64.b64decode(private_key_text)
|
||||
private_key_password = _to_bytes(private_key_password)
|
||||
decoded_body = base64.b64decode(private_key_text)
|
||||
private_key_password = _to_bytes(private_key_password)
|
||||
|
||||
pkcs12 = crypto.load_pkcs12(decoded_body, private_key_password)
|
||||
return crypto.dump_privatekey(crypto.FILETYPE_PEM,
|
||||
pkcs12 = crypto.load_pkcs12(decoded_body, private_key_password)
|
||||
return crypto.dump_privatekey(crypto.FILETYPE_PEM,
|
||||
pkcs12.get_privatekey())
|
||||
|
||||
@@ -25,18 +25,18 @@ from oauth2client._helpers import _urlsafe_b64decode
|
||||
|
||||
|
||||
class PyCryptoVerifier(object):
|
||||
"""Verifies the signature on a message."""
|
||||
"""Verifies the signature on a message."""
|
||||
|
||||
def __init__(self, pubkey):
|
||||
"""Constructor.
|
||||
def __init__(self, pubkey):
|
||||
"""Constructor.
|
||||
|
||||
Args:
|
||||
pubkey, OpenSSL.crypto.PKey (or equiv), The public key to verify with.
|
||||
"""
|
||||
self._pubkey = pubkey
|
||||
self._pubkey = pubkey
|
||||
|
||||
def verify(self, message, signature):
|
||||
"""Verifies a message against a signature.
|
||||
def verify(self, message, signature):
|
||||
"""Verifies a message against a signature.
|
||||
|
||||
Args:
|
||||
message: string or bytes, The message to verify. If string, will be
|
||||
@@ -47,13 +47,13 @@ class PyCryptoVerifier(object):
|
||||
True if message was signed by the private key associated with the public
|
||||
key that this object was constructed with.
|
||||
"""
|
||||
message = _to_bytes(message, encoding='utf-8')
|
||||
return PKCS1_v1_5.new(self._pubkey).verify(
|
||||
message = _to_bytes(message, encoding='utf-8')
|
||||
return PKCS1_v1_5.new(self._pubkey).verify(
|
||||
SHA256.new(message), signature)
|
||||
|
||||
@staticmethod
|
||||
def from_string(key_pem, is_x509_cert):
|
||||
"""Construct a Verified instance from a string.
|
||||
@staticmethod
|
||||
def from_string(key_pem, is_x509_cert):
|
||||
"""Construct a Verified instance from a string.
|
||||
|
||||
Args:
|
||||
key_pem: string, public key in PEM format.
|
||||
@@ -63,33 +63,33 @@ class PyCryptoVerifier(object):
|
||||
Returns:
|
||||
Verifier instance.
|
||||
"""
|
||||
if is_x509_cert:
|
||||
key_pem = _to_bytes(key_pem)
|
||||
pemLines = key_pem.replace(b' ', b'').split()
|
||||
certDer = _urlsafe_b64decode(b''.join(pemLines[1:-1]))
|
||||
certSeq = DerSequence()
|
||||
certSeq.decode(certDer)
|
||||
tbsSeq = DerSequence()
|
||||
tbsSeq.decode(certSeq[0])
|
||||
pubkey = RSA.importKey(tbsSeq[6])
|
||||
else:
|
||||
pubkey = RSA.importKey(key_pem)
|
||||
return PyCryptoVerifier(pubkey)
|
||||
if is_x509_cert:
|
||||
key_pem = _to_bytes(key_pem)
|
||||
pemLines = key_pem.replace(b' ', b'').split()
|
||||
certDer = _urlsafe_b64decode(b''.join(pemLines[1:-1]))
|
||||
certSeq = DerSequence()
|
||||
certSeq.decode(certDer)
|
||||
tbsSeq = DerSequence()
|
||||
tbsSeq.decode(certSeq[0])
|
||||
pubkey = RSA.importKey(tbsSeq[6])
|
||||
else:
|
||||
pubkey = RSA.importKey(key_pem)
|
||||
return PyCryptoVerifier(pubkey)
|
||||
|
||||
|
||||
class PyCryptoSigner(object):
|
||||
"""Signs messages with a private key."""
|
||||
"""Signs messages with a private key."""
|
||||
|
||||
def __init__(self, pkey):
|
||||
"""Constructor.
|
||||
def __init__(self, pkey):
|
||||
"""Constructor.
|
||||
|
||||
Args:
|
||||
pkey, OpenSSL.crypto.PKey (or equiv), The private key to sign with.
|
||||
"""
|
||||
self._key = pkey
|
||||
self._key = pkey
|
||||
|
||||
def sign(self, message):
|
||||
"""Signs a message.
|
||||
def sign(self, message):
|
||||
"""Signs a message.
|
||||
|
||||
Args:
|
||||
message: string, Message to be signed.
|
||||
@@ -97,12 +97,12 @@ class PyCryptoSigner(object):
|
||||
Returns:
|
||||
string, The signature of the message for the given key.
|
||||
"""
|
||||
message = _to_bytes(message, encoding='utf-8')
|
||||
return PKCS1_v1_5.new(self._key).sign(SHA256.new(message))
|
||||
message = _to_bytes(message, encoding='utf-8')
|
||||
return PKCS1_v1_5.new(self._key).sign(SHA256.new(message))
|
||||
|
||||
@staticmethod
|
||||
def from_string(key, password='notasecret'):
|
||||
"""Construct a Signer instance from a string.
|
||||
@staticmethod
|
||||
def from_string(key, password='notasecret'):
|
||||
"""Construct a Signer instance from a string.
|
||||
|
||||
Args:
|
||||
key: string, private key in PEM format.
|
||||
@@ -114,13 +114,13 @@ class PyCryptoSigner(object):
|
||||
Raises:
|
||||
NotImplementedError if the key isn't in PEM format.
|
||||
"""
|
||||
parsed_pem_key = _parse_pem_key(key)
|
||||
if parsed_pem_key:
|
||||
pkey = RSA.importKey(parsed_pem_key)
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
parsed_pem_key = _parse_pem_key(key)
|
||||
if parsed_pem_key:
|
||||
pkey = RSA.importKey(parsed_pem_key)
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
'PKCS12 format is not supported by the PyCrypto library. '
|
||||
'Try converting to a "PEM" '
|
||||
'(openssl pkcs12 -in xxxxx.p12 -nodes -nocerts > privatekey.pem) '
|
||||
'or using PyOpenSSL if native code is an option.')
|
||||
return PyCryptoSigner(pkey)
|
||||
return PyCryptoSigner(pkey)
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
@@ -23,7 +23,6 @@ __author__ = 'jcgregorio@google.com (Joe Gregorio)'
|
||||
import json
|
||||
import six
|
||||
|
||||
|
||||
# Properties that make a client_secrets.json file valid.
|
||||
TYPE_WEB = 'web'
|
||||
TYPE_INSTALLED = 'installed'
|
||||
@@ -59,65 +58,65 @@ VALID_CLIENT = {
|
||||
|
||||
|
||||
class Error(Exception):
|
||||
"""Base error for this module."""
|
||||
pass
|
||||
"""Base error for this module."""
|
||||
pass
|
||||
|
||||
|
||||
class InvalidClientSecretsError(Error):
|
||||
"""Format of ClientSecrets file is invalid."""
|
||||
pass
|
||||
"""Format of ClientSecrets file is invalid."""
|
||||
pass
|
||||
|
||||
|
||||
def _validate_clientsecrets(obj):
|
||||
_INVALID_FILE_FORMAT_MSG = (
|
||||
_INVALID_FILE_FORMAT_MSG = (
|
||||
'Invalid file format. See '
|
||||
'https://developers.google.com/api-client-library/'
|
||||
'python/guide/aaa_client_secrets')
|
||||
|
||||
if obj is None:
|
||||
raise InvalidClientSecretsError(_INVALID_FILE_FORMAT_MSG)
|
||||
if len(obj) != 1:
|
||||
raise InvalidClientSecretsError(
|
||||
if obj is None:
|
||||
raise InvalidClientSecretsError(_INVALID_FILE_FORMAT_MSG)
|
||||
if len(obj) != 1:
|
||||
raise InvalidClientSecretsError(
|
||||
_INVALID_FILE_FORMAT_MSG + ' '
|
||||
'Expected a JSON object with a single property for a "web" or '
|
||||
'"installed" application')
|
||||
client_type = tuple(obj)[0]
|
||||
if client_type not in VALID_CLIENT:
|
||||
raise InvalidClientSecretsError('Unknown client type: %s.' % (client_type,))
|
||||
client_info = obj[client_type]
|
||||
for prop_name in VALID_CLIENT[client_type]['required']:
|
||||
if prop_name not in client_info:
|
||||
raise InvalidClientSecretsError(
|
||||
client_type = tuple(obj)[0]
|
||||
if client_type not in VALID_CLIENT:
|
||||
raise InvalidClientSecretsError('Unknown client type: %s.' % (client_type, ))
|
||||
client_info = obj[client_type]
|
||||
for prop_name in VALID_CLIENT[client_type]['required']:
|
||||
if prop_name not in client_info:
|
||||
raise InvalidClientSecretsError(
|
||||
'Missing property "%s" in a client type of "%s".' % (prop_name,
|
||||
client_type))
|
||||
for prop_name in VALID_CLIENT[client_type]['string']:
|
||||
if client_info[prop_name].startswith('[['):
|
||||
raise InvalidClientSecretsError(
|
||||
for prop_name in VALID_CLIENT[client_type]['string']:
|
||||
if client_info[prop_name].startswith('[['):
|
||||
raise InvalidClientSecretsError(
|
||||
'Property "%s" is not configured.' % prop_name)
|
||||
return client_type, client_info
|
||||
return client_type, client_info
|
||||
|
||||
|
||||
def load(fp):
|
||||
obj = json.load(fp)
|
||||
return _validate_clientsecrets(obj)
|
||||
obj = json.load(fp)
|
||||
return _validate_clientsecrets(obj)
|
||||
|
||||
|
||||
def loads(s):
|
||||
obj = json.loads(s)
|
||||
return _validate_clientsecrets(obj)
|
||||
obj = json.loads(s)
|
||||
return _validate_clientsecrets(obj)
|
||||
|
||||
|
||||
def _loadfile(filename):
|
||||
try:
|
||||
with open(filename, 'r') as fp:
|
||||
obj = json.load(fp)
|
||||
except IOError:
|
||||
raise InvalidClientSecretsError('File not found: "%s"' % filename)
|
||||
return _validate_clientsecrets(obj)
|
||||
try:
|
||||
with open(filename, 'r') as fp:
|
||||
obj = json.load(fp)
|
||||
except IOError:
|
||||
raise InvalidClientSecretsError('File not found: "%s"' % filename)
|
||||
return _validate_clientsecrets(obj)
|
||||
|
||||
|
||||
def loadfile(filename, cache=None):
|
||||
"""Loading of client_secrets JSON file, optionally backed by a cache.
|
||||
"""Loading of client_secrets JSON file, optionally backed by a cache.
|
||||
|
||||
Typical cache storage would be App Engine memcache service,
|
||||
but you can pass in any other cache client that implements
|
||||
@@ -149,15 +148,15 @@ def loadfile(filename, cache=None):
|
||||
JSON contents is validated only during first load. Cache hits are not
|
||||
validated.
|
||||
"""
|
||||
_SECRET_NAMESPACE = 'oauth2client:secrets#ns'
|
||||
_SECRET_NAMESPACE = 'oauth2client:secrets#ns'
|
||||
|
||||
if not cache:
|
||||
return _loadfile(filename)
|
||||
if not cache:
|
||||
return _loadfile(filename)
|
||||
|
||||
obj = cache.get(filename, namespace=_SECRET_NAMESPACE)
|
||||
if obj is None:
|
||||
client_type, client_info = _loadfile(filename)
|
||||
obj = {client_type: client_info}
|
||||
cache.set(filename, obj, namespace=_SECRET_NAMESPACE)
|
||||
obj = cache.get(filename, namespace=_SECRET_NAMESPACE)
|
||||
if obj is None:
|
||||
client_type, client_info = _loadfile(filename)
|
||||
obj = {client_type: client_info}
|
||||
cache.set(filename, obj, namespace=_SECRET_NAMESPACE)
|
||||
|
||||
return next(six.iteritems(obj))
|
||||
return next(six.iteritems(obj))
|
||||
|
||||
@@ -25,51 +25,51 @@ from oauth2client._helpers import _to_bytes
|
||||
from oauth2client._helpers import _urlsafe_b64decode
|
||||
from oauth2client._helpers import _urlsafe_b64encode
|
||||
|
||||
|
||||
CLOCK_SKEW_SECS = 300 # 5 minutes in seconds
|
||||
AUTH_TOKEN_LIFETIME_SECS = 300 # 5 minutes in seconds
|
||||
MAX_TOKEN_LIFETIME_SECS = 86400 # 1 day in seconds
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class AppIdentityError(Exception):
|
||||
pass
|
||||
pass
|
||||
|
||||
|
||||
try:
|
||||
from oauth2client._openssl_crypt import OpenSSLVerifier
|
||||
from oauth2client._openssl_crypt import OpenSSLSigner
|
||||
from oauth2client._openssl_crypt import pkcs12_key_as_pem
|
||||
from oauth2client._openssl_crypt import OpenSSLVerifier
|
||||
from oauth2client._openssl_crypt import OpenSSLSigner
|
||||
from oauth2client._openssl_crypt import pkcs12_key_as_pem
|
||||
except ImportError:
|
||||
OpenSSLVerifier = None
|
||||
OpenSSLSigner = None
|
||||
def pkcs12_key_as_pem(*args, **kwargs):
|
||||
raise NotImplementedError('pkcs12_key_as_pem requires OpenSSL.')
|
||||
OpenSSLVerifier = None
|
||||
OpenSSLSigner = None
|
||||
|
||||
|
||||
def pkcs12_key_as_pem(*args, **kwargs):
|
||||
raise NotImplementedError('pkcs12_key_as_pem requires OpenSSL.')
|
||||
|
||||
|
||||
try:
|
||||
from oauth2client._pycrypto_crypt import PyCryptoVerifier
|
||||
from oauth2client._pycrypto_crypt import PyCryptoSigner
|
||||
from oauth2client._pycrypto_crypt import PyCryptoVerifier
|
||||
from oauth2client._pycrypto_crypt import PyCryptoSigner
|
||||
except ImportError:
|
||||
PyCryptoVerifier = None
|
||||
PyCryptoSigner = None
|
||||
PyCryptoVerifier = None
|
||||
PyCryptoSigner = None
|
||||
|
||||
|
||||
if OpenSSLSigner:
|
||||
Signer = OpenSSLSigner
|
||||
Verifier = OpenSSLVerifier
|
||||
Signer = OpenSSLSigner
|
||||
Verifier = OpenSSLVerifier
|
||||
elif PyCryptoSigner:
|
||||
Signer = PyCryptoSigner
|
||||
Verifier = PyCryptoVerifier
|
||||
Signer = PyCryptoSigner
|
||||
Verifier = PyCryptoVerifier
|
||||
else:
|
||||
raise ImportError('No encryption library found. Please install either '
|
||||
raise ImportError('No encryption library found. Please install either '
|
||||
'PyOpenSSL, or PyCrypto 2.6 or later')
|
||||
|
||||
|
||||
def make_signed_jwt(signer, payload):
|
||||
"""Make a signed JWT.
|
||||
"""Make a signed JWT.
|
||||
|
||||
See http://self-issued.info/docs/draft-jones-json-web-token.html.
|
||||
|
||||
@@ -80,24 +80,24 @@ def make_signed_jwt(signer, payload):
|
||||
Returns:
|
||||
string, The JWT for the payload.
|
||||
"""
|
||||
header = {'typ': 'JWT', 'alg': 'RS256'}
|
||||
header = {'typ': 'JWT', 'alg': 'RS256'}
|
||||
|
||||
segments = [
|
||||
segments = [
|
||||
_urlsafe_b64encode(_json_encode(header)),
|
||||
_urlsafe_b64encode(_json_encode(payload)),
|
||||
]
|
||||
signing_input = b'.'.join(segments)
|
||||
]
|
||||
signing_input = b'.'.join(segments)
|
||||
|
||||
signature = signer.sign(signing_input)
|
||||
segments.append(_urlsafe_b64encode(signature))
|
||||
signature = signer.sign(signing_input)
|
||||
segments.append(_urlsafe_b64encode(signature))
|
||||
|
||||
logger.debug(str(segments))
|
||||
logger.debug(str(segments))
|
||||
|
||||
return b'.'.join(segments)
|
||||
return b'.'.join(segments)
|
||||
|
||||
|
||||
def verify_signed_jwt_with_certs(jwt, certs, audience):
|
||||
"""Verify a JWT against public certs.
|
||||
"""Verify a JWT against public certs.
|
||||
|
||||
See http://self-issued.info/docs/draft-jones-json-web-token.html.
|
||||
|
||||
@@ -113,61 +113,61 @@ def verify_signed_jwt_with_certs(jwt, certs, audience):
|
||||
Raises:
|
||||
AppIdentityError if any checks are failed.
|
||||
"""
|
||||
jwt = _to_bytes(jwt)
|
||||
segments = jwt.split(b'.')
|
||||
jwt = _to_bytes(jwt)
|
||||
segments = jwt.split(b'.')
|
||||
|
||||
if len(segments) != 3:
|
||||
raise AppIdentityError('Wrong number of segments in token: %s' % jwt)
|
||||
signed = segments[0] + b'.' + segments[1]
|
||||
if len(segments) != 3:
|
||||
raise AppIdentityError('Wrong number of segments in token: %s' % jwt)
|
||||
signed = segments[0] + b'.' + segments[1]
|
||||
|
||||
signature = _urlsafe_b64decode(segments[2])
|
||||
signature = _urlsafe_b64decode(segments[2])
|
||||
|
||||
# Parse token.
|
||||
json_body = _urlsafe_b64decode(segments[1])
|
||||
try:
|
||||
parsed = json.loads(_from_bytes(json_body))
|
||||
except:
|
||||
raise AppIdentityError('Can\'t parse token: %s' % json_body)
|
||||
# Parse token.
|
||||
json_body = _urlsafe_b64decode(segments[1])
|
||||
try:
|
||||
parsed = json.loads(_from_bytes(json_body))
|
||||
except:
|
||||
raise AppIdentityError('Can\'t parse token: %s' % json_body)
|
||||
|
||||
# Check signature.
|
||||
verified = False
|
||||
for pem in certs.values():
|
||||
verifier = Verifier.from_string(pem, True)
|
||||
if verifier.verify(signed, signature):
|
||||
verified = True
|
||||
break
|
||||
if not verified:
|
||||
raise AppIdentityError('Invalid token signature: %s' % jwt)
|
||||
# Check signature.
|
||||
verified = False
|
||||
for pem in certs.values():
|
||||
verifier = Verifier.from_string(pem, True)
|
||||
if verifier.verify(signed, signature):
|
||||
verified = True
|
||||
break
|
||||
if not verified:
|
||||
raise AppIdentityError('Invalid token signature: %s' % jwt)
|
||||
|
||||
# Check creation timestamp.
|
||||
iat = parsed.get('iat')
|
||||
if iat is None:
|
||||
raise AppIdentityError('No iat field in token: %s' % json_body)
|
||||
earliest = iat - CLOCK_SKEW_SECS
|
||||
# Check creation timestamp.
|
||||
iat = parsed.get('iat')
|
||||
if iat is None:
|
||||
raise AppIdentityError('No iat field in token: %s' % json_body)
|
||||
earliest = iat - CLOCK_SKEW_SECS
|
||||
|
||||
# Check expiration timestamp.
|
||||
now = int(time.time())
|
||||
exp = parsed.get('exp')
|
||||
if exp is None:
|
||||
raise AppIdentityError('No exp field in token: %s' % json_body)
|
||||
if exp >= now + MAX_TOKEN_LIFETIME_SECS:
|
||||
raise AppIdentityError('exp field too far in future: %s' % json_body)
|
||||
latest = exp + CLOCK_SKEW_SECS
|
||||
# Check expiration timestamp.
|
||||
now = int(time.time())
|
||||
exp = parsed.get('exp')
|
||||
if exp is None:
|
||||
raise AppIdentityError('No exp field in token: %s' % json_body)
|
||||
if exp >= now + MAX_TOKEN_LIFETIME_SECS:
|
||||
raise AppIdentityError('exp field too far in future: %s' % json_body)
|
||||
latest = exp + CLOCK_SKEW_SECS
|
||||
|
||||
if now < earliest:
|
||||
raise AppIdentityError('Token used too early, %d < %d: %s' %
|
||||
if now < earliest:
|
||||
raise AppIdentityError('Token used too early, %d < %d: %s' %
|
||||
(now, earliest, json_body))
|
||||
if now > latest:
|
||||
raise AppIdentityError('Token used too late, %d > %d: %s' %
|
||||
if now > latest:
|
||||
raise AppIdentityError('Token used too late, %d > %d: %s' %
|
||||
(now, latest, json_body))
|
||||
|
||||
# Check audience.
|
||||
if audience is not None:
|
||||
aud = parsed.get('aud')
|
||||
if aud is None:
|
||||
raise AppIdentityError('No aud field in token: %s' % json_body)
|
||||
if aud != audience:
|
||||
raise AppIdentityError('Wrong recipient, %s != %s: %s' %
|
||||
# Check audience.
|
||||
if audience is not None:
|
||||
aud = parsed.get('aud')
|
||||
if aud is None:
|
||||
raise AppIdentityError('No aud field in token: %s' % json_body)
|
||||
if aud != audience:
|
||||
raise AppIdentityError('Wrong recipient, %s != %s: %s' %
|
||||
(aud, audience, json_body))
|
||||
|
||||
return parsed
|
||||
return parsed
|
||||
|
||||
@@ -20,22 +20,20 @@ import os
|
||||
from oauth2client._helpers import _to_bytes
|
||||
from oauth2client import client
|
||||
|
||||
|
||||
DEVSHELL_ENV = 'DEVSHELL_CLIENT_PORT'
|
||||
|
||||
|
||||
class Error(Exception):
|
||||
"""Errors for this module."""
|
||||
pass
|
||||
"""Errors for this module."""
|
||||
pass
|
||||
|
||||
|
||||
class CommunicationError(Error):
|
||||
"""Errors for communication with the Developer Shell server."""
|
||||
"""Errors for communication with the Developer Shell server."""
|
||||
|
||||
|
||||
class NoDevshellServer(Error):
|
||||
"""Error when no Developer Shell server can be contacted."""
|
||||
|
||||
"""Error when no Developer Shell server can be contacted."""
|
||||
|
||||
# The request for credential information to the Developer Shell client socket is
|
||||
# always an empty PBLite-formatted JSON object, so just define it as a constant.
|
||||
@@ -43,7 +41,7 @@ CREDENTIAL_INFO_REQUEST_JSON = '[]'
|
||||
|
||||
|
||||
class CredentialInfoResponse(object):
|
||||
"""Credential information response from Developer Shell server.
|
||||
"""Credential information response from Developer Shell server.
|
||||
|
||||
The credential information response from Developer Shell socket is a
|
||||
PBLite-formatted JSON array with fields encoded by their index in the array:
|
||||
@@ -52,46 +50,46 @@ class CredentialInfoResponse(object):
|
||||
* Index 2 - OAuth2 access token. None if there is no valid auth context.
|
||||
"""
|
||||
|
||||
def __init__(self, json_string):
|
||||
"""Initialize the response data from JSON PBLite array."""
|
||||
pbl = json.loads(json_string)
|
||||
if not isinstance(pbl, list):
|
||||
raise ValueError('Not a list: ' + str(pbl))
|
||||
pbl_len = len(pbl)
|
||||
self.user_email = pbl[0] if pbl_len > 0 else None
|
||||
self.project_id = pbl[1] if pbl_len > 1 else None
|
||||
self.access_token = pbl[2] if pbl_len > 2 else None
|
||||
def __init__(self, json_string):
|
||||
"""Initialize the response data from JSON PBLite array."""
|
||||
pbl = json.loads(json_string)
|
||||
if not isinstance(pbl, list):
|
||||
raise ValueError('Not a list: ' + str(pbl))
|
||||
pbl_len = len(pbl)
|
||||
self.user_email = pbl[0] if pbl_len > 0 else None
|
||||
self.project_id = pbl[1] if pbl_len > 1 else None
|
||||
self.access_token = pbl[2] if pbl_len > 2 else None
|
||||
|
||||
|
||||
def _SendRecv():
|
||||
"""Communicate with the Developer Shell server socket."""
|
||||
"""Communicate with the Developer Shell server socket."""
|
||||
|
||||
port = int(os.getenv(DEVSHELL_ENV, 0))
|
||||
if port == 0:
|
||||
raise NoDevshellServer()
|
||||
port = int(os.getenv(DEVSHELL_ENV, 0))
|
||||
if port == 0:
|
||||
raise NoDevshellServer()
|
||||
|
||||
import socket
|
||||
import socket
|
||||
|
||||
sock = socket.socket()
|
||||
sock.connect(('localhost', port))
|
||||
sock = socket.socket()
|
||||
sock.connect(('localhost', port))
|
||||
|
||||
data = CREDENTIAL_INFO_REQUEST_JSON
|
||||
msg = '%s\n%s' % (len(data), data)
|
||||
sock.sendall(_to_bytes(msg, encoding='utf-8'))
|
||||
data = CREDENTIAL_INFO_REQUEST_JSON
|
||||
msg = '%s\n%s' % (len(data), data)
|
||||
sock.sendall(_to_bytes(msg, encoding='utf-8'))
|
||||
|
||||
header = sock.recv(6).decode()
|
||||
if '\n' not in header:
|
||||
raise CommunicationError('saw no newline in the first 6 bytes')
|
||||
len_str, json_str = header.split('\n', 1)
|
||||
to_read = int(len_str) - len(json_str)
|
||||
if to_read > 0:
|
||||
json_str += sock.recv(to_read, socket.MSG_WAITALL).decode()
|
||||
header = sock.recv(6).decode()
|
||||
if '\n' not in header:
|
||||
raise CommunicationError('saw no newline in the first 6 bytes')
|
||||
len_str, json_str = header.split('\n', 1)
|
||||
to_read = int(len_str) - len(json_str)
|
||||
if to_read > 0:
|
||||
json_str += sock.recv(to_read, socket.MSG_WAITALL).decode()
|
||||
|
||||
return CredentialInfoResponse(json_str)
|
||||
return CredentialInfoResponse(json_str)
|
||||
|
||||
|
||||
class DevshellCredentials(client.GoogleCredentials):
|
||||
"""Credentials object for Google Developer Shell environment.
|
||||
"""Credentials object for Google Developer Shell environment.
|
||||
|
||||
This object will allow a Google Developer Shell session to identify its user
|
||||
to Google and other OAuth 2.0 servers that can verify assertions. It can be
|
||||
@@ -102,8 +100,8 @@ class DevshellCredentials(client.GoogleCredentials):
|
||||
generate and refresh its own access tokens.
|
||||
"""
|
||||
|
||||
def __init__(self, user_agent=None):
|
||||
super(DevshellCredentials, self).__init__(
|
||||
def __init__(self, user_agent=None):
|
||||
super(DevshellCredentials, self).__init__(
|
||||
None, # access_token, initialized below
|
||||
None, # client_id
|
||||
None, # client_secret
|
||||
@@ -111,26 +109,26 @@ class DevshellCredentials(client.GoogleCredentials):
|
||||
None, # token_expiry
|
||||
None, # token_uri
|
||||
user_agent)
|
||||
self._refresh(None)
|
||||
self._refresh(None)
|
||||
|
||||
def _refresh(self, http_request):
|
||||
self.devshell_response = _SendRecv()
|
||||
self.access_token = self.devshell_response.access_token
|
||||
def _refresh(self, http_request):
|
||||
self.devshell_response = _SendRecv()
|
||||
self.access_token = self.devshell_response.access_token
|
||||
|
||||
@property
|
||||
def user_email(self):
|
||||
return self.devshell_response.user_email
|
||||
@property
|
||||
def user_email(self):
|
||||
return self.devshell_response.user_email
|
||||
|
||||
@property
|
||||
def project_id(self):
|
||||
return self.devshell_response.project_id
|
||||
@property
|
||||
def project_id(self):
|
||||
return self.devshell_response.project_id
|
||||
|
||||
@classmethod
|
||||
def from_json(cls, json_data):
|
||||
raise NotImplementedError(
|
||||
@classmethod
|
||||
def from_json(cls, json_data):
|
||||
raise NotImplementedError(
|
||||
'Cannot load Developer Shell credentials from JSON.')
|
||||
|
||||
@property
|
||||
def serialization_data(self):
|
||||
raise NotImplementedError(
|
||||
@property
|
||||
def serialization_data(self):
|
||||
raise NotImplementedError(
|
||||
'Cannot serialize Developer Shell credentials.')
|
||||
|
||||
@@ -27,58 +27,59 @@ import pickle
|
||||
from django.db import models
|
||||
from oauth2client.client import Storage as BaseStorage
|
||||
|
||||
|
||||
class CredentialsField(models.Field):
|
||||
|
||||
__metaclass__ = models.SubfieldBase
|
||||
__metaclass__ = models.SubfieldBase
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
if 'null' not in kwargs:
|
||||
kwargs['null'] = True
|
||||
super(CredentialsField, self).__init__(*args, **kwargs)
|
||||
def __init__(self, *args, **kwargs):
|
||||
if 'null' not in kwargs:
|
||||
kwargs['null'] = True
|
||||
super(CredentialsField, self).__init__(*args, **kwargs)
|
||||
|
||||
def get_internal_type(self):
|
||||
return "TextField"
|
||||
def get_internal_type(self):
|
||||
return "TextField"
|
||||
|
||||
def to_python(self, value):
|
||||
if value is None:
|
||||
return None
|
||||
if isinstance(value, oauth2client.client.Credentials):
|
||||
return value
|
||||
return pickle.loads(base64.b64decode(value))
|
||||
def to_python(self, value):
|
||||
if value is None:
|
||||
return None
|
||||
if isinstance(value, oauth2client.client.Credentials):
|
||||
return value
|
||||
return pickle.loads(base64.b64decode(value))
|
||||
|
||||
def get_db_prep_value(self, value, connection, prepared=False):
|
||||
if value is None:
|
||||
return None
|
||||
return base64.b64encode(pickle.dumps(value))
|
||||
def get_db_prep_value(self, value, connection, prepared=False):
|
||||
if value is None:
|
||||
return None
|
||||
return base64.b64encode(pickle.dumps(value))
|
||||
|
||||
|
||||
class FlowField(models.Field):
|
||||
|
||||
__metaclass__ = models.SubfieldBase
|
||||
__metaclass__ = models.SubfieldBase
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
if 'null' not in kwargs:
|
||||
kwargs['null'] = True
|
||||
super(FlowField, self).__init__(*args, **kwargs)
|
||||
def __init__(self, *args, **kwargs):
|
||||
if 'null' not in kwargs:
|
||||
kwargs['null'] = True
|
||||
super(FlowField, self).__init__(*args, **kwargs)
|
||||
|
||||
def get_internal_type(self):
|
||||
return "TextField"
|
||||
def get_internal_type(self):
|
||||
return "TextField"
|
||||
|
||||
def to_python(self, value):
|
||||
if value is None:
|
||||
return None
|
||||
if isinstance(value, oauth2client.client.Flow):
|
||||
return value
|
||||
return pickle.loads(base64.b64decode(value))
|
||||
def to_python(self, value):
|
||||
if value is None:
|
||||
return None
|
||||
if isinstance(value, oauth2client.client.Flow):
|
||||
return value
|
||||
return pickle.loads(base64.b64decode(value))
|
||||
|
||||
def get_db_prep_value(self, value, connection, prepared=False):
|
||||
if value is None:
|
||||
return None
|
||||
return base64.b64encode(pickle.dumps(value))
|
||||
def get_db_prep_value(self, value, connection, prepared=False):
|
||||
if value is None:
|
||||
return None
|
||||
return base64.b64encode(pickle.dumps(value))
|
||||
|
||||
|
||||
class Storage(BaseStorage):
|
||||
"""Store and retrieve a single credential to and from
|
||||
"""Store and retrieve a single credential to and from
|
||||
the datastore.
|
||||
|
||||
This Storage helper presumes the Credentials
|
||||
@@ -86,8 +87,8 @@ class Storage(BaseStorage):
|
||||
on a db model class.
|
||||
"""
|
||||
|
||||
def __init__(self, model_class, key_name, key_value, property_name):
|
||||
"""Constructor for Storage.
|
||||
def __init__(self, model_class, key_name, key_value, property_name):
|
||||
"""Constructor for Storage.
|
||||
|
||||
Args:
|
||||
model: db.Model, model class
|
||||
@@ -95,47 +96,47 @@ class Storage(BaseStorage):
|
||||
key_value: string, key value for the entity that has the credentials
|
||||
property_name: string, name of the property that is an CredentialsProperty
|
||||
"""
|
||||
self.model_class = model_class
|
||||
self.key_name = key_name
|
||||
self.key_value = key_value
|
||||
self.property_name = property_name
|
||||
self.model_class = model_class
|
||||
self.key_name = key_name
|
||||
self.key_value = key_value
|
||||
self.property_name = property_name
|
||||
|
||||
def locked_get(self):
|
||||
"""Retrieve Credential from datastore.
|
||||
def locked_get(self):
|
||||
"""Retrieve Credential from datastore.
|
||||
|
||||
Returns:
|
||||
oauth2client.Credentials
|
||||
"""
|
||||
credential = None
|
||||
credential = None
|
||||
|
||||
query = {self.key_name: self.key_value}
|
||||
entities = self.model_class.objects.filter(**query)
|
||||
if len(entities) > 0:
|
||||
credential = getattr(entities[0], self.property_name)
|
||||
if credential and hasattr(credential, 'set_store'):
|
||||
credential.set_store(self)
|
||||
return credential
|
||||
query = {self.key_name: self.key_value}
|
||||
entities = self.model_class.objects.filter(**query)
|
||||
if len(entities) > 0:
|
||||
credential = getattr(entities[0], self.property_name)
|
||||
if credential and hasattr(credential, 'set_store'):
|
||||
credential.set_store(self)
|
||||
return credential
|
||||
|
||||
def locked_put(self, credentials, overwrite=False):
|
||||
"""Write a Credentials to the datastore.
|
||||
def locked_put(self, credentials, overwrite=False):
|
||||
"""Write a Credentials to the datastore.
|
||||
|
||||
Args:
|
||||
credentials: Credentials, the credentials to store.
|
||||
overwrite: Boolean, indicates whether you would like these credentials to
|
||||
overwrite any existing stored credentials.
|
||||
"""
|
||||
args = {self.key_name: self.key_value}
|
||||
args = {self.key_name: self.key_value}
|
||||
|
||||
if overwrite:
|
||||
entity, unused_is_new = self.model_class.objects.get_or_create(**args)
|
||||
else:
|
||||
entity = self.model_class(**args)
|
||||
if overwrite:
|
||||
entity, unused_is_new = self.model_class.objects.get_or_create(**args)
|
||||
else:
|
||||
entity = self.model_class(**args)
|
||||
|
||||
setattr(entity, self.property_name, credentials)
|
||||
entity.save()
|
||||
setattr(entity, self.property_name, credentials)
|
||||
entity.save()
|
||||
|
||||
def locked_delete(self):
|
||||
"""Delete Credentials from the datastore."""
|
||||
def locked_delete(self):
|
||||
"""Delete Credentials from the datastore."""
|
||||
|
||||
query = {self.key_name: self.key_value}
|
||||
entities = self.model_class.objects.filter(**query).delete()
|
||||
query = {self.key_name: self.key_value}
|
||||
entities = self.model_class.objects.filter(**query).delete()
|
||||
|
||||
@@ -28,37 +28,37 @@ from oauth2client.client import Storage as BaseStorage
|
||||
|
||||
|
||||
class CredentialsFileSymbolicLinkError(Exception):
|
||||
"""Credentials files must not be symbolic links."""
|
||||
"""Credentials files must not be symbolic links."""
|
||||
|
||||
|
||||
class Storage(BaseStorage):
|
||||
"""Store and retrieve a single credential to and from a file."""
|
||||
"""Store and retrieve a single credential to and from a file."""
|
||||
|
||||
def __init__(self, filename):
|
||||
self._filename = filename
|
||||
self._lock = threading.Lock()
|
||||
def __init__(self, filename):
|
||||
self._filename = filename
|
||||
self._lock = threading.Lock()
|
||||
|
||||
def _validate_file(self):
|
||||
if os.path.islink(self._filename):
|
||||
raise CredentialsFileSymbolicLinkError(
|
||||
def _validate_file(self):
|
||||
if os.path.islink(self._filename):
|
||||
raise CredentialsFileSymbolicLinkError(
|
||||
'File: %s is a symbolic link.' % self._filename)
|
||||
|
||||
def acquire_lock(self):
|
||||
"""Acquires any lock necessary to access this Storage.
|
||||
def acquire_lock(self):
|
||||
"""Acquires any lock necessary to access this Storage.
|
||||
|
||||
This lock is not reentrant."""
|
||||
self._lock.acquire()
|
||||
self._lock.acquire()
|
||||
|
||||
def release_lock(self):
|
||||
"""Release the Storage lock.
|
||||
def release_lock(self):
|
||||
"""Release the Storage lock.
|
||||
|
||||
Trying to release a lock that isn't held will result in a
|
||||
RuntimeError.
|
||||
"""
|
||||
self._lock.release()
|
||||
self._lock.release()
|
||||
|
||||
def locked_get(self):
|
||||
"""Retrieve Credential from file.
|
||||
def locked_get(self):
|
||||
"""Retrieve Credential from file.
|
||||
|
||||
Returns:
|
||||
oauth2client.client.Credentials
|
||||
@@ -66,38 +66,38 @@ class Storage(BaseStorage):
|
||||
Raises:
|
||||
CredentialsFileSymbolicLinkError if the file is a symbolic link.
|
||||
"""
|
||||
credentials = None
|
||||
self._validate_file()
|
||||
try:
|
||||
f = open(self._filename, 'rb')
|
||||
content = f.read()
|
||||
f.close()
|
||||
except IOError:
|
||||
return credentials
|
||||
credentials = None
|
||||
self._validate_file()
|
||||
try:
|
||||
f = open(self._filename, 'rb')
|
||||
content = f.read()
|
||||
f.close()
|
||||
except IOError:
|
||||
return credentials
|
||||
|
||||
try:
|
||||
credentials = Credentials.new_from_json(content)
|
||||
credentials.set_store(self)
|
||||
except ValueError:
|
||||
pass
|
||||
try:
|
||||
credentials = Credentials.new_from_json(content)
|
||||
credentials.set_store(self)
|
||||
except ValueError:
|
||||
pass
|
||||
|
||||
return credentials
|
||||
return credentials
|
||||
|
||||
def _create_file_if_needed(self):
|
||||
"""Create an empty file if necessary.
|
||||
def _create_file_if_needed(self):
|
||||
"""Create an empty file if necessary.
|
||||
|
||||
This method will not initialize the file. Instead it implements a
|
||||
simple version of "touch" to ensure the file has been created.
|
||||
"""
|
||||
if not os.path.exists(self._filename):
|
||||
old_umask = os.umask(0o177)
|
||||
try:
|
||||
open(self._filename, 'a+b').close()
|
||||
finally:
|
||||
os.umask(old_umask)
|
||||
if not os.path.exists(self._filename):
|
||||
old_umask = os.umask(0o177)
|
||||
try:
|
||||
open(self._filename, 'a+b').close()
|
||||
finally:
|
||||
os.umask(old_umask)
|
||||
|
||||
def locked_put(self, credentials):
|
||||
"""Write Credentials to file.
|
||||
def locked_put(self, credentials):
|
||||
"""Write Credentials to file.
|
||||
|
||||
Args:
|
||||
credentials: Credentials, the credentials to store.
|
||||
@@ -106,17 +106,17 @@ class Storage(BaseStorage):
|
||||
CredentialsFileSymbolicLinkError if the file is a symbolic link.
|
||||
"""
|
||||
|
||||
self._create_file_if_needed()
|
||||
self._validate_file()
|
||||
f = open(self._filename, 'w')
|
||||
f.write(credentials.to_json())
|
||||
f.close()
|
||||
self._create_file_if_needed()
|
||||
self._validate_file()
|
||||
f = open(self._filename, 'w')
|
||||
f.write(credentials.to_json())
|
||||
f.close()
|
||||
|
||||
def locked_delete(self):
|
||||
"""Delete Credentials file.
|
||||
def locked_delete(self):
|
||||
"""Delete Credentials file.
|
||||
|
||||
Args:
|
||||
credentials: Credentials, the credentials to store.
|
||||
"""
|
||||
|
||||
os.unlink(self._filename)
|
||||
os.unlink(self._filename)
|
||||
|
||||
@@ -189,8 +189,7 @@ from oauth2client.client import Storage
|
||||
from oauth2client import clientsecrets
|
||||
from oauth2client import util
|
||||
|
||||
|
||||
DEFAULT_SCOPES = ('email',)
|
||||
DEFAULT_SCOPES = ('email', )
|
||||
|
||||
|
||||
class UserOAuth2(object):
|
||||
@@ -461,6 +460,7 @@ class UserOAuth2(object):
|
||||
be redirected to the authorization flow. Once complete, the user will
|
||||
be redirected back to the original page.
|
||||
"""
|
||||
|
||||
def curry_wrapper(wrapped_function):
|
||||
@wraps(wrapped_function)
|
||||
def required_wrapper(*args, **kwargs):
|
||||
@@ -519,6 +519,7 @@ class FlaskSessionStorage(Storage):
|
||||
credentials. We strongly recommend using a server-side session
|
||||
implementation.
|
||||
"""
|
||||
|
||||
def locked_get(self):
|
||||
serialized = session.get('google_oauth2_credentials')
|
||||
|
||||
|
||||
@@ -36,7 +36,7 @@ META = ('http://metadata.google.internal/0.1/meta-data/service-accounts/'
|
||||
|
||||
|
||||
class AppAssertionCredentials(AssertionCredentials):
|
||||
"""Credentials object for Compute Engine Assertion Grants
|
||||
"""Credentials object for Compute Engine Assertion Grants
|
||||
|
||||
This object will allow a Compute Engine instance to identify itself to
|
||||
Google and other OAuth 2.0 servers that can verify assertions. It can be used
|
||||
@@ -48,27 +48,27 @@ class AppAssertionCredentials(AssertionCredentials):
|
||||
generate and refresh its own access tokens.
|
||||
"""
|
||||
|
||||
@util.positional(2)
|
||||
def __init__(self, scope, **kwargs):
|
||||
"""Constructor for AppAssertionCredentials
|
||||
@util.positional(2)
|
||||
def __init__(self, scope, **kwargs):
|
||||
"""Constructor for AppAssertionCredentials
|
||||
|
||||
Args:
|
||||
scope: string or iterable of strings, scope(s) of the credentials being
|
||||
requested.
|
||||
"""
|
||||
self.scope = util.scopes_to_string(scope)
|
||||
self.kwargs = kwargs
|
||||
self.scope = util.scopes_to_string(scope)
|
||||
self.kwargs = kwargs
|
||||
|
||||
# Assertion type is no longer used, but still in the parent class signature.
|
||||
super(AppAssertionCredentials, self).__init__(None)
|
||||
# Assertion type is no longer used, but still in the parent class signature.
|
||||
super(AppAssertionCredentials, self).__init__(None)
|
||||
|
||||
@classmethod
|
||||
def from_json(cls, json_data):
|
||||
data = json.loads(_from_bytes(json_data))
|
||||
return AppAssertionCredentials(data['scope'])
|
||||
@classmethod
|
||||
def from_json(cls, json_data):
|
||||
data = json.loads(_from_bytes(json_data))
|
||||
return AppAssertionCredentials(data['scope'])
|
||||
|
||||
def _refresh(self, http_request):
|
||||
"""Refreshes the access_token.
|
||||
def _refresh(self, http_request):
|
||||
"""Refreshes the access_token.
|
||||
|
||||
Skip all the storage hoops and just refresh using the API.
|
||||
|
||||
@@ -79,29 +79,29 @@ class AppAssertionCredentials(AssertionCredentials):
|
||||
Raises:
|
||||
AccessTokenRefreshError: When the refresh fails.
|
||||
"""
|
||||
query = '?scope=%s' % urllib.parse.quote(self.scope, '')
|
||||
uri = META.replace('{?scope}', query)
|
||||
response, content = http_request(uri)
|
||||
content = _from_bytes(content)
|
||||
if response.status == 200:
|
||||
try:
|
||||
d = json.loads(content)
|
||||
except Exception as e:
|
||||
raise AccessTokenRefreshError(str(e))
|
||||
self.access_token = d['accessToken']
|
||||
else:
|
||||
if response.status == 404:
|
||||
content += (' This can occur if a VM was created'
|
||||
query = '?scope=%s' % urllib.parse.quote(self.scope, '')
|
||||
uri = META.replace('{?scope}', query)
|
||||
response, content = http_request(uri)
|
||||
content = _from_bytes(content)
|
||||
if response.status == 200:
|
||||
try:
|
||||
d = json.loads(content)
|
||||
except Exception as e:
|
||||
raise AccessTokenRefreshError(str(e))
|
||||
self.access_token = d['accessToken']
|
||||
else:
|
||||
if response.status == 404:
|
||||
content += (' This can occur if a VM was created'
|
||||
' with no service account or scopes.')
|
||||
raise AccessTokenRefreshError(content)
|
||||
raise AccessTokenRefreshError(content)
|
||||
|
||||
@property
|
||||
@property
|
||||
def serialization_data(self):
|
||||
raise NotImplementedError(
|
||||
raise NotImplementedError(
|
||||
'Cannot serialize credentials for GCE service accounts.')
|
||||
|
||||
def create_scoped_required(self):
|
||||
return not self.scope
|
||||
def create_scoped_required(self):
|
||||
return not self.scope
|
||||
|
||||
def create_scoped(self, scopes):
|
||||
return AppAssertionCredentials(scopes, **self.kwargs)
|
||||
def create_scoped(self, scopes):
|
||||
return AppAssertionCredentials(scopes, **self.kwargs)
|
||||
|
||||
@@ -28,7 +28,7 @@ from oauth2client.client import Storage as BaseStorage
|
||||
|
||||
|
||||
class Storage(BaseStorage):
|
||||
"""Store and retrieve a single credential to and from the keyring.
|
||||
"""Store and retrieve a single credential to and from the keyring.
|
||||
|
||||
To use this module you must have the keyring module installed. See
|
||||
<http://pypi.python.org/pypi/keyring/>. This is an optional module and is not
|
||||
@@ -48,63 +48,63 @@ class Storage(BaseStorage):
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, service_name, user_name):
|
||||
"""Constructor.
|
||||
def __init__(self, service_name, user_name):
|
||||
"""Constructor.
|
||||
|
||||
Args:
|
||||
service_name: string, The name of the service under which the credentials
|
||||
are stored.
|
||||
user_name: string, The name of the user to store credentials for.
|
||||
"""
|
||||
self._service_name = service_name
|
||||
self._user_name = user_name
|
||||
self._lock = threading.Lock()
|
||||
self._service_name = service_name
|
||||
self._user_name = user_name
|
||||
self._lock = threading.Lock()
|
||||
|
||||
def acquire_lock(self):
|
||||
"""Acquires any lock necessary to access this Storage.
|
||||
def acquire_lock(self):
|
||||
"""Acquires any lock necessary to access this Storage.
|
||||
|
||||
This lock is not reentrant."""
|
||||
self._lock.acquire()
|
||||
self._lock.acquire()
|
||||
|
||||
def release_lock(self):
|
||||
"""Release the Storage lock.
|
||||
def release_lock(self):
|
||||
"""Release the Storage lock.
|
||||
|
||||
Trying to release a lock that isn't held will result in a
|
||||
RuntimeError.
|
||||
"""
|
||||
self._lock.release()
|
||||
self._lock.release()
|
||||
|
||||
def locked_get(self):
|
||||
"""Retrieve Credential from file.
|
||||
def locked_get(self):
|
||||
"""Retrieve Credential from file.
|
||||
|
||||
Returns:
|
||||
oauth2client.client.Credentials
|
||||
"""
|
||||
credentials = None
|
||||
content = keyring.get_password(self._service_name, self._user_name)
|
||||
credentials = None
|
||||
content = keyring.get_password(self._service_name, self._user_name)
|
||||
|
||||
if content is not None:
|
||||
try:
|
||||
credentials = Credentials.new_from_json(content)
|
||||
credentials.set_store(self)
|
||||
except ValueError:
|
||||
pass
|
||||
if content is not None:
|
||||
try:
|
||||
credentials = Credentials.new_from_json(content)
|
||||
credentials.set_store(self)
|
||||
except ValueError:
|
||||
pass
|
||||
|
||||
return credentials
|
||||
return credentials
|
||||
|
||||
def locked_put(self, credentials):
|
||||
"""Write Credentials to file.
|
||||
def locked_put(self, credentials):
|
||||
"""Write Credentials to file.
|
||||
|
||||
Args:
|
||||
credentials: Credentials, the credentials to store.
|
||||
"""
|
||||
keyring.set_password(self._service_name, self._user_name,
|
||||
keyring.set_password(self._service_name, self._user_name,
|
||||
credentials.to_json())
|
||||
|
||||
def locked_delete(self):
|
||||
"""Delete Credentials file.
|
||||
def locked_delete(self):
|
||||
"""Delete Credentials file.
|
||||
|
||||
Args:
|
||||
credentials: Credentials, the credentials to store.
|
||||
"""
|
||||
keyring.set_password(self._service_name, self._user_name, '')
|
||||
keyring.set_password(self._service_name, self._user_name, '')
|
||||
|
||||
@@ -45,68 +45,69 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class CredentialsFileSymbolicLinkError(Exception):
|
||||
"""Credentials files must not be symbolic links."""
|
||||
"""Credentials files must not be symbolic links."""
|
||||
|
||||
|
||||
class AlreadyLockedException(Exception):
|
||||
"""Trying to lock a file that has already been locked by the LockedFile."""
|
||||
pass
|
||||
"""Trying to lock a file that has already been locked by the LockedFile."""
|
||||
pass
|
||||
|
||||
|
||||
def validate_file(filename):
|
||||
if os.path.islink(filename):
|
||||
raise CredentialsFileSymbolicLinkError(
|
||||
if os.path.islink(filename):
|
||||
raise CredentialsFileSymbolicLinkError(
|
||||
'File: %s is a symbolic link.' % filename)
|
||||
|
||||
class _Opener(object):
|
||||
"""Base class for different locking primitives."""
|
||||
|
||||
def __init__(self, filename, mode, fallback_mode):
|
||||
"""Create an Opener.
|
||||
class _Opener(object):
|
||||
"""Base class for different locking primitives."""
|
||||
|
||||
def __init__(self, filename, mode, fallback_mode):
|
||||
"""Create an Opener.
|
||||
|
||||
Args:
|
||||
filename: string, The pathname of the file.
|
||||
mode: string, The preferred mode to access the file with.
|
||||
fallback_mode: string, The mode to use if locking fails.
|
||||
"""
|
||||
self._locked = False
|
||||
self._filename = filename
|
||||
self._mode = mode
|
||||
self._fallback_mode = fallback_mode
|
||||
self._fh = None
|
||||
self._lock_fd = None
|
||||
self._locked = False
|
||||
self._filename = filename
|
||||
self._mode = mode
|
||||
self._fallback_mode = fallback_mode
|
||||
self._fh = None
|
||||
self._lock_fd = None
|
||||
|
||||
def is_locked(self):
|
||||
"""Was the file locked."""
|
||||
return self._locked
|
||||
def is_locked(self):
|
||||
"""Was the file locked."""
|
||||
return self._locked
|
||||
|
||||
def file_handle(self):
|
||||
"""The file handle to the file. Valid only after opened."""
|
||||
return self._fh
|
||||
def file_handle(self):
|
||||
"""The file handle to the file. Valid only after opened."""
|
||||
return self._fh
|
||||
|
||||
def filename(self):
|
||||
"""The filename that is being locked."""
|
||||
return self._filename
|
||||
def filename(self):
|
||||
"""The filename that is being locked."""
|
||||
return self._filename
|
||||
|
||||
def open_and_lock(self, timeout, delay):
|
||||
"""Open the file and lock it.
|
||||
def open_and_lock(self, timeout, delay):
|
||||
"""Open the file and lock it.
|
||||
|
||||
Args:
|
||||
timeout: float, How long to try to lock for.
|
||||
delay: float, How long to wait between retries.
|
||||
"""
|
||||
pass
|
||||
pass
|
||||
|
||||
def unlock_and_close(self):
|
||||
"""Unlock and close the file."""
|
||||
pass
|
||||
def unlock_and_close(self):
|
||||
"""Unlock and close the file."""
|
||||
pass
|
||||
|
||||
|
||||
class _PosixOpener(_Opener):
|
||||
"""Lock files using Posix advisory lock files."""
|
||||
"""Lock files using Posix advisory lock files."""
|
||||
|
||||
def open_and_lock(self, timeout, delay):
|
||||
"""Open the file and lock it.
|
||||
def open_and_lock(self, timeout, delay):
|
||||
"""Open the file and lock it.
|
||||
|
||||
Tries to create a .lock file next to the file we're trying to open.
|
||||
|
||||
@@ -119,141 +120,67 @@ class _PosixOpener(_Opener):
|
||||
IOError: if the open fails.
|
||||
CredentialsFileSymbolicLinkError if the file is a symbolic link.
|
||||
"""
|
||||
if self._locked:
|
||||
raise AlreadyLockedException('File %s is already locked' %
|
||||
if self._locked:
|
||||
raise AlreadyLockedException('File %s is already locked' %
|
||||
self._filename)
|
||||
self._locked = False
|
||||
self._locked = False
|
||||
|
||||
validate_file(self._filename)
|
||||
try:
|
||||
self._fh = open(self._filename, self._mode)
|
||||
except IOError as e:
|
||||
# If we can't access with _mode, try _fallback_mode and don't lock.
|
||||
if e.errno == errno.EACCES:
|
||||
self._fh = open(self._filename, self._fallback_mode)
|
||||
return
|
||||
validate_file(self._filename)
|
||||
try:
|
||||
self._fh = open(self._filename, self._mode)
|
||||
except IOError as e:
|
||||
# If we can't access with _mode, try _fallback_mode and don't lock.
|
||||
if e.errno == errno.EACCES:
|
||||
self._fh = open(self._filename, self._fallback_mode)
|
||||
return
|
||||
|
||||
lock_filename = self._posix_lockfile(self._filename)
|
||||
lock_filename = self._posix_lockfile(self._filename)
|
||||
start_time = time.time()
|
||||
while True:
|
||||
try:
|
||||
self._lock_fd = os.open(lock_filename,
|
||||
os.O_CREAT|os.O_EXCL|os.O_RDWR)
|
||||
self._locked = True
|
||||
break
|
||||
try:
|
||||
self._lock_fd = os.open(lock_filename,
|
||||
os.O_CREAT | os.O_EXCL | os.O_RDWR)
|
||||
self._locked = True
|
||||
break
|
||||
|
||||
except OSError as e:
|
||||
if e.errno != errno.EEXIST:
|
||||
raise
|
||||
if (time.time() - start_time) >= timeout:
|
||||
logger.warn('Could not acquire lock %s in %s seconds',
|
||||
except OSError as e:
|
||||
if e.errno != errno.EEXIST:
|
||||
raise
|
||||
if (time.time() - start_time) >= timeout:
|
||||
logger.warn('Could not acquire lock %s in %s seconds',
|
||||
lock_filename, timeout)
|
||||
# Close the file and open in fallback_mode.
|
||||
if self._fh:
|
||||
self._fh.close()
|
||||
self._fh = open(self._filename, self._fallback_mode)
|
||||
return
|
||||
time.sleep(delay)
|
||||
|
||||
def unlock_and_close(self):
|
||||
"""Unlock a file by removing the .lock file, and close the handle."""
|
||||
if self._locked:
|
||||
lock_filename = self._posix_lockfile(self._filename)
|
||||
os.close(self._lock_fd)
|
||||
os.unlink(lock_filename)
|
||||
self._locked = False
|
||||
self._lock_fd = None
|
||||
if self._fh:
|
||||
self._fh.close()
|
||||
|
||||
def _posix_lockfile(self, filename):
|
||||
"""The name of the lock file to use for posix locking."""
|
||||
return '%s.lock' % filename
|
||||
|
||||
|
||||
try:
|
||||
import fcntl
|
||||
|
||||
class _FcntlOpener(_Opener):
|
||||
"""Open, lock, and unlock a file using fcntl.lockf."""
|
||||
|
||||
def open_and_lock(self, timeout, delay):
|
||||
"""Open the file and lock it.
|
||||
|
||||
Args:
|
||||
timeout: float, How long to try to lock for.
|
||||
delay: float, How long to wait between retries
|
||||
|
||||
Raises:
|
||||
AlreadyLockedException: if the lock is already acquired.
|
||||
IOError: if the open fails.
|
||||
CredentialsFileSymbolicLinkError if the file is a symbolic link.
|
||||
"""
|
||||
if self._locked:
|
||||
raise AlreadyLockedException('File %s is already locked' %
|
||||
self._filename)
|
||||
start_time = time.time()
|
||||
|
||||
validate_file(self._filename)
|
||||
try:
|
||||
self._fh = open(self._filename, self._mode)
|
||||
except IOError as e:
|
||||
# If we can't access with _mode, try _fallback_mode and don't lock.
|
||||
if e.errno in (errno.EPERM, errno.EACCES):
|
||||
self._fh = open(self._filename, self._fallback_mode)
|
||||
return
|
||||
|
||||
# We opened in _mode, try to lock the file.
|
||||
while True:
|
||||
try:
|
||||
fcntl.lockf(self._fh.fileno(), fcntl.LOCK_EX)
|
||||
self._locked = True
|
||||
return
|
||||
except IOError as e:
|
||||
# If not retrying, then just pass on the error.
|
||||
if timeout == 0:
|
||||
raise
|
||||
if e.errno != errno.EACCES:
|
||||
raise
|
||||
# We could not acquire the lock. Try again.
|
||||
if (time.time() - start_time) >= timeout:
|
||||
logger.warn('Could not lock %s in %s seconds',
|
||||
self._filename, timeout)
|
||||
if self._fh:
|
||||
self._fh.close()
|
||||
self._fh = open(self._filename, self._fallback_mode)
|
||||
return
|
||||
time.sleep(delay)
|
||||
# Close the file and open in fallback_mode.
|
||||
if self._fh:
|
||||
self._fh.close()
|
||||
self._fh = open(self._filename, self._fallback_mode)
|
||||
return
|
||||
time.sleep(delay)
|
||||
|
||||
def unlock_and_close(self):
|
||||
"""Close and unlock the file using the fcntl.lockf primitive."""
|
||||
if self._locked:
|
||||
fcntl.lockf(self._fh.fileno(), fcntl.LOCK_UN)
|
||||
self._locked = False
|
||||
if self._fh:
|
||||
self._fh.close()
|
||||
except ImportError:
|
||||
_FcntlOpener = None
|
||||
"""Unlock a file by removing the .lock file, and close the handle."""
|
||||
if self._locked:
|
||||
lock_filename = self._posix_lockfile(self._filename)
|
||||
os.close(self._lock_fd)
|
||||
os.unlink(lock_filename)
|
||||
self._locked = False
|
||||
self._lock_fd = None
|
||||
if self._fh:
|
||||
self._fh.close()
|
||||
|
||||
def _posix_lockfile(self, filename):
|
||||
"""The name of the lock file to use for posix locking."""
|
||||
return '%s.lock' % filename
|
||||
|
||||
|
||||
try:
|
||||
import pywintypes
|
||||
import win32con
|
||||
import win32file
|
||||
import fcntl
|
||||
|
||||
class _Win32Opener(_Opener):
|
||||
"""Open, lock, and unlock a file using windows primitives."""
|
||||
|
||||
# Error #33:
|
||||
# 'The process cannot access the file because another process'
|
||||
FILE_IN_USE_ERROR = 33
|
||||
class _FcntlOpener(_Opener):
|
||||
"""Open, lock, and unlock a file using fcntl.lockf."""
|
||||
|
||||
# Error #158:
|
||||
# 'The segment is already unlocked.'
|
||||
FILE_ALREADY_UNLOCKED_ERROR = 158
|
||||
|
||||
def open_and_lock(self, timeout, delay):
|
||||
"""Open the file and lock it.
|
||||
def open_and_lock(self, timeout, delay):
|
||||
"""Open the file and lock it.
|
||||
|
||||
Args:
|
||||
timeout: float, How long to try to lock for.
|
||||
@@ -264,71 +191,147 @@ try:
|
||||
IOError: if the open fails.
|
||||
CredentialsFileSymbolicLinkError if the file is a symbolic link.
|
||||
"""
|
||||
if self._locked:
|
||||
raise AlreadyLockedException('File %s is already locked' %
|
||||
if self._locked:
|
||||
raise AlreadyLockedException('File %s is already locked' %
|
||||
self._filename)
|
||||
start_time = time.time()
|
||||
start_time = time.time()
|
||||
|
||||
validate_file(self._filename)
|
||||
try:
|
||||
self._fh = open(self._filename, self._mode)
|
||||
except IOError as e:
|
||||
# If we can't access with _mode, try _fallback_mode and don't lock.
|
||||
if e.errno == errno.EACCES:
|
||||
self._fh = open(self._filename, self._fallback_mode)
|
||||
return
|
||||
validate_file(self._filename)
|
||||
try:
|
||||
self._fh = open(self._filename, self._mode)
|
||||
except IOError as e:
|
||||
# If we can't access with _mode, try _fallback_mode and don't lock.
|
||||
if e.errno in (errno.EPERM, errno.EACCES):
|
||||
self._fh = open(self._filename, self._fallback_mode)
|
||||
return
|
||||
|
||||
# We opened in _mode, try to lock the file.
|
||||
while True:
|
||||
try:
|
||||
hfile = win32file._get_osfhandle(self._fh.fileno())
|
||||
win32file.LockFileEx(
|
||||
# We opened in _mode, try to lock the file.
|
||||
while True:
|
||||
try:
|
||||
fcntl.lockf(self._fh.fileno(), fcntl.LOCK_EX)
|
||||
self._locked = True
|
||||
return
|
||||
except IOError as e:
|
||||
# If not retrying, then just pass on the error.
|
||||
if timeout == 0:
|
||||
raise
|
||||
if e.errno != errno.EACCES:
|
||||
raise
|
||||
# We could not acquire the lock. Try again.
|
||||
if (time.time() - start_time) >= timeout:
|
||||
logger.warn('Could not lock %s in %s seconds',
|
||||
self._filename, timeout)
|
||||
if self._fh:
|
||||
self._fh.close()
|
||||
self._fh = open(self._filename, self._fallback_mode)
|
||||
return
|
||||
time.sleep(delay)
|
||||
|
||||
def unlock_and_close(self):
|
||||
"""Close and unlock the file using the fcntl.lockf primitive."""
|
||||
if self._locked:
|
||||
fcntl.lockf(self._fh.fileno(), fcntl.LOCK_UN)
|
||||
self._locked = False
|
||||
if self._fh:
|
||||
self._fh.close()
|
||||
except ImportError:
|
||||
_FcntlOpener = None
|
||||
|
||||
|
||||
try:
|
||||
import pywintypes
|
||||
import win32con
|
||||
import win32file
|
||||
|
||||
|
||||
class _Win32Opener(_Opener):
|
||||
"""Open, lock, and unlock a file using windows primitives."""
|
||||
|
||||
# Error #33:
|
||||
# 'The process cannot access the file because another process'
|
||||
FILE_IN_USE_ERROR = 33
|
||||
|
||||
# Error #158:
|
||||
# 'The segment is already unlocked.'
|
||||
FILE_ALREADY_UNLOCKED_ERROR = 158
|
||||
|
||||
def open_and_lock(self, timeout, delay):
|
||||
"""Open the file and lock it.
|
||||
|
||||
Args:
|
||||
timeout: float, How long to try to lock for.
|
||||
delay: float, How long to wait between retries
|
||||
|
||||
Raises:
|
||||
AlreadyLockedException: if the lock is already acquired.
|
||||
IOError: if the open fails.
|
||||
CredentialsFileSymbolicLinkError if the file is a symbolic link.
|
||||
"""
|
||||
if self._locked:
|
||||
raise AlreadyLockedException('File %s is already locked' %
|
||||
self._filename)
|
||||
start_time = time.time()
|
||||
|
||||
validate_file(self._filename)
|
||||
try:
|
||||
self._fh = open(self._filename, self._mode)
|
||||
except IOError as e:
|
||||
# If we can't access with _mode, try _fallback_mode and don't lock.
|
||||
if e.errno == errno.EACCES:
|
||||
self._fh = open(self._filename, self._fallback_mode)
|
||||
return
|
||||
|
||||
# We opened in _mode, try to lock the file.
|
||||
while True:
|
||||
try:
|
||||
hfile = win32file._get_osfhandle(self._fh.fileno())
|
||||
win32file.LockFileEx(
|
||||
hfile,
|
||||
(win32con.LOCKFILE_FAIL_IMMEDIATELY|
|
||||
(win32con.LOCKFILE_FAIL_IMMEDIATELY |
|
||||
win32con.LOCKFILE_EXCLUSIVE_LOCK), 0, -0x10000,
|
||||
pywintypes.OVERLAPPED())
|
||||
self._locked = True
|
||||
return
|
||||
except pywintypes.error as e:
|
||||
if timeout == 0:
|
||||
raise
|
||||
self._locked = True
|
||||
return
|
||||
except pywintypes.error as e:
|
||||
if timeout == 0:
|
||||
raise
|
||||
|
||||
# If the error is not that the file is already in use, raise.
|
||||
if e[0] != _Win32Opener.FILE_IN_USE_ERROR:
|
||||
raise
|
||||
# If the error is not that the file is already in use, raise.
|
||||
if e[0] != _Win32Opener.FILE_IN_USE_ERROR:
|
||||
raise
|
||||
|
||||
# We could not acquire the lock. Try again.
|
||||
if (time.time() - start_time) >= timeout:
|
||||
logger.warn('Could not lock %s in %s seconds' % (
|
||||
# We could not acquire the lock. Try again.
|
||||
if (time.time() - start_time) >= timeout:
|
||||
logger.warn('Could not lock %s in %s seconds' % (
|
||||
self._filename, timeout))
|
||||
if self._fh:
|
||||
self._fh.close()
|
||||
self._fh = open(self._filename, self._fallback_mode)
|
||||
return
|
||||
time.sleep(delay)
|
||||
if self._fh:
|
||||
self._fh.close()
|
||||
self._fh = open(self._filename, self._fallback_mode)
|
||||
return
|
||||
time.sleep(delay)
|
||||
|
||||
def unlock_and_close(self):
|
||||
"""Close and unlock the file using the win32 primitive."""
|
||||
if self._locked:
|
||||
try:
|
||||
hfile = win32file._get_osfhandle(self._fh.fileno())
|
||||
win32file.UnlockFileEx(hfile, 0, -0x10000, pywintypes.OVERLAPPED())
|
||||
except pywintypes.error as e:
|
||||
if e[0] != _Win32Opener.FILE_ALREADY_UNLOCKED_ERROR:
|
||||
raise
|
||||
self._locked = False
|
||||
def unlock_and_close(self):
|
||||
"""Close and unlock the file using the win32 primitive."""
|
||||
if self._locked:
|
||||
try:
|
||||
hfile = win32file._get_osfhandle(self._fh.fileno())
|
||||
win32file.UnlockFileEx(hfile, 0, -0x10000, pywintypes.OVERLAPPED())
|
||||
except pywintypes.error as e:
|
||||
if e[0] != _Win32Opener.FILE_ALREADY_UNLOCKED_ERROR:
|
||||
raise
|
||||
self._locked = False
|
||||
if self._fh:
|
||||
self._fh.close()
|
||||
self._fh.close()
|
||||
except ImportError:
|
||||
_Win32Opener = None
|
||||
_Win32Opener = None
|
||||
|
||||
|
||||
class LockedFile(object):
|
||||
"""Represent a file that has exclusive access."""
|
||||
"""Represent a file that has exclusive access."""
|
||||
|
||||
@util.positional(4)
|
||||
def __init__(self, filename, mode, fallback_mode, use_native_locking=True):
|
||||
"""Construct a LockedFile.
|
||||
@util.positional(4)
|
||||
def __init__(self, filename, mode, fallback_mode, use_native_locking=True):
|
||||
"""Construct a LockedFile.
|
||||
|
||||
Args:
|
||||
filename: string, The path of the file to open.
|
||||
@@ -336,32 +339,32 @@ class LockedFile(object):
|
||||
fallback_mode: string, The mode to use if locking fails.
|
||||
use_native_locking: bool, Whether or not fcntl/win32 locking is used.
|
||||
"""
|
||||
opener = None
|
||||
if not opener and use_native_locking:
|
||||
if _Win32Opener:
|
||||
opener = _Win32Opener(filename, mode, fallback_mode)
|
||||
if _FcntlOpener:
|
||||
opener = _FcntlOpener(filename, mode, fallback_mode)
|
||||
opener = None
|
||||
if not opener and use_native_locking:
|
||||
if _Win32Opener:
|
||||
opener = _Win32Opener(filename, mode, fallback_mode)
|
||||
if _FcntlOpener:
|
||||
opener = _FcntlOpener(filename, mode, fallback_mode)
|
||||
|
||||
if not opener:
|
||||
opener = _PosixOpener(filename, mode, fallback_mode)
|
||||
if not opener:
|
||||
opener = _PosixOpener(filename, mode, fallback_mode)
|
||||
|
||||
self._opener = opener
|
||||
self._opener = opener
|
||||
|
||||
def filename(self):
|
||||
"""Return the filename we were constructed with."""
|
||||
return self._opener._filename
|
||||
def filename(self):
|
||||
"""Return the filename we were constructed with."""
|
||||
return self._opener._filename
|
||||
|
||||
def file_handle(self):
|
||||
"""Return the file_handle to the opened file."""
|
||||
return self._opener.file_handle()
|
||||
def file_handle(self):
|
||||
"""Return the file_handle to the opened file."""
|
||||
return self._opener.file_handle()
|
||||
|
||||
def is_locked(self):
|
||||
"""Return whether we successfully locked the file."""
|
||||
return self._opener.is_locked()
|
||||
def is_locked(self):
|
||||
"""Return whether we successfully locked the file."""
|
||||
return self._opener.is_locked()
|
||||
|
||||
def open_and_lock(self, timeout=0, delay=0.05):
|
||||
"""Open the file, trying to lock it.
|
||||
def open_and_lock(self, timeout=0, delay=0.05):
|
||||
"""Open the file, trying to lock it.
|
||||
|
||||
Args:
|
||||
timeout: float, The number of seconds to try to acquire the lock.
|
||||
@@ -371,8 +374,8 @@ class LockedFile(object):
|
||||
AlreadyLockedException: if the lock is already acquired.
|
||||
IOError: if the open fails.
|
||||
"""
|
||||
self._opener.open_and_lock(timeout, delay)
|
||||
self._opener.open_and_lock(timeout, delay)
|
||||
|
||||
def unlock_and_close(self):
|
||||
"""Unlock and close a file."""
|
||||
self._opener.unlock_and_close()
|
||||
def unlock_and_close(self):
|
||||
"""Unlock and close a file."""
|
||||
self._opener.unlock_and_close()
|
||||
|
||||
@@ -65,17 +65,17 @@ _multistores_lock = threading.Lock()
|
||||
|
||||
|
||||
class Error(Exception):
|
||||
"""Base error for this module."""
|
||||
"""Base error for this module."""
|
||||
|
||||
|
||||
class NewerCredentialStoreError(Error):
|
||||
"""The credential store is a newer version than supported."""
|
||||
"""The credential store is a newer version than supported."""
|
||||
|
||||
|
||||
@util.positional(4)
|
||||
def get_credential_storage(filename, client_id, user_agent, scope,
|
||||
warn_on_readonly=True):
|
||||
"""Get a Storage instance for a credential.
|
||||
"""Get a Storage instance for a credential.
|
||||
|
||||
Args:
|
||||
filename: The JSON file storing a set of credentials
|
||||
@@ -88,17 +88,17 @@ def get_credential_storage(filename, client_id, user_agent, scope,
|
||||
An object derived from client.Storage for getting/setting the
|
||||
credential.
|
||||
"""
|
||||
# Recreate the legacy key with these specific parameters
|
||||
key = {'clientId': client_id, 'userAgent': user_agent,
|
||||
# Recreate the legacy key with these specific parameters
|
||||
key = {'clientId': client_id, 'userAgent': user_agent,
|
||||
'scope': util.scopes_to_string(scope)}
|
||||
return get_credential_storage_custom_key(
|
||||
return get_credential_storage_custom_key(
|
||||
filename, key, warn_on_readonly=warn_on_readonly)
|
||||
|
||||
|
||||
@util.positional(2)
|
||||
def get_credential_storage_custom_string_key(
|
||||
filename, key_string, warn_on_readonly=True):
|
||||
"""Get a Storage instance for a credential using a single string as a key.
|
||||
"""Get a Storage instance for a credential using a single string as a key.
|
||||
|
||||
Allows you to provide a string as a custom key that will be used for
|
||||
credential storage and retrieval.
|
||||
@@ -112,16 +112,16 @@ def get_credential_storage_custom_string_key(
|
||||
An object derived from client.Storage for getting/setting the
|
||||
credential.
|
||||
"""
|
||||
# Create a key dictionary that can be used
|
||||
key_dict = {'key': key_string}
|
||||
return get_credential_storage_custom_key(
|
||||
# Create a key dictionary that can be used
|
||||
key_dict = {'key': key_string}
|
||||
return get_credential_storage_custom_key(
|
||||
filename, key_dict, warn_on_readonly=warn_on_readonly)
|
||||
|
||||
|
||||
@util.positional(2)
|
||||
def get_credential_storage_custom_key(
|
||||
filename, key_dict, warn_on_readonly=True):
|
||||
"""Get a Storage instance for a credential using a dictionary as a key.
|
||||
"""Get a Storage instance for a credential using a dictionary as a key.
|
||||
|
||||
Allows you to provide a dictionary as a custom key that will be used for
|
||||
credential storage and retrieval.
|
||||
@@ -137,14 +137,14 @@ def get_credential_storage_custom_key(
|
||||
An object derived from client.Storage for getting/setting the
|
||||
credential.
|
||||
"""
|
||||
multistore = _get_multistore(filename, warn_on_readonly=warn_on_readonly)
|
||||
key = util.dict_to_tuple_key(key_dict)
|
||||
return multistore._get_storage(key)
|
||||
multistore = _get_multistore(filename, warn_on_readonly=warn_on_readonly)
|
||||
key = util.dict_to_tuple_key(key_dict)
|
||||
return multistore._get_storage(key)
|
||||
|
||||
|
||||
@util.positional(1)
|
||||
def get_all_credential_keys(filename, warn_on_readonly=True):
|
||||
"""Gets all the registered credential keys in the given Multistore.
|
||||
"""Gets all the registered credential keys in the given Multistore.
|
||||
|
||||
Args:
|
||||
filename: The JSON file storing a set of credentials
|
||||
@@ -155,17 +155,17 @@ def get_all_credential_keys(filename, warn_on_readonly=True):
|
||||
dictionaries that can be passed into get_credential_storage_custom_key to
|
||||
get the actual credentials.
|
||||
"""
|
||||
multistore = _get_multistore(filename, warn_on_readonly=warn_on_readonly)
|
||||
multistore._lock()
|
||||
try:
|
||||
return multistore._get_all_credential_keys()
|
||||
finally:
|
||||
multistore._unlock()
|
||||
multistore = _get_multistore(filename, warn_on_readonly=warn_on_readonly)
|
||||
multistore._lock()
|
||||
try:
|
||||
return multistore._get_all_credential_keys()
|
||||
finally:
|
||||
multistore._unlock()
|
||||
|
||||
|
||||
@util.positional(1)
|
||||
def _get_multistore(filename, warn_on_readonly=True):
|
||||
"""A helper method to initialize the multistore with proper locking.
|
||||
"""A helper method to initialize the multistore with proper locking.
|
||||
|
||||
Args:
|
||||
filename: The JSON file storing a set of credentials
|
||||
@@ -174,176 +174,176 @@ def _get_multistore(filename, warn_on_readonly=True):
|
||||
Returns:
|
||||
A multistore object
|
||||
"""
|
||||
filename = os.path.expanduser(filename)
|
||||
_multistores_lock.acquire()
|
||||
try:
|
||||
multistore = _multistores.setdefault(
|
||||
filename = os.path.expanduser(filename)
|
||||
_multistores_lock.acquire()
|
||||
try:
|
||||
multistore = _multistores.setdefault(
|
||||
filename, _MultiStore(filename, warn_on_readonly=warn_on_readonly))
|
||||
finally:
|
||||
_multistores_lock.release()
|
||||
return multistore
|
||||
finally:
|
||||
_multistores_lock.release()
|
||||
return multistore
|
||||
|
||||
|
||||
class _MultiStore(object):
|
||||
"""A file backed store for multiple credentials."""
|
||||
"""A file backed store for multiple credentials."""
|
||||
|
||||
@util.positional(2)
|
||||
def __init__(self, filename, warn_on_readonly=True):
|
||||
"""Initialize the class.
|
||||
@util.positional(2)
|
||||
def __init__(self, filename, warn_on_readonly=True):
|
||||
"""Initialize the class.
|
||||
|
||||
This will create the file if necessary.
|
||||
"""
|
||||
self._file = LockedFile(filename, 'r+', 'r')
|
||||
self._thread_lock = threading.Lock()
|
||||
self._read_only = False
|
||||
self._warn_on_readonly = warn_on_readonly
|
||||
self._file = LockedFile(filename, 'r+', 'r')
|
||||
self._thread_lock = threading.Lock()
|
||||
self._read_only = False
|
||||
self._warn_on_readonly = warn_on_readonly
|
||||
|
||||
self._create_file_if_needed()
|
||||
self._create_file_if_needed()
|
||||
|
||||
# Cache of deserialized store. This is only valid after the
|
||||
# _MultiStore is locked or _refresh_data_cache is called. This is
|
||||
# of the form of:
|
||||
#
|
||||
# ((key, value), (key, value)...) -> OAuth2Credential
|
||||
#
|
||||
# If this is None, then the store hasn't been read yet.
|
||||
self._data = None
|
||||
# Cache of deserialized store. This is only valid after the
|
||||
# _MultiStore is locked or _refresh_data_cache is called. This is
|
||||
# of the form of:
|
||||
#
|
||||
# ((key, value), (key, value)...) -> OAuth2Credential
|
||||
#
|
||||
# If this is None, then the store hasn't been read yet.
|
||||
self._data = None
|
||||
|
||||
class _Storage(BaseStorage):
|
||||
"""A Storage object that knows how to read/write a single credential."""
|
||||
class _Storage(BaseStorage):
|
||||
"""A Storage object that knows how to read/write a single credential."""
|
||||
|
||||
def __init__(self, multistore, key):
|
||||
self._multistore = multistore
|
||||
self._key = key
|
||||
def __init__(self, multistore, key):
|
||||
self._multistore = multistore
|
||||
self._key = key
|
||||
|
||||
def acquire_lock(self):
|
||||
"""Acquires any lock necessary to access this Storage.
|
||||
def acquire_lock(self):
|
||||
"""Acquires any lock necessary to access this Storage.
|
||||
|
||||
This lock is not reentrant.
|
||||
"""
|
||||
self._multistore._lock()
|
||||
self._multistore._lock()
|
||||
|
||||
def release_lock(self):
|
||||
"""Release the Storage lock.
|
||||
def release_lock(self):
|
||||
"""Release the Storage lock.
|
||||
|
||||
Trying to release a lock that isn't held will result in a
|
||||
RuntimeError.
|
||||
"""
|
||||
self._multistore._unlock()
|
||||
self._multistore._unlock()
|
||||
|
||||
def locked_get(self):
|
||||
"""Retrieve credential.
|
||||
def locked_get(self):
|
||||
"""Retrieve credential.
|
||||
|
||||
The Storage lock must be held when this is called.
|
||||
|
||||
Returns:
|
||||
oauth2client.client.Credentials
|
||||
"""
|
||||
credential = self._multistore._get_credential(self._key)
|
||||
if credential:
|
||||
credential.set_store(self)
|
||||
return credential
|
||||
credential = self._multistore._get_credential(self._key)
|
||||
if credential:
|
||||
credential.set_store(self)
|
||||
return credential
|
||||
|
||||
def locked_put(self, credentials):
|
||||
"""Write a credential.
|
||||
def locked_put(self, credentials):
|
||||
"""Write a credential.
|
||||
|
||||
The Storage lock must be held when this is called.
|
||||
|
||||
Args:
|
||||
credentials: Credentials, the credentials to store.
|
||||
"""
|
||||
self._multistore._update_credential(self._key, credentials)
|
||||
self._multistore._update_credential(self._key, credentials)
|
||||
|
||||
def locked_delete(self):
|
||||
"""Delete a credential.
|
||||
def locked_delete(self):
|
||||
"""Delete a credential.
|
||||
|
||||
The Storage lock must be held when this is called.
|
||||
|
||||
Args:
|
||||
credentials: Credentials, the credentials to store.
|
||||
"""
|
||||
self._multistore._delete_credential(self._key)
|
||||
self._multistore._delete_credential(self._key)
|
||||
|
||||
def _create_file_if_needed(self):
|
||||
"""Create an empty file if necessary.
|
||||
def _create_file_if_needed(self):
|
||||
"""Create an empty file if necessary.
|
||||
|
||||
This method will not initialize the file. Instead it implements a
|
||||
simple version of "touch" to ensure the file has been created.
|
||||
"""
|
||||
if not os.path.exists(self._file.filename()):
|
||||
old_umask = os.umask(0o177)
|
||||
try:
|
||||
open(self._file.filename(), 'a+b').close()
|
||||
finally:
|
||||
os.umask(old_umask)
|
||||
if not os.path.exists(self._file.filename()):
|
||||
old_umask = os.umask(0o177)
|
||||
try:
|
||||
open(self._file.filename(), 'a+b').close()
|
||||
finally:
|
||||
os.umask(old_umask)
|
||||
|
||||
def _lock(self):
|
||||
"""Lock the entire multistore."""
|
||||
self._thread_lock.acquire()
|
||||
try:
|
||||
self._file.open_and_lock()
|
||||
except IOError as e:
|
||||
if e.errno == errno.ENOSYS:
|
||||
logger.warn('File system does not support locking the credentials '
|
||||
def _lock(self):
|
||||
"""Lock the entire multistore."""
|
||||
self._thread_lock.acquire()
|
||||
try:
|
||||
self._file.open_and_lock()
|
||||
except IOError as e:
|
||||
if e.errno == errno.ENOSYS:
|
||||
logger.warn('File system does not support locking the credentials '
|
||||
'file.')
|
||||
elif e.errno == errno.ENOLCK:
|
||||
logger.warn('File system is out of resources for writing the '
|
||||
elif e.errno == errno.ENOLCK:
|
||||
logger.warn('File system is out of resources for writing the '
|
||||
'credentials file (is your disk full?).')
|
||||
else:
|
||||
raise
|
||||
if not self._file.is_locked():
|
||||
self._read_only = True
|
||||
if self._warn_on_readonly:
|
||||
logger.warn('The credentials file (%s) is not writable. Opening in '
|
||||
else:
|
||||
raise
|
||||
if not self._file.is_locked():
|
||||
self._read_only = True
|
||||
if self._warn_on_readonly:
|
||||
logger.warn('The credentials file (%s) is not writable. Opening in '
|
||||
'read-only mode. Any refreshed credentials will only be '
|
||||
'valid for this run.', self._file.filename())
|
||||
if os.path.getsize(self._file.filename()) == 0:
|
||||
logger.debug('Initializing empty multistore file')
|
||||
# The multistore is empty so write out an empty file.
|
||||
self._data = {}
|
||||
self._write()
|
||||
elif not self._read_only or self._data is None:
|
||||
# Only refresh the data if we are read/write or we haven't
|
||||
# cached the data yet. If we are readonly, we assume is isn't
|
||||
# changing out from under us and that we only have to read it
|
||||
# once. This prevents us from whacking any new access keys that
|
||||
# we have cached in memory but were unable to write out.
|
||||
self._refresh_data_cache()
|
||||
if os.path.getsize(self._file.filename()) == 0:
|
||||
logger.debug('Initializing empty multistore file')
|
||||
# The multistore is empty so write out an empty file.
|
||||
self._data = {}
|
||||
self._write()
|
||||
elif not self._read_only or self._data is None:
|
||||
# Only refresh the data if we are read/write or we haven't
|
||||
# cached the data yet. If we are readonly, we assume is isn't
|
||||
# changing out from under us and that we only have to read it
|
||||
# once. This prevents us from whacking any new access keys that
|
||||
# we have cached in memory but were unable to write out.
|
||||
self._refresh_data_cache()
|
||||
|
||||
def _unlock(self):
|
||||
"""Release the lock on the multistore."""
|
||||
self._file.unlock_and_close()
|
||||
self._thread_lock.release()
|
||||
def _unlock(self):
|
||||
"""Release the lock on the multistore."""
|
||||
self._file.unlock_and_close()
|
||||
self._thread_lock.release()
|
||||
|
||||
def _locked_json_read(self):
|
||||
"""Get the raw content of the multistore file.
|
||||
def _locked_json_read(self):
|
||||
"""Get the raw content of the multistore file.
|
||||
|
||||
The multistore must be locked when this is called.
|
||||
|
||||
Returns:
|
||||
The contents of the multistore decoded as JSON.
|
||||
"""
|
||||
assert self._thread_lock.locked()
|
||||
self._file.file_handle().seek(0)
|
||||
return json.load(self._file.file_handle())
|
||||
assert self._thread_lock.locked()
|
||||
self._file.file_handle().seek(0)
|
||||
return json.load(self._file.file_handle())
|
||||
|
||||
def _locked_json_write(self, data):
|
||||
"""Write a JSON serializable data structure to the multistore.
|
||||
def _locked_json_write(self, data):
|
||||
"""Write a JSON serializable data structure to the multistore.
|
||||
|
||||
The multistore must be locked when this is called.
|
||||
|
||||
Args:
|
||||
data: The data to be serialized and written.
|
||||
"""
|
||||
assert self._thread_lock.locked()
|
||||
if self._read_only:
|
||||
return
|
||||
self._file.file_handle().seek(0)
|
||||
json.dump(data, self._file.file_handle(), sort_keys=True, indent=2, separators=(',', ': '))
|
||||
self._file.file_handle().truncate()
|
||||
assert self._thread_lock.locked()
|
||||
if self._read_only:
|
||||
return
|
||||
self._file.file_handle().seek(0)
|
||||
json.dump(data, self._file.file_handle(), sort_keys=True, indent=2, separators=(',', ': '))
|
||||
self._file.file_handle().truncate()
|
||||
|
||||
def _refresh_data_cache(self):
|
||||
"""Refresh the contents of the multistore.
|
||||
def _refresh_data_cache(self):
|
||||
"""Refresh the contents of the multistore.
|
||||
|
||||
The multistore must be locked when this is called.
|
||||
|
||||
@@ -351,41 +351,41 @@ class _MultiStore(object):
|
||||
NewerCredentialStoreError: Raised when a newer client has written the
|
||||
store.
|
||||
"""
|
||||
self._data = {}
|
||||
try:
|
||||
raw_data = self._locked_json_read()
|
||||
except Exception:
|
||||
logger.warn('Credential data store could not be loaded. '
|
||||
self._data = {}
|
||||
try:
|
||||
raw_data = self._locked_json_read()
|
||||
except Exception:
|
||||
logger.warn('Credential data store could not be loaded. '
|
||||
'Will ignore and overwrite.')
|
||||
return
|
||||
return
|
||||
|
||||
version = 0
|
||||
try:
|
||||
version = raw_data['file_version']
|
||||
except Exception:
|
||||
logger.warn('Missing version for credential data store. It may be '
|
||||
version = 0
|
||||
try:
|
||||
version = raw_data['file_version']
|
||||
except Exception:
|
||||
logger.warn('Missing version for credential data store. It may be '
|
||||
'corrupt or an old version. Overwriting.')
|
||||
if version > 1:
|
||||
raise NewerCredentialStoreError(
|
||||
if version > 1:
|
||||
raise NewerCredentialStoreError(
|
||||
'Credential file has file_version of %d. '
|
||||
'Only file_version of 1 is supported.' % version)
|
||||
|
||||
credentials = []
|
||||
try:
|
||||
credentials = raw_data['data']
|
||||
except (TypeError, KeyError):
|
||||
pass
|
||||
credentials = []
|
||||
try:
|
||||
credentials = raw_data['data']
|
||||
except (TypeError, KeyError):
|
||||
pass
|
||||
|
||||
for cred_entry in credentials:
|
||||
try:
|
||||
(key, credential) = self._decode_credential_from_json(cred_entry)
|
||||
self._data[key] = credential
|
||||
except:
|
||||
# If something goes wrong loading a credential, just ignore it
|
||||
logger.info('Error decoding credential, skipping', exc_info=True)
|
||||
for cred_entry in credentials:
|
||||
try:
|
||||
(key, credential) = self._decode_credential_from_json(cred_entry)
|
||||
self._data[key] = credential
|
||||
except:
|
||||
# If something goes wrong loading a credential, just ignore it
|
||||
logger.info('Error decoding credential, skipping', exc_info=True)
|
||||
|
||||
def _decode_credential_from_json(self, cred_entry):
|
||||
"""Load a credential from our JSON serialization.
|
||||
def _decode_credential_from_json(self, cred_entry):
|
||||
"""Load a credential from our JSON serialization.
|
||||
|
||||
Args:
|
||||
cred_entry: A dict entry from the data member of our format
|
||||
@@ -394,36 +394,36 @@ class _MultiStore(object):
|
||||
(key, cred) where the key is the key tuple and the cred is the
|
||||
OAuth2Credential object.
|
||||
"""
|
||||
raw_key = cred_entry['key']
|
||||
key = util.dict_to_tuple_key(raw_key)
|
||||
credential = None
|
||||
credential = Credentials.new_from_json(json.dumps(cred_entry['credential']))
|
||||
return (key, credential)
|
||||
raw_key = cred_entry['key']
|
||||
key = util.dict_to_tuple_key(raw_key)
|
||||
credential = None
|
||||
credential = Credentials.new_from_json(json.dumps(cred_entry['credential']))
|
||||
return (key, credential)
|
||||
|
||||
def _write(self):
|
||||
"""Write the cached data back out.
|
||||
def _write(self):
|
||||
"""Write the cached data back out.
|
||||
|
||||
The multistore must be locked.
|
||||
"""
|
||||
raw_data = {'file_version': 1}
|
||||
raw_creds = []
|
||||
raw_data['data'] = raw_creds
|
||||
for (cred_key, cred) in self._data.items():
|
||||
raw_key = dict(cred_key)
|
||||
raw_cred = json.loads(cred.to_json())
|
||||
raw_creds.append({'key': raw_key, 'credential': raw_cred})
|
||||
self._locked_json_write(raw_data)
|
||||
raw_data = {'file_version': 1}
|
||||
raw_creds = []
|
||||
raw_data['data'] = raw_creds
|
||||
for (cred_key, cred) in self._data.items():
|
||||
raw_key = dict(cred_key)
|
||||
raw_cred = json.loads(cred.to_json())
|
||||
raw_creds.append({'key': raw_key, 'credential': raw_cred})
|
||||
self._locked_json_write(raw_data)
|
||||
|
||||
def _get_all_credential_keys(self):
|
||||
"""Gets all the registered credential keys in the multistore.
|
||||
def _get_all_credential_keys(self):
|
||||
"""Gets all the registered credential keys in the multistore.
|
||||
|
||||
Returns:
|
||||
A list of dictionaries corresponding to all the keys currently registered
|
||||
"""
|
||||
return [dict(key) for key in self._data.keys()]
|
||||
return [dict(key) for key in self._data.keys()]
|
||||
|
||||
def _get_credential(self, key):
|
||||
"""Get a credential from the multistore.
|
||||
def _get_credential(self, key):
|
||||
"""Get a credential from the multistore.
|
||||
|
||||
The multistore must be locked.
|
||||
|
||||
@@ -433,10 +433,10 @@ class _MultiStore(object):
|
||||
Returns:
|
||||
The credential specified or None if not present
|
||||
"""
|
||||
return self._data.get(key, None)
|
||||
return self._data.get(key, None)
|
||||
|
||||
def _update_credential(self, key, cred):
|
||||
"""Update a credential and write the multistore.
|
||||
def _update_credential(self, key, cred):
|
||||
"""Update a credential and write the multistore.
|
||||
|
||||
This must be called when the multistore is locked.
|
||||
|
||||
@@ -444,25 +444,25 @@ class _MultiStore(object):
|
||||
key: The key used to retrieve the credential
|
||||
cred: The OAuth2Credential to update/set
|
||||
"""
|
||||
self._data[key] = cred
|
||||
self._write()
|
||||
self._data[key] = cred
|
||||
self._write()
|
||||
|
||||
def _delete_credential(self, key):
|
||||
"""Delete a credential and write the multistore.
|
||||
def _delete_credential(self, key):
|
||||
"""Delete a credential and write the multistore.
|
||||
|
||||
This must be called when the multistore is locked.
|
||||
|
||||
Args:
|
||||
key: The key used to retrieve the credential
|
||||
"""
|
||||
try:
|
||||
del self._data[key]
|
||||
except KeyError:
|
||||
pass
|
||||
self._write()
|
||||
try:
|
||||
del self._data[key]
|
||||
except KeyError:
|
||||
pass
|
||||
self._write()
|
||||
|
||||
def _get_storage(self, key):
|
||||
"""Get a Storage object to get/set a credential.
|
||||
def _get_storage(self, key):
|
||||
"""Get a Storage object to get/set a credential.
|
||||
|
||||
This Storage is a 'view' into the multistore.
|
||||
|
||||
@@ -472,4 +472,4 @@ class _MultiStore(object):
|
||||
Returns:
|
||||
A Storage object that can be used to get/set this cred
|
||||
"""
|
||||
return self._Storage(self, key)
|
||||
return self._Storage(self, key)
|
||||
|
||||
@@ -30,7 +30,6 @@ from oauth2client import util
|
||||
from oauth2client.tools import ClientRedirectHandler
|
||||
from oauth2client.tools import ClientRedirectServer
|
||||
|
||||
|
||||
FLAGS = gflags.FLAGS
|
||||
|
||||
gflags.DEFINE_boolean('auth_local_webserver', True,
|
||||
@@ -48,7 +47,7 @@ gflags.DEFINE_multi_int('auth_host_port', [8080, 8090],
|
||||
|
||||
@util.positional(2)
|
||||
def run(flow, storage, http=None):
|
||||
"""Core code for a command-line application.
|
||||
"""Core code for a command-line application.
|
||||
|
||||
The ``run()`` function is called from your application and runs
|
||||
through all the steps to obtain credentials. It takes a ``Flow``
|
||||
@@ -86,76 +85,76 @@ def run(flow, storage, http=None):
|
||||
Returns:
|
||||
Credentials, the obtained credential.
|
||||
"""
|
||||
logging.warning('This function, oauth2client.tools.run(), and the use of '
|
||||
logging.warning('This function, oauth2client.tools.run(), and the use of '
|
||||
'the gflags library are deprecated and will be removed in a future '
|
||||
'version of the library.')
|
||||
if FLAGS.auth_local_webserver:
|
||||
success = False
|
||||
port_number = 0
|
||||
for port in FLAGS.auth_host_port:
|
||||
port_number = port
|
||||
try:
|
||||
httpd = ClientRedirectServer((FLAGS.auth_host_name, port),
|
||||
if FLAGS.auth_local_webserver:
|
||||
success = False
|
||||
port_number = 0
|
||||
for port in FLAGS.auth_host_port:
|
||||
port_number = port
|
||||
try:
|
||||
httpd = ClientRedirectServer((FLAGS.auth_host_name, port),
|
||||
ClientRedirectHandler)
|
||||
except socket.error as e:
|
||||
pass
|
||||
else:
|
||||
success = True
|
||||
break
|
||||
FLAGS.auth_local_webserver = success
|
||||
except socket.error as e:
|
||||
pass
|
||||
else:
|
||||
success = True
|
||||
break
|
||||
FLAGS.auth_local_webserver = success
|
||||
if not success:
|
||||
print('Failed to start a local webserver listening on either port 8080')
|
||||
print('or port 9090. Please check your firewall settings and locally')
|
||||
print('running programs that may be blocking or using those ports.')
|
||||
print()
|
||||
print('Falling back to --noauth_local_webserver and continuing with')
|
||||
print('authorization.')
|
||||
print()
|
||||
print('Failed to start a local webserver listening on either port 8080')
|
||||
print('or port 9090. Please check your firewall settings and locally')
|
||||
print('running programs that may be blocking or using those ports.')
|
||||
print()
|
||||
print('Falling back to --noauth_local_webserver and continuing with')
|
||||
print('authorization.')
|
||||
print()
|
||||
|
||||
if FLAGS.auth_local_webserver:
|
||||
oauth_callback = 'http://%s:%s/' % (FLAGS.auth_host_name, port_number)
|
||||
else:
|
||||
oauth_callback = client.OOB_CALLBACK_URN
|
||||
flow.redirect_uri = oauth_callback
|
||||
authorize_url = flow.step1_get_authorize_url()
|
||||
|
||||
if FLAGS.auth_local_webserver:
|
||||
webbrowser.open(authorize_url, new=1, autoraise=True)
|
||||
print('Your browser has been opened to visit:')
|
||||
print()
|
||||
print(' ' + authorize_url)
|
||||
print()
|
||||
print('If your browser is on a different machine then exit and re-run')
|
||||
print('this application with the command-line parameter ')
|
||||
print()
|
||||
print(' --noauth_local_webserver')
|
||||
print()
|
||||
else:
|
||||
print('Go to the following link in your browser:')
|
||||
print()
|
||||
print(' ' + authorize_url)
|
||||
print()
|
||||
|
||||
code = None
|
||||
if FLAGS.auth_local_webserver:
|
||||
httpd.handle_request()
|
||||
if 'error' in httpd.query_params:
|
||||
sys.exit('Authentication request was rejected.')
|
||||
if 'code' in httpd.query_params:
|
||||
code = httpd.query_params['code']
|
||||
if FLAGS.auth_local_webserver:
|
||||
oauth_callback = 'http://%s:%s/' % (FLAGS.auth_host_name, port_number)
|
||||
else:
|
||||
print('Failed to find "code" in the query parameters of the redirect.')
|
||||
sys.exit('Try running with --noauth_local_webserver.')
|
||||
else:
|
||||
code = input('Enter verification code: ').strip()
|
||||
oauth_callback = client.OOB_CALLBACK_URN
|
||||
flow.redirect_uri = oauth_callback
|
||||
authorize_url = flow.step1_get_authorize_url()
|
||||
|
||||
try:
|
||||
credential = flow.step2_exchange(code, http=http)
|
||||
except client.FlowExchangeError as e:
|
||||
sys.exit('Authentication has failed: %s' % e)
|
||||
if FLAGS.auth_local_webserver:
|
||||
webbrowser.open(authorize_url, new=1, autoraise=True)
|
||||
print('Your browser has been opened to visit:')
|
||||
print()
|
||||
print(' ' + authorize_url)
|
||||
print()
|
||||
print('If your browser is on a different machine then exit and re-run')
|
||||
print('this application with the command-line parameter ')
|
||||
print()
|
||||
print(' --noauth_local_webserver')
|
||||
print()
|
||||
else:
|
||||
print('Go to the following link in your browser:')
|
||||
print()
|
||||
print(' ' + authorize_url)
|
||||
print()
|
||||
|
||||
storage.put(credential)
|
||||
credential.set_store(storage)
|
||||
print('Authentication successful.')
|
||||
code = None
|
||||
if FLAGS.auth_local_webserver:
|
||||
httpd.handle_request()
|
||||
if 'error' in httpd.query_params:
|
||||
sys.exit('Authentication request was rejected.')
|
||||
if 'code' in httpd.query_params:
|
||||
code = httpd.query_params['code']
|
||||
else:
|
||||
print('Failed to find "code" in the query parameters of the redirect.')
|
||||
sys.exit('Try running with --noauth_local_webserver.')
|
||||
else:
|
||||
code = input('Enter verification code: ').strip()
|
||||
|
||||
return credential
|
||||
try:
|
||||
credential = flow.step2_exchange(code, http=http)
|
||||
except client.FlowExchangeError as e:
|
||||
sys.exit('Authentication has failed: %s' % e)
|
||||
|
||||
storage.put(credential)
|
||||
credential.set_store(storage)
|
||||
print('Authentication successful.')
|
||||
|
||||
return credential
|
||||
|
||||
@@ -34,83 +34,83 @@ from oauth2client.client import AssertionCredentials
|
||||
|
||||
|
||||
class _ServiceAccountCredentials(AssertionCredentials):
|
||||
"""Class representing a service account (signed JWT) credential."""
|
||||
"""Class representing a service account (signed JWT) credential."""
|
||||
|
||||
MAX_TOKEN_LIFETIME_SECS = 3600 # 1 hour in seconds
|
||||
MAX_TOKEN_LIFETIME_SECS = 3600 # 1 hour in seconds
|
||||
|
||||
def __init__(self, service_account_id, service_account_email, private_key_id,
|
||||
def __init__(self, service_account_id, service_account_email, private_key_id,
|
||||
private_key_pkcs8_text, scopes, user_agent=None,
|
||||
token_uri=GOOGLE_TOKEN_URI, revoke_uri=GOOGLE_REVOKE_URI,
|
||||
**kwargs):
|
||||
|
||||
super(_ServiceAccountCredentials, self).__init__(
|
||||
super(_ServiceAccountCredentials, self).__init__(
|
||||
None, user_agent=user_agent, token_uri=token_uri, revoke_uri=revoke_uri)
|
||||
|
||||
self._service_account_id = service_account_id
|
||||
self._service_account_email = service_account_email
|
||||
self._private_key_id = private_key_id
|
||||
self._private_key = _get_private_key(private_key_pkcs8_text)
|
||||
self._private_key_pkcs8_text = private_key_pkcs8_text
|
||||
self._scopes = util.scopes_to_string(scopes)
|
||||
self._user_agent = user_agent
|
||||
self._token_uri = token_uri
|
||||
self._revoke_uri = revoke_uri
|
||||
self._kwargs = kwargs
|
||||
self._service_account_id = service_account_id
|
||||
self._service_account_email = service_account_email
|
||||
self._private_key_id = private_key_id
|
||||
self._private_key = _get_private_key(private_key_pkcs8_text)
|
||||
self._private_key_pkcs8_text = private_key_pkcs8_text
|
||||
self._scopes = util.scopes_to_string(scopes)
|
||||
self._user_agent = user_agent
|
||||
self._token_uri = token_uri
|
||||
self._revoke_uri = revoke_uri
|
||||
self._kwargs = kwargs
|
||||
|
||||
def _generate_assertion(self):
|
||||
"""Generate the assertion that will be used in the request."""
|
||||
def _generate_assertion(self):
|
||||
"""Generate the assertion that will be used in the request."""
|
||||
|
||||
header = {
|
||||
header = {
|
||||
'alg': 'RS256',
|
||||
'typ': 'JWT',
|
||||
'kid': self._private_key_id
|
||||
}
|
||||
}
|
||||
|
||||
now = int(time.time())
|
||||
payload = {
|
||||
now = int(time.time())
|
||||
payload = {
|
||||
'aud': self._token_uri,
|
||||
'scope': self._scopes,
|
||||
'iat': now,
|
||||
'exp': now + _ServiceAccountCredentials.MAX_TOKEN_LIFETIME_SECS,
|
||||
'iss': self._service_account_email
|
||||
}
|
||||
payload.update(self._kwargs)
|
||||
}
|
||||
payload.update(self._kwargs)
|
||||
|
||||
first_segment = _urlsafe_b64encode(_json_encode(header))
|
||||
second_segment = _urlsafe_b64encode(_json_encode(payload))
|
||||
assertion_input = first_segment + b'.' + second_segment
|
||||
first_segment = _urlsafe_b64encode(_json_encode(header))
|
||||
second_segment = _urlsafe_b64encode(_json_encode(payload))
|
||||
assertion_input = first_segment + b'.' + second_segment
|
||||
|
||||
# Sign the assertion.
|
||||
rsa_bytes = rsa.pkcs1.sign(assertion_input, self._private_key, 'SHA-256')
|
||||
signature = base64.urlsafe_b64encode(rsa_bytes).rstrip(b'=')
|
||||
# Sign the assertion.
|
||||
rsa_bytes = rsa.pkcs1.sign(assertion_input, self._private_key, 'SHA-256')
|
||||
signature = base64.urlsafe_b64encode(rsa_bytes).rstrip(b'=')
|
||||
|
||||
return assertion_input + b'.' + signature
|
||||
return assertion_input + b'.' + signature
|
||||
|
||||
def sign_blob(self, blob):
|
||||
# Ensure that it is bytes
|
||||
blob = _to_bytes(blob, encoding='utf-8')
|
||||
return (self._private_key_id,
|
||||
def sign_blob(self, blob):
|
||||
# Ensure that it is bytes
|
||||
blob = _to_bytes(blob, encoding='utf-8')
|
||||
return (self._private_key_id,
|
||||
rsa.pkcs1.sign(blob, self._private_key, 'SHA-256'))
|
||||
|
||||
@property
|
||||
def service_account_email(self):
|
||||
return self._service_account_email
|
||||
@property
|
||||
def service_account_email(self):
|
||||
return self._service_account_email
|
||||
|
||||
@property
|
||||
def serialization_data(self):
|
||||
return {
|
||||
@property
|
||||
def serialization_data(self):
|
||||
return {
|
||||
'type': 'service_account',
|
||||
'client_id': self._service_account_id,
|
||||
'client_email': self._service_account_email,
|
||||
'private_key_id': self._private_key_id,
|
||||
'private_key': self._private_key_pkcs8_text
|
||||
}
|
||||
}
|
||||
|
||||
def create_scoped_required(self):
|
||||
return not self._scopes
|
||||
def create_scoped_required(self):
|
||||
return not self._scopes
|
||||
|
||||
def create_scoped(self, scopes):
|
||||
return _ServiceAccountCredentials(self._service_account_id,
|
||||
def create_scoped(self, scopes):
|
||||
return _ServiceAccountCredentials(self._service_account_id,
|
||||
self._service_account_email,
|
||||
self._private_key_id,
|
||||
self._private_key_pkcs8_text,
|
||||
@@ -122,10 +122,10 @@ class _ServiceAccountCredentials(AssertionCredentials):
|
||||
|
||||
|
||||
def _get_private_key(private_key_pkcs8_text):
|
||||
"""Get an RSA private key object from a pkcs8 representation."""
|
||||
private_key_pkcs8_text = _to_bytes(private_key_pkcs8_text)
|
||||
der = rsa.pem.load_pem(private_key_pkcs8_text, 'PRIVATE KEY')
|
||||
asn1_private_key, _ = decoder.decode(der, asn1Spec=PrivateKeyInfo())
|
||||
return rsa.PrivateKey.load_pkcs1(
|
||||
"""Get an RSA private key object from a pkcs8 representation."""
|
||||
private_key_pkcs8_text = _to_bytes(private_key_pkcs8_text)
|
||||
der = rsa.pem.load_pem(private_key_pkcs8_text, 'PRIVATE KEY')
|
||||
asn1_private_key, _ = decoder.decode(der, asn1Spec=PrivateKeyInfo())
|
||||
return rsa.PrivateKey.load_pkcs1(
|
||||
asn1_private_key.getComponentByName('privateKey').asOctets(),
|
||||
format='DER')
|
||||
|
||||
@@ -35,7 +35,6 @@ from six.moves import input
|
||||
from oauth2client import client
|
||||
from oauth2client import util
|
||||
|
||||
|
||||
_CLIENT_SECRETS_MESSAGE = """WARNING: Please configure OAuth 2.0
|
||||
|
||||
To make this sample run you will need to populate the client_secrets.json file
|
||||
@@ -47,22 +46,23 @@ with information from the APIs Console <https://code.google.com/apis/console>.
|
||||
|
||||
"""
|
||||
|
||||
|
||||
def _CreateArgumentParser():
|
||||
try:
|
||||
import argparse
|
||||
except ImportError:
|
||||
return None
|
||||
parser = argparse.ArgumentParser(add_help=False)
|
||||
parser.add_argument('--auth_host_name', default='localhost',
|
||||
try:
|
||||
import argparse
|
||||
except ImportError:
|
||||
return None
|
||||
parser = argparse.ArgumentParser(add_help=False)
|
||||
parser.add_argument('--auth_host_name', default='localhost',
|
||||
help='Hostname when running a local web server.')
|
||||
parser.add_argument('--noauth_local_webserver', action='store_true',
|
||||
parser.add_argument('--noauth_local_webserver', action='store_true',
|
||||
default=False, help='Do not run a local web server.')
|
||||
parser.add_argument('--auth_host_port', default=[8080, 8090], type=int,
|
||||
parser.add_argument('--auth_host_port', default=[8080, 8090], type=int,
|
||||
nargs='*', help='Port web server should listen on.')
|
||||
parser.add_argument('--logging_level', default='ERROR',
|
||||
parser.add_argument('--logging_level', default='ERROR',
|
||||
choices=['DEBUG', 'INFO', 'WARNING', 'ERROR', 'CRITICAL'],
|
||||
help='Set the logging level of detail.')
|
||||
return parser
|
||||
return parser
|
||||
|
||||
# argparser is an ArgumentParser that contains command-line options expected
|
||||
# by tools.run(). Pass it in as part of the 'parents' argument to your own
|
||||
@@ -71,45 +71,45 @@ argparser = _CreateArgumentParser()
|
||||
|
||||
|
||||
class ClientRedirectServer(BaseHTTPServer.HTTPServer):
|
||||
"""A server to handle OAuth 2.0 redirects back to localhost.
|
||||
"""A server to handle OAuth 2.0 redirects back to localhost.
|
||||
|
||||
Waits for a single request and parses the query parameters
|
||||
into query_params and then stops serving.
|
||||
"""
|
||||
query_params = {}
|
||||
query_params = {}
|
||||
|
||||
|
||||
class ClientRedirectHandler(BaseHTTPServer.BaseHTTPRequestHandler):
|
||||
"""A handler for OAuth 2.0 redirects back to localhost.
|
||||
"""A handler for OAuth 2.0 redirects back to localhost.
|
||||
|
||||
Waits for a single request and parses the query parameters
|
||||
into the servers query_params and then stops serving.
|
||||
"""
|
||||
|
||||
def do_GET(self):
|
||||
"""Handle a GET request.
|
||||
def do_GET(self):
|
||||
"""Handle a GET request.
|
||||
|
||||
Parses the query parameters and prints a message
|
||||
if the flow has completed. Note that we can't detect
|
||||
if an error occurred.
|
||||
"""
|
||||
self.send_response(200)
|
||||
self.send_header("Content-type", "text/html")
|
||||
self.end_headers()
|
||||
query = self.path.split('?', 1)[-1]
|
||||
query = dict(urllib.parse.parse_qsl(query))
|
||||
self.server.query_params = query
|
||||
self.wfile.write(b"<html><head><title>Authentication Status</title></head>")
|
||||
self.wfile.write(b"<body><p>The authentication flow has completed.</p>")
|
||||
self.wfile.write(b"</body></html>")
|
||||
self.send_response(200)
|
||||
self.send_header("Content-type", "text/html")
|
||||
self.end_headers()
|
||||
query = self.path.split('?', 1)[-1]
|
||||
query = dict(urllib.parse.parse_qsl(query))
|
||||
self.server.query_params = query
|
||||
self.wfile.write(b"<html><head><title>Authentication Status</title></head>")
|
||||
self.wfile.write(b"<body><p>The authentication flow has completed.</p>")
|
||||
self.wfile.write(b"</body></html>")
|
||||
|
||||
def log_message(self, format, *args):
|
||||
"""Do not log messages to stdout while running as command line program."""
|
||||
def log_message(self, format, *args):
|
||||
"""Do not log messages to stdout while running as command line program."""
|
||||
|
||||
|
||||
@util.positional(3)
|
||||
def run_flow(flow, storage, flags, http=None):
|
||||
"""Core code for a command-line application.
|
||||
"""Core code for a command-line application.
|
||||
|
||||
The ``run()`` function is called from your application and runs
|
||||
through all the steps to obtain credentials. It takes a ``Flow``
|
||||
@@ -159,91 +159,91 @@ def run_flow(flow, storage, flags, http=None):
|
||||
Returns:
|
||||
Credentials, the obtained credential.
|
||||
"""
|
||||
logging.getLogger().setLevel(getattr(logging, flags.logging_level))
|
||||
if not flags.noauth_local_webserver:
|
||||
success = False
|
||||
port_number = 0
|
||||
for port in flags.auth_host_port:
|
||||
port_number = port
|
||||
try:
|
||||
httpd = ClientRedirectServer((flags.auth_host_name, port),
|
||||
logging.getLogger().setLevel(getattr(logging, flags.logging_level))
|
||||
if not flags.noauth_local_webserver:
|
||||
success = False
|
||||
port_number = 0
|
||||
for port in flags.auth_host_port:
|
||||
port_number = port
|
||||
try:
|
||||
httpd = ClientRedirectServer((flags.auth_host_name, port),
|
||||
ClientRedirectHandler)
|
||||
except socket.error:
|
||||
pass
|
||||
else:
|
||||
success = True
|
||||
break
|
||||
flags.noauth_local_webserver = not success
|
||||
except socket.error:
|
||||
pass
|
||||
else:
|
||||
success = True
|
||||
break
|
||||
flags.noauth_local_webserver = not success
|
||||
if not success:
|
||||
print('Failed to start a local webserver listening on either port 8080')
|
||||
print('or port 9090. Please check your firewall settings and locally')
|
||||
print('running programs that may be blocking or using those ports.')
|
||||
print()
|
||||
print('Falling back to --noauth_local_webserver and continuing with')
|
||||
print('authorization.')
|
||||
print()
|
||||
print('Failed to start a local webserver listening on either port 8080')
|
||||
print('or port 9090. Please check your firewall settings and locally')
|
||||
print('running programs that may be blocking or using those ports.')
|
||||
print()
|
||||
print('Falling back to --noauth_local_webserver and continuing with')
|
||||
print('authorization.')
|
||||
print()
|
||||
|
||||
if not flags.noauth_local_webserver:
|
||||
oauth_callback = 'http://%s:%s/' % (flags.auth_host_name, port_number)
|
||||
else:
|
||||
oauth_callback = client.OOB_CALLBACK_URN
|
||||
flow.redirect_uri = oauth_callback
|
||||
authorize_url = flow.step1_get_authorize_url()
|
||||
|
||||
if not flags.noauth_local_webserver:
|
||||
import webbrowser
|
||||
webbrowser.open(authorize_url, new=1, autoraise=True)
|
||||
print('Your browser has been opened to visit:')
|
||||
print()
|
||||
print(' ' + authorize_url)
|
||||
print()
|
||||
print('If your browser is on a different machine then exit and re-run this')
|
||||
print('application with the command-line parameter ')
|
||||
print()
|
||||
print(' --noauth_local_webserver')
|
||||
print()
|
||||
else:
|
||||
print('Go to the following link in your browser:')
|
||||
print()
|
||||
print(' ' + authorize_url)
|
||||
print()
|
||||
|
||||
code = None
|
||||
if not flags.noauth_local_webserver:
|
||||
httpd.handle_request()
|
||||
if 'error' in httpd.query_params:
|
||||
sys.exit('Authentication request was rejected.')
|
||||
if 'code' in httpd.query_params:
|
||||
code = httpd.query_params['code']
|
||||
if not flags.noauth_local_webserver:
|
||||
oauth_callback = 'http://%s:%s/' % (flags.auth_host_name, port_number)
|
||||
else:
|
||||
print('Failed to find "code" in the query parameters of the redirect.')
|
||||
sys.exit('Try running with --noauth_local_webserver.')
|
||||
else:
|
||||
code = input('Enter verification code: ').strip()
|
||||
oauth_callback = client.OOB_CALLBACK_URN
|
||||
flow.redirect_uri = oauth_callback
|
||||
authorize_url = flow.step1_get_authorize_url()
|
||||
|
||||
try:
|
||||
credential = flow.step2_exchange(code, http=http)
|
||||
except client.FlowExchangeError as e:
|
||||
sys.exit('Authentication has failed: %s' % e)
|
||||
if not flags.noauth_local_webserver:
|
||||
import webbrowser
|
||||
webbrowser.open(authorize_url, new=1, autoraise=True)
|
||||
print('Your browser has been opened to visit:')
|
||||
print()
|
||||
print(' ' + authorize_url)
|
||||
print()
|
||||
print('If your browser is on a different machine then exit and re-run this')
|
||||
print('application with the command-line parameter ')
|
||||
print()
|
||||
print(' --noauth_local_webserver')
|
||||
print()
|
||||
else:
|
||||
print('Go to the following link in your browser:')
|
||||
print()
|
||||
print(' ' + authorize_url)
|
||||
print()
|
||||
|
||||
storage.put(credential)
|
||||
credential.set_store(storage)
|
||||
print('Authentication successful.')
|
||||
code = None
|
||||
if not flags.noauth_local_webserver:
|
||||
httpd.handle_request()
|
||||
if 'error' in httpd.query_params:
|
||||
sys.exit('Authentication request was rejected.')
|
||||
if 'code' in httpd.query_params:
|
||||
code = httpd.query_params['code']
|
||||
else:
|
||||
print('Failed to find "code" in the query parameters of the redirect.')
|
||||
sys.exit('Try running with --noauth_local_webserver.')
|
||||
else:
|
||||
code = input('Enter verification code: ').strip()
|
||||
|
||||
return credential
|
||||
try:
|
||||
credential = flow.step2_exchange(code, http=http)
|
||||
except client.FlowExchangeError as e:
|
||||
sys.exit('Authentication has failed: %s' % e)
|
||||
|
||||
storage.put(credential)
|
||||
credential.set_store(storage)
|
||||
print('Authentication successful.')
|
||||
|
||||
return credential
|
||||
|
||||
|
||||
def message_if_missing(filename):
|
||||
"""Helpful message to display if the CLIENT_SECRETS file is missing."""
|
||||
"""Helpful message to display if the CLIENT_SECRETS file is missing."""
|
||||
|
||||
return _CLIENT_SECRETS_MESSAGE % filename
|
||||
return _CLIENT_SECRETS_MESSAGE % filename
|
||||
|
||||
try:
|
||||
from oauth2client.old_run import run
|
||||
from oauth2client.old_run import FLAGS
|
||||
from oauth2client.old_run import run
|
||||
from oauth2client.old_run import FLAGS
|
||||
except ImportError:
|
||||
def run(*args, **kwargs):
|
||||
raise NotImplementedError(
|
||||
def run(*args, **kwargs):
|
||||
raise NotImplementedError(
|
||||
'The gflags library must be installed to use tools.run(). '
|
||||
'Please install gflags or preferrably switch to using '
|
||||
'tools.run_flow().')
|
||||
|
||||
@@ -38,7 +38,6 @@ import types
|
||||
import six
|
||||
from six.moves import urllib
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
POSITIONAL_WARNING = 'WARNING'
|
||||
@@ -49,8 +48,9 @@ POSITIONAL_SET = frozenset([POSITIONAL_WARNING, POSITIONAL_EXCEPTION,
|
||||
|
||||
positional_parameters_enforcement = POSITIONAL_WARNING
|
||||
|
||||
|
||||
def positional(max_positional_args):
|
||||
"""A decorator to declare that only the first N arguments my be positional.
|
||||
"""A decorator to declare that only the first N arguments my be positional.
|
||||
|
||||
This decorator makes it easy to support Python 3 style keyword-only
|
||||
parameters. For example, in Python 3 it is possible to write::
|
||||
@@ -119,33 +119,34 @@ def positional(max_positional_args):
|
||||
POSITIONAL_EXCEPTION.
|
||||
|
||||
"""
|
||||
def positional_decorator(wrapped):
|
||||
@functools.wraps(wrapped)
|
||||
def positional_wrapper(*args, **kwargs):
|
||||
if len(args) > max_positional_args:
|
||||
plural_s = ''
|
||||
if max_positional_args != 1:
|
||||
plural_s = 's'
|
||||
message = '%s() takes at most %d positional argument%s (%d given)' % (
|
||||
wrapped.__name__, max_positional_args, plural_s, len(args))
|
||||
if positional_parameters_enforcement == POSITIONAL_EXCEPTION:
|
||||
raise TypeError(message)
|
||||
elif positional_parameters_enforcement == POSITIONAL_WARNING:
|
||||
logger.warning(message)
|
||||
else: # IGNORE
|
||||
pass
|
||||
return wrapped(*args, **kwargs)
|
||||
return positional_wrapper
|
||||
|
||||
if isinstance(max_positional_args, six.integer_types):
|
||||
return positional_decorator
|
||||
else:
|
||||
args, _, _, defaults = inspect.getargspec(max_positional_args)
|
||||
return positional(len(args) - len(defaults))(max_positional_args)
|
||||
def positional_decorator(wrapped):
|
||||
@functools.wraps(wrapped)
|
||||
def positional_wrapper(*args, **kwargs):
|
||||
if len(args) > max_positional_args:
|
||||
plural_s = ''
|
||||
if max_positional_args != 1:
|
||||
plural_s = 's'
|
||||
message = '%s() takes at most %d positional argument%s (%d given)' % (
|
||||
wrapped.__name__, max_positional_args, plural_s, len(args))
|
||||
if positional_parameters_enforcement == POSITIONAL_EXCEPTION:
|
||||
raise TypeError(message)
|
||||
elif positional_parameters_enforcement == POSITIONAL_WARNING:
|
||||
logger.warning(message)
|
||||
else: # IGNORE
|
||||
pass
|
||||
return wrapped(*args, **kwargs)
|
||||
return positional_wrapper
|
||||
|
||||
if isinstance(max_positional_args, six.integer_types):
|
||||
return positional_decorator
|
||||
else:
|
||||
args, _, _, defaults = inspect.getargspec(max_positional_args)
|
||||
return positional(len(args) - len(defaults))(max_positional_args)
|
||||
|
||||
|
||||
def scopes_to_string(scopes):
|
||||
"""Converts scope value to a string.
|
||||
"""Converts scope value to a string.
|
||||
|
||||
If scopes is a string then it is simply passed through. If scopes is an
|
||||
iterable then a string is returned that is all the individual scopes
|
||||
@@ -157,14 +158,14 @@ def scopes_to_string(scopes):
|
||||
Returns:
|
||||
The scopes formatted as a single string.
|
||||
"""
|
||||
if isinstance(scopes, six.string_types):
|
||||
return scopes
|
||||
else:
|
||||
return ' '.join(scopes)
|
||||
if isinstance(scopes, six.string_types):
|
||||
return scopes
|
||||
else:
|
||||
return ' '.join(scopes)
|
||||
|
||||
|
||||
def string_to_scopes(scopes):
|
||||
"""Converts stringifed scope value to a list.
|
||||
"""Converts stringifed scope value to a list.
|
||||
|
||||
If scopes is a list then it is simply passed through. If scopes is an
|
||||
string then a list of each individual scope is returned.
|
||||
@@ -175,16 +176,16 @@ def string_to_scopes(scopes):
|
||||
Returns:
|
||||
The scopes in a list.
|
||||
"""
|
||||
if not scopes:
|
||||
return []
|
||||
if isinstance(scopes, six.string_types):
|
||||
return scopes.split(' ')
|
||||
else:
|
||||
return scopes
|
||||
if not scopes:
|
||||
return []
|
||||
if isinstance(scopes, six.string_types):
|
||||
return scopes.split(' ')
|
||||
else:
|
||||
return scopes
|
||||
|
||||
|
||||
def dict_to_tuple_key(dictionary):
|
||||
"""Converts a dictionary to a tuple that can be used as an immutable key.
|
||||
"""Converts a dictionary to a tuple that can be used as an immutable key.
|
||||
|
||||
The resulting key is always sorted so that logically equivalent dictionaries
|
||||
always produce an identical tuple for a key.
|
||||
@@ -195,11 +196,11 @@ def dict_to_tuple_key(dictionary):
|
||||
Returns:
|
||||
A tuple representing the dictionary in it's naturally sorted ordering.
|
||||
"""
|
||||
return tuple(sorted(dictionary.items()))
|
||||
return tuple(sorted(dictionary.items()))
|
||||
|
||||
|
||||
def _add_query_parameter(url, name, value):
|
||||
"""Adds a query parameter to a url.
|
||||
"""Adds a query parameter to a url.
|
||||
|
||||
Replaces the current value if it already exists in the URL.
|
||||
|
||||
@@ -211,11 +212,11 @@ def _add_query_parameter(url, name, value):
|
||||
Returns:
|
||||
Updated query parameter. Does not update the url if value is None.
|
||||
"""
|
||||
if value is None:
|
||||
return url
|
||||
else:
|
||||
parsed = list(urllib.parse.urlparse(url))
|
||||
q = dict(urllib.parse.parse_qsl(parsed[4]))
|
||||
q[name] = value
|
||||
parsed[4] = urllib.parse.urlencode(q)
|
||||
return urllib.parse.urlunparse(parsed)
|
||||
if value is None:
|
||||
return url
|
||||
else:
|
||||
parsed = list(urllib.parse.urlparse(url))
|
||||
q = dict(urllib.parse.parse_qsl(parsed[4]))
|
||||
q[name] = value
|
||||
parsed[4] = urllib.parse.urlencode(q)
|
||||
return urllib.parse.urlunparse(parsed)
|
||||
|
||||
@@ -20,7 +20,6 @@ __authors__ = [
|
||||
'"Joe Gregorio" <jcgregorio@google.com>',
|
||||
]
|
||||
|
||||
|
||||
import base64
|
||||
import hmac
|
||||
import time
|
||||
@@ -28,13 +27,11 @@ import time
|
||||
import six
|
||||
from oauth2client import util
|
||||
|
||||
|
||||
# Delimiter character
|
||||
DELIMITER = b':'
|
||||
|
||||
|
||||
# 1 hour in seconds
|
||||
DEFAULT_TIMEOUT_SECS = 1*60*60
|
||||
DEFAULT_TIMEOUT_SECS = 1 * 60 * 60
|
||||
|
||||
|
||||
def _force_bytes(s):
|
||||
@@ -48,7 +45,7 @@ def _force_bytes(s):
|
||||
|
||||
@util.positional(2)
|
||||
def generate_token(key, user_id, action_id="", when=None):
|
||||
"""Generates a URL-safe token for the given user, action, time tuple.
|
||||
"""Generates a URL-safe token for the given user, action, time tuple.
|
||||
|
||||
Args:
|
||||
key: secret key to use.
|
||||
@@ -61,22 +58,22 @@ def generate_token(key, user_id, action_id="", when=None):
|
||||
Returns:
|
||||
A string XSRF protection token.
|
||||
"""
|
||||
when = _force_bytes(when or int(time.time()))
|
||||
digester = hmac.new(_force_bytes(key))
|
||||
digester.update(_force_bytes(user_id))
|
||||
digester.update(DELIMITER)
|
||||
digester.update(_force_bytes(action_id))
|
||||
digester.update(DELIMITER)
|
||||
digester.update(when)
|
||||
digest = digester.digest()
|
||||
when = _force_bytes(when or int(time.time()))
|
||||
digester = hmac.new(_force_bytes(key))
|
||||
digester.update(_force_bytes(user_id))
|
||||
digester.update(DELIMITER)
|
||||
digester.update(_force_bytes(action_id))
|
||||
digester.update(DELIMITER)
|
||||
digester.update(when)
|
||||
digest = digester.digest()
|
||||
|
||||
token = base64.urlsafe_b64encode(digest + DELIMITER + when)
|
||||
return token
|
||||
token = base64.urlsafe_b64encode(digest + DELIMITER + when)
|
||||
return token
|
||||
|
||||
|
||||
@util.positional(3)
|
||||
def validate_token(key, token, user_id, action_id="", current_time=None):
|
||||
"""Validates that the given token authorizes the user for the action.
|
||||
"""Validates that the given token authorizes the user for the action.
|
||||
|
||||
Tokens are invalid if the time of issue is too old or if the token
|
||||
does not match what generateToken outputs (i.e. the token was forged).
|
||||
@@ -92,27 +89,27 @@ def validate_token(key, token, user_id, action_id="", current_time=None):
|
||||
A boolean - True if the user is authorized for the action, False
|
||||
otherwise.
|
||||
"""
|
||||
if not token:
|
||||
return False
|
||||
try:
|
||||
decoded = base64.urlsafe_b64decode(token)
|
||||
token_time = int(decoded.split(DELIMITER)[-1])
|
||||
except (TypeError, ValueError):
|
||||
return False
|
||||
if current_time is None:
|
||||
current_time = time.time()
|
||||
# If the token is too old it's not valid.
|
||||
if current_time - token_time > DEFAULT_TIMEOUT_SECS:
|
||||
return False
|
||||
if not token:
|
||||
return False
|
||||
try:
|
||||
decoded = base64.urlsafe_b64decode(token)
|
||||
token_time = int(decoded.split(DELIMITER)[-1])
|
||||
except (TypeError, ValueError):
|
||||
return False
|
||||
if current_time is None:
|
||||
current_time = time.time()
|
||||
# If the token is too old it's not valid.
|
||||
if current_time - token_time > DEFAULT_TIMEOUT_SECS:
|
||||
return False
|
||||
|
||||
# The given token should match the generated one with the same time.
|
||||
expected_token = generate_token(key, user_id, action_id=action_id,
|
||||
# The given token should match the generated one with the same time.
|
||||
expected_token = generate_token(key, user_id, action_id=action_id,
|
||||
when=token_time)
|
||||
if len(token) != len(expected_token):
|
||||
return False
|
||||
if len(token) != len(expected_token):
|
||||
return False
|
||||
|
||||
# Perform constant time comparison to avoid timing attacks
|
||||
different = 0
|
||||
for x, y in zip(bytearray(token), bytearray(expected_token)):
|
||||
different |= x ^ y
|
||||
return not different
|
||||
# Perform constant time comparison to avoid timing attacks
|
||||
different = 0
|
||||
for x, y in zip(bytearray(token), bytearray(expected_token)):
|
||||
different |= x ^ y
|
||||
return not different
|
||||
|
||||
@@ -16,6 +16,7 @@ __author__ = 'afshar@google.com (Ali Afshar)'
|
||||
|
||||
import oauth2client.util
|
||||
|
||||
|
||||
def setup_package():
|
||||
"""Run on testing package."""
|
||||
oauth2client.util.positional_parameters_enforcement = 'EXCEPTION'
|
||||
"""Run on testing package."""
|
||||
oauth2client.util.positional_parameters_enforcement = 'EXCEPTION'
|
||||
|
||||
@@ -20,47 +20,45 @@ import httplib2
|
||||
|
||||
# TODO(craigcitro): Find a cleaner way to share this code with googleapiclient.
|
||||
|
||||
|
||||
class HttpMock(object):
|
||||
"""Mock of httplib2.Http"""
|
||||
"""Mock of httplib2.Http"""
|
||||
|
||||
def __init__(self, filename=None, headers=None):
|
||||
"""
|
||||
def __init__(self, filename=None, headers=None):
|
||||
"""
|
||||
Args:
|
||||
filename: string, absolute filename to read response from
|
||||
headers: dict, header to return with response
|
||||
"""
|
||||
if headers is None:
|
||||
headers = {'status': '200 OK'}
|
||||
if filename:
|
||||
f = file(filename, 'r')
|
||||
self.data = f.read()
|
||||
f.close()
|
||||
else:
|
||||
self.data = None
|
||||
self.response_headers = headers
|
||||
self.headers = None
|
||||
self.uri = None
|
||||
self.method = None
|
||||
self.body = None
|
||||
self.headers = None
|
||||
if headers is None:
|
||||
headers = {'status': '200 OK'}
|
||||
if filename:
|
||||
f = file(filename, 'r')
|
||||
self.data = f.read()
|
||||
f.close()
|
||||
else:
|
||||
self.data = None
|
||||
self.response_headers = headers
|
||||
self.headers = None
|
||||
self.uri = None
|
||||
self.method = None
|
||||
self.body = None
|
||||
self.headers = None
|
||||
|
||||
|
||||
def request(self, uri,
|
||||
def request(self, uri,
|
||||
method='GET',
|
||||
body=None,
|
||||
headers=None,
|
||||
redirections=1,
|
||||
connection_type=None):
|
||||
self.uri = uri
|
||||
self.method = method
|
||||
self.body = body
|
||||
self.headers = headers
|
||||
return httplib2.Response(self.response_headers), self.data
|
||||
self.uri = uri
|
||||
self.method = method
|
||||
self.body = body
|
||||
self.headers = headers
|
||||
return httplib2.Response(self.response_headers), self.data
|
||||
|
||||
|
||||
class HttpMockSequence(object):
|
||||
"""Mock of httplib2.Http
|
||||
"""Mock of httplib2.Http
|
||||
|
||||
Mocks a sequence of calls to request returning different responses for each
|
||||
call. Create an instance initialized with the desired response headers
|
||||
@@ -83,33 +81,33 @@ class HttpMockSequence(object):
|
||||
'echo_request_uri' means return the request uri in the response body
|
||||
"""
|
||||
|
||||
def __init__(self, iterable):
|
||||
"""
|
||||
def __init__(self, iterable):
|
||||
"""
|
||||
Args:
|
||||
iterable: iterable, a sequence of pairs of (headers, body)
|
||||
"""
|
||||
self._iterable = iterable
|
||||
self.follow_redirects = True
|
||||
self.requests = []
|
||||
self._iterable = iterable
|
||||
self.follow_redirects = True
|
||||
self.requests = []
|
||||
|
||||
def request(self, uri,
|
||||
def request(self, uri,
|
||||
method='GET',
|
||||
body=None,
|
||||
headers=None,
|
||||
redirections=1,
|
||||
connection_type=None):
|
||||
resp, content = self._iterable.pop(0)
|
||||
self.requests.append({'uri': uri, 'body': body, 'headers': headers})
|
||||
# Read any underlying stream before sending the request.
|
||||
body_stream_content = body.read() if getattr(body, 'read', None) else None
|
||||
if content == 'echo_request_headers':
|
||||
content = headers
|
||||
elif content == 'echo_request_headers_as_json':
|
||||
content = json.dumps(headers)
|
||||
elif content == 'echo_request_body':
|
||||
content = body if body_stream_content is None else body_stream_content
|
||||
elif content == 'echo_request_uri':
|
||||
content = uri
|
||||
elif not isinstance(content, bytes):
|
||||
raise TypeError('http content should be bytes: %r' % (content,))
|
||||
return httplib2.Response(resp), content
|
||||
resp, content = self._iterable.pop(0)
|
||||
self.requests.append({'uri': uri, 'body': body, 'headers': headers})
|
||||
# Read any underlying stream before sending the request.
|
||||
body_stream_content = body.read() if getattr(body, 'read', None) else None
|
||||
if content == 'echo_request_headers':
|
||||
content = headers
|
||||
elif content == 'echo_request_headers_as_json':
|
||||
content = json.dumps(headers)
|
||||
elif content == 'echo_request_body':
|
||||
content = body if body_stream_content is None else body_stream_content
|
||||
elif content == 'echo_request_uri':
|
||||
content = uri
|
||||
elif not isinstance(content, bytes):
|
||||
raise TypeError('http content should be bytes: %r' % (content, ))
|
||||
return httplib2.Response(resp), content
|
||||
|
||||
@@ -25,93 +25,93 @@ from oauth2client._helpers import _urlsafe_b64encode
|
||||
|
||||
class Test__parse_pem_key(unittest.TestCase):
|
||||
|
||||
def test_valid_input(self):
|
||||
test_string = b'1234-----BEGIN FOO BAR BAZ'
|
||||
result = _parse_pem_key(test_string)
|
||||
self.assertEqual(result, test_string[4:])
|
||||
def test_valid_input(self):
|
||||
test_string = b'1234-----BEGIN FOO BAR BAZ'
|
||||
result = _parse_pem_key(test_string)
|
||||
self.assertEqual(result, test_string[4:])
|
||||
|
||||
def test_bad_input(self):
|
||||
test_string = b'DOES NOT HAVE DASHES'
|
||||
result = _parse_pem_key(test_string)
|
||||
self.assertEqual(result, None)
|
||||
def test_bad_input(self):
|
||||
test_string = b'DOES NOT HAVE DASHES'
|
||||
result = _parse_pem_key(test_string)
|
||||
self.assertEqual(result, None)
|
||||
|
||||
|
||||
class Test__json_encode(unittest.TestCase):
|
||||
|
||||
def test_dictionary_input(self):
|
||||
# Use only a single key since dictionary hash order
|
||||
# is non-deterministic.
|
||||
data = {u'foo': 10}
|
||||
result = _json_encode(data)
|
||||
self.assertEqual(result, """{"foo":10}""")
|
||||
def test_dictionary_input(self):
|
||||
# Use only a single key since dictionary hash order
|
||||
# is non-deterministic.
|
||||
data = {u'foo': 10}
|
||||
result = _json_encode(data)
|
||||
self.assertEqual(result, """{"foo":10}""")
|
||||
|
||||
def test_list_input(self):
|
||||
data = [42, 1337]
|
||||
result = _json_encode(data)
|
||||
self.assertEqual(result, """[42,1337]""")
|
||||
def test_list_input(self):
|
||||
data = [42, 1337]
|
||||
result = _json_encode(data)
|
||||
self.assertEqual(result, """[42,1337]""")
|
||||
|
||||
|
||||
class Test__to_bytes(unittest.TestCase):
|
||||
|
||||
def test_with_bytes(self):
|
||||
value = b'bytes-val'
|
||||
self.assertEqual(_to_bytes(value), value)
|
||||
def test_with_bytes(self):
|
||||
value = b'bytes-val'
|
||||
self.assertEqual(_to_bytes(value), value)
|
||||
|
||||
def test_with_unicode(self):
|
||||
value = u'string-val'
|
||||
encoded_value = b'string-val'
|
||||
self.assertEqual(_to_bytes(value), encoded_value)
|
||||
def test_with_unicode(self):
|
||||
value = u'string-val'
|
||||
encoded_value = b'string-val'
|
||||
self.assertEqual(_to_bytes(value), encoded_value)
|
||||
|
||||
def test_with_nonstring_type(self):
|
||||
value = object()
|
||||
self.assertRaises(ValueError, _to_bytes, value)
|
||||
def test_with_nonstring_type(self):
|
||||
value = object()
|
||||
self.assertRaises(ValueError, _to_bytes, value)
|
||||
|
||||
|
||||
class Test__from_bytes(unittest.TestCase):
|
||||
|
||||
def test_with_unicode(self):
|
||||
value = u'bytes-val'
|
||||
self.assertEqual(_from_bytes(value), value)
|
||||
def test_with_unicode(self):
|
||||
value = u'bytes-val'
|
||||
self.assertEqual(_from_bytes(value), value)
|
||||
|
||||
def test_with_bytes(self):
|
||||
value = b'string-val'
|
||||
decoded_value = u'string-val'
|
||||
self.assertEqual(_from_bytes(value), decoded_value)
|
||||
def test_with_bytes(self):
|
||||
value = b'string-val'
|
||||
decoded_value = u'string-val'
|
||||
self.assertEqual(_from_bytes(value), decoded_value)
|
||||
|
||||
def test_with_nonstring_type(self):
|
||||
value = object()
|
||||
self.assertRaises(ValueError, _from_bytes, value)
|
||||
def test_with_nonstring_type(self):
|
||||
value = object()
|
||||
self.assertRaises(ValueError, _from_bytes, value)
|
||||
|
||||
|
||||
class Test__urlsafe_b64encode(unittest.TestCase):
|
||||
|
||||
DEADBEEF_ENCODED = b'ZGVhZGJlZWY'
|
||||
DEADBEEF_ENCODED = b'ZGVhZGJlZWY'
|
||||
|
||||
def test_valid_input_bytes(self):
|
||||
test_string = b'deadbeef'
|
||||
result = _urlsafe_b64encode(test_string)
|
||||
self.assertEqual(result, self.DEADBEEF_ENCODED)
|
||||
def test_valid_input_bytes(self):
|
||||
test_string = b'deadbeef'
|
||||
result = _urlsafe_b64encode(test_string)
|
||||
self.assertEqual(result, self.DEADBEEF_ENCODED)
|
||||
|
||||
def test_valid_input_unicode(self):
|
||||
test_string = u'deadbeef'
|
||||
result = _urlsafe_b64encode(test_string)
|
||||
self.assertEqual(result, self.DEADBEEF_ENCODED)
|
||||
def test_valid_input_unicode(self):
|
||||
test_string = u'deadbeef'
|
||||
result = _urlsafe_b64encode(test_string)
|
||||
self.assertEqual(result, self.DEADBEEF_ENCODED)
|
||||
|
||||
|
||||
class Test__urlsafe_b64decode(unittest.TestCase):
|
||||
|
||||
def test_valid_input_bytes(self):
|
||||
test_string = b'ZGVhZGJlZWY'
|
||||
result = _urlsafe_b64decode(test_string)
|
||||
self.assertEqual(result, b'deadbeef')
|
||||
def test_valid_input_bytes(self):
|
||||
test_string = b'ZGVhZGJlZWY'
|
||||
result = _urlsafe_b64decode(test_string)
|
||||
self.assertEqual(result, b'deadbeef')
|
||||
|
||||
def test_valid_input_unicode(self):
|
||||
test_string = b'ZGVhZGJlZWY'
|
||||
result = _urlsafe_b64decode(test_string)
|
||||
self.assertEqual(result, b'deadbeef')
|
||||
def test_valid_input_unicode(self):
|
||||
test_string = b'ZGVhZGJlZWY'
|
||||
result = _urlsafe_b64decode(test_string)
|
||||
self.assertEqual(result, b'deadbeef')
|
||||
|
||||
def test_bad_input(self):
|
||||
import binascii
|
||||
bad_string = b'+'
|
||||
self.assertRaises((TypeError, binascii.Error),
|
||||
def test_bad_input(self):
|
||||
import binascii
|
||||
bad_string = b'+'
|
||||
self.assertRaises((TypeError, binascii.Error),
|
||||
_urlsafe_b64decode, bad_string)
|
||||
|
||||
@@ -22,42 +22,42 @@ from oauth2client.crypt import PyCryptoVerifier
|
||||
|
||||
class TestPyCryptoVerifier(unittest.TestCase):
|
||||
|
||||
PUBLIC_KEY_FILENAME = os.path.join(os.path.dirname(__file__),
|
||||
PUBLIC_KEY_FILENAME = os.path.join(os.path.dirname(__file__),
|
||||
'data', 'publickey.pem')
|
||||
PRIVATE_KEY_FILENAME = os.path.join(os.path.dirname(__file__),
|
||||
PRIVATE_KEY_FILENAME = os.path.join(os.path.dirname(__file__),
|
||||
'data', 'privatekey.pem')
|
||||
|
||||
def _load_public_key_bytes(self):
|
||||
with open(self.PUBLIC_KEY_FILENAME, 'rb') as fh:
|
||||
return fh.read()
|
||||
def _load_public_key_bytes(self):
|
||||
with open(self.PUBLIC_KEY_FILENAME, 'rb') as fh:
|
||||
return fh.read()
|
||||
|
||||
def _load_private_key_bytes(self):
|
||||
with open(self.PRIVATE_KEY_FILENAME, 'rb') as fh:
|
||||
return fh.read()
|
||||
def _load_private_key_bytes(self):
|
||||
with open(self.PRIVATE_KEY_FILENAME, 'rb') as fh:
|
||||
return fh.read()
|
||||
|
||||
def test_verify_success(self):
|
||||
to_sign = b'foo'
|
||||
signer = PyCryptoSigner.from_string(self._load_private_key_bytes())
|
||||
actual_signature = signer.sign(to_sign)
|
||||
def test_verify_success(self):
|
||||
to_sign = b'foo'
|
||||
signer = PyCryptoSigner.from_string(self._load_private_key_bytes())
|
||||
actual_signature = signer.sign(to_sign)
|
||||
|
||||
verifier = PyCryptoVerifier.from_string(self._load_public_key_bytes(),
|
||||
verifier = PyCryptoVerifier.from_string(self._load_public_key_bytes(),
|
||||
is_x509_cert=True)
|
||||
self.assertTrue(verifier.verify(to_sign, actual_signature))
|
||||
self.assertTrue(verifier.verify(to_sign, actual_signature))
|
||||
|
||||
def test_verify_failure(self):
|
||||
verifier = PyCryptoVerifier.from_string(self._load_public_key_bytes(),
|
||||
def test_verify_failure(self):
|
||||
verifier = PyCryptoVerifier.from_string(self._load_public_key_bytes(),
|
||||
is_x509_cert=True)
|
||||
bad_signature = b''
|
||||
self.assertFalse(verifier.verify(b'foo', bad_signature))
|
||||
bad_signature = b''
|
||||
self.assertFalse(verifier.verify(b'foo', bad_signature))
|
||||
|
||||
def test_verify_bad_key(self):
|
||||
verifier = PyCryptoVerifier.from_string(self._load_public_key_bytes(),
|
||||
def test_verify_bad_key(self):
|
||||
verifier = PyCryptoVerifier.from_string(self._load_public_key_bytes(),
|
||||
is_x509_cert=True)
|
||||
bad_signature = b''
|
||||
self.assertFalse(verifier.verify(b'foo', bad_signature))
|
||||
bad_signature = b''
|
||||
self.assertFalse(verifier.verify(b'foo', bad_signature))
|
||||
|
||||
def test_from_string_unicode_key(self):
|
||||
public_key = self._load_public_key_bytes()
|
||||
public_key = public_key.decode('utf-8')
|
||||
verifier = PyCryptoVerifier.from_string(public_key, is_x509_cert=True)
|
||||
self.assertTrue(isinstance(verifier, PyCryptoVerifier))
|
||||
def test_from_string_unicode_key(self):
|
||||
public_key = self._load_public_key_bytes()
|
||||
public_key = public_key.decode('utf-8')
|
||||
verifier = PyCryptoVerifier.from_string(public_key, is_x509_cert=True)
|
||||
self.assertTrue(isinstance(verifier, PyCryptoVerifier))
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -16,7 +16,6 @@
|
||||
|
||||
__author__ = 'jcgregorio@google.com (Joe Gregorio)'
|
||||
|
||||
|
||||
import os
|
||||
import unittest
|
||||
from io import StringIO
|
||||
@@ -25,22 +24,22 @@ import httplib2
|
||||
|
||||
from oauth2client import clientsecrets
|
||||
|
||||
|
||||
DATA_DIR = os.path.join(os.path.dirname(__file__), 'data')
|
||||
VALID_FILE = os.path.join(DATA_DIR, 'client_secrets.json')
|
||||
INVALID_FILE = os.path.join(DATA_DIR, 'unfilled_client_secrets.json')
|
||||
NONEXISTENT_FILE = os.path.join(__file__, '..', 'afilethatisntthere.json')
|
||||
|
||||
|
||||
class OAuth2CredentialsTests(unittest.TestCase):
|
||||
|
||||
def setUp(self):
|
||||
pass
|
||||
def setUp(self):
|
||||
pass
|
||||
|
||||
def tearDown(self):
|
||||
pass
|
||||
def tearDown(self):
|
||||
pass
|
||||
|
||||
def test_validate_error(self):
|
||||
ERRORS = [
|
||||
def test_validate_error(self):
|
||||
ERRORS = [
|
||||
('{}', 'Invalid'),
|
||||
('{"foo": {}}', 'Unknown'),
|
||||
('{"web": {}}', 'Missing'),
|
||||
@@ -56,95 +55,95 @@ class OAuth2CredentialsTests(unittest.TestCase):
|
||||
}
|
||||
""", 'Property'),
|
||||
]
|
||||
for src, match in ERRORS:
|
||||
# Ensure that it is unicode
|
||||
try:
|
||||
src = src.decode('utf-8')
|
||||
except AttributeError:
|
||||
pass
|
||||
# Test load(s)
|
||||
try:
|
||||
clientsecrets.loads(src)
|
||||
self.fail(src + ' should not be a valid client_secrets file.')
|
||||
except clientsecrets.InvalidClientSecretsError as e:
|
||||
self.assertTrue(str(e).startswith(match))
|
||||
for src, match in ERRORS:
|
||||
# Ensure that it is unicode
|
||||
try:
|
||||
src = src.decode('utf-8')
|
||||
except AttributeError:
|
||||
pass
|
||||
# Test load(s)
|
||||
try:
|
||||
clientsecrets.loads(src)
|
||||
self.fail(src + ' should not be a valid client_secrets file.')
|
||||
except clientsecrets.InvalidClientSecretsError as e:
|
||||
self.assertTrue(str(e).startswith(match))
|
||||
|
||||
# Test loads(fp)
|
||||
try:
|
||||
fp = StringIO(src)
|
||||
clientsecrets.load(fp)
|
||||
self.fail(src + ' should not be a valid client_secrets file.')
|
||||
except clientsecrets.InvalidClientSecretsError as e:
|
||||
self.assertTrue(str(e).startswith(match))
|
||||
# Test loads(fp)
|
||||
try:
|
||||
fp = StringIO(src)
|
||||
clientsecrets.load(fp)
|
||||
self.fail(src + ' should not be a valid client_secrets file.')
|
||||
except clientsecrets.InvalidClientSecretsError as e:
|
||||
self.assertTrue(str(e).startswith(match))
|
||||
|
||||
def test_load_by_filename(self):
|
||||
try:
|
||||
clientsecrets._loadfile(NONEXISTENT_FILE)
|
||||
self.fail('should fail to load a missing client_secrets file.')
|
||||
except clientsecrets.InvalidClientSecretsError as e:
|
||||
self.assertTrue(str(e).startswith('File'))
|
||||
def test_load_by_filename(self):
|
||||
try:
|
||||
clientsecrets._loadfile(NONEXISTENT_FILE)
|
||||
self.fail('should fail to load a missing client_secrets file.')
|
||||
except clientsecrets.InvalidClientSecretsError as e:
|
||||
self.assertTrue(str(e).startswith('File'))
|
||||
|
||||
|
||||
class CachedClientsecretsTests(unittest.TestCase):
|
||||
|
||||
class CacheMock(object):
|
||||
def __init__(self):
|
||||
self.cache = {}
|
||||
self.last_get_ns = None
|
||||
self.last_set_ns = None
|
||||
class CacheMock(object):
|
||||
def __init__(self):
|
||||
self.cache = {}
|
||||
self.last_get_ns = None
|
||||
self.last_set_ns = None
|
||||
|
||||
def get(self, key, namespace=''):
|
||||
# ignoring namespace for easier testing
|
||||
self.last_get_ns = namespace
|
||||
return self.cache.get(key, None)
|
||||
def get(self, key, namespace=''):
|
||||
# ignoring namespace for easier testing
|
||||
self.last_get_ns = namespace
|
||||
return self.cache.get(key, None)
|
||||
|
||||
def set(self, key, value, namespace=''):
|
||||
# ignoring namespace for easier testing
|
||||
self.last_set_ns = namespace
|
||||
self.cache[key] = value
|
||||
def set(self, key, value, namespace=''):
|
||||
# ignoring namespace for easier testing
|
||||
self.last_set_ns = namespace
|
||||
self.cache[key] = value
|
||||
|
||||
def setUp(self):
|
||||
self.cache_mock = self.CacheMock()
|
||||
def setUp(self):
|
||||
self.cache_mock = self.CacheMock()
|
||||
|
||||
def test_cache_miss(self):
|
||||
client_type, client_info = clientsecrets.loadfile(
|
||||
def test_cache_miss(self):
|
||||
client_type, client_info = clientsecrets.loadfile(
|
||||
VALID_FILE, cache=self.cache_mock)
|
||||
self.assertEqual('web', client_type)
|
||||
self.assertEqual('foo_client_secret', client_info['client_secret'])
|
||||
self.assertEqual('web', client_type)
|
||||
self.assertEqual('foo_client_secret', client_info['client_secret'])
|
||||
|
||||
cached = self.cache_mock.cache[VALID_FILE]
|
||||
self.assertEqual({client_type: client_info}, cached)
|
||||
cached = self.cache_mock.cache[VALID_FILE]
|
||||
self.assertEqual({client_type: client_info}, cached)
|
||||
|
||||
# make sure we're using non-empty namespace
|
||||
ns = self.cache_mock.last_set_ns
|
||||
self.assertTrue(bool(ns))
|
||||
# make sure they're equal
|
||||
self.assertEqual(ns, self.cache_mock.last_get_ns)
|
||||
# make sure we're using non-empty namespace
|
||||
ns = self.cache_mock.last_set_ns
|
||||
self.assertTrue(bool(ns))
|
||||
# make sure they're equal
|
||||
self.assertEqual(ns, self.cache_mock.last_get_ns)
|
||||
|
||||
def test_cache_hit(self):
|
||||
self.cache_mock.cache[NONEXISTENT_FILE] = { 'web': 'secret info' }
|
||||
def test_cache_hit(self):
|
||||
self.cache_mock.cache[NONEXISTENT_FILE] = {'web': 'secret info'}
|
||||
|
||||
client_type, client_info = clientsecrets.loadfile(
|
||||
client_type, client_info = clientsecrets.loadfile(
|
||||
NONEXISTENT_FILE, cache=self.cache_mock)
|
||||
self.assertEqual('web', client_type)
|
||||
self.assertEqual('secret info', client_info)
|
||||
# make sure we didn't do any set() RPCs
|
||||
self.assertEqual(None, self.cache_mock.last_set_ns)
|
||||
self.assertEqual('web', client_type)
|
||||
self.assertEqual('secret info', client_info)
|
||||
# make sure we didn't do any set() RPCs
|
||||
self.assertEqual(None, self.cache_mock.last_set_ns)
|
||||
|
||||
def test_validation(self):
|
||||
try:
|
||||
clientsecrets.loadfile(INVALID_FILE, cache=self.cache_mock)
|
||||
self.fail('Expected InvalidClientSecretsError to be raised '
|
||||
def test_validation(self):
|
||||
try:
|
||||
clientsecrets.loadfile(INVALID_FILE, cache=self.cache_mock)
|
||||
self.fail('Expected InvalidClientSecretsError to be raised '
|
||||
'while loading %s' % INVALID_FILE)
|
||||
except clientsecrets.InvalidClientSecretsError:
|
||||
pass
|
||||
except clientsecrets.InvalidClientSecretsError:
|
||||
pass
|
||||
|
||||
def test_without_cache(self):
|
||||
# this also ensures loadfile() is backward compatible
|
||||
client_type, client_info = clientsecrets.loadfile(VALID_FILE)
|
||||
self.assertEqual('web', client_type)
|
||||
self.assertEqual('foo_client_secret', client_info['client_secret'])
|
||||
def test_without_cache(self):
|
||||
# this also ensures loadfile() is backward compatible
|
||||
client_type, client_info = clientsecrets.loadfile(VALID_FILE)
|
||||
self.assertEqual('web', client_type)
|
||||
self.assertEqual('foo_client_secret', client_info['client_secret'])
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
unittest.main()
|
||||
|
||||
@@ -18,10 +18,10 @@ import sys
|
||||
import unittest
|
||||
|
||||
try:
|
||||
reload
|
||||
reload
|
||||
except NameError:
|
||||
# For Python3 (though importlib should be used, silly 3.3).
|
||||
from imp import reload
|
||||
# For Python3 (though importlib should be used, silly 3.3).
|
||||
from imp import reload
|
||||
|
||||
from oauth2client import _helpers
|
||||
from oauth2client.client import HAS_OPENSSL
|
||||
@@ -30,44 +30,44 @@ from oauth2client import crypt
|
||||
|
||||
|
||||
def datafile(filename):
|
||||
f = open(os.path.join(os.path.dirname(__file__), 'data', filename), 'rb')
|
||||
data = f.read()
|
||||
f.close()
|
||||
return data
|
||||
f = open(os.path.join(os.path.dirname(__file__), 'data', filename), 'rb')
|
||||
data = f.read()
|
||||
f.close()
|
||||
return data
|
||||
|
||||
|
||||
class Test_pkcs12_key_as_pem(unittest.TestCase):
|
||||
|
||||
def _make_signed_jwt_creds(self, private_key_file='privatekey.p12',
|
||||
def _make_signed_jwt_creds(self, private_key_file='privatekey.p12',
|
||||
private_key=None):
|
||||
private_key = private_key or datafile(private_key_file)
|
||||
return SignedJwtAssertionCredentials(
|
||||
private_key = private_key or datafile(private_key_file)
|
||||
return SignedJwtAssertionCredentials(
|
||||
'some_account@example.com',
|
||||
private_key,
|
||||
scope='read+write',
|
||||
sub='joe@example.org')
|
||||
|
||||
def _succeeds_helper(self, password=None):
|
||||
self.assertEqual(True, HAS_OPENSSL)
|
||||
def _succeeds_helper(self, password=None):
|
||||
self.assertEqual(True, HAS_OPENSSL)
|
||||
|
||||
credentials = self._make_signed_jwt_creds()
|
||||
if password is None:
|
||||
password = credentials.private_key_password
|
||||
pem_contents = crypt.pkcs12_key_as_pem(credentials.private_key, password)
|
||||
pkcs12_key_as_pem = datafile('pem_from_pkcs12.pem')
|
||||
pkcs12_key_as_pem = _helpers._parse_pem_key(pkcs12_key_as_pem)
|
||||
alternate_pem = datafile('pem_from_pkcs12_alternate.pem')
|
||||
self.assertTrue(pem_contents in [pkcs12_key_as_pem, alternate_pem])
|
||||
credentials = self._make_signed_jwt_creds()
|
||||
if password is None:
|
||||
password = credentials.private_key_password
|
||||
pem_contents = crypt.pkcs12_key_as_pem(credentials.private_key, password)
|
||||
pkcs12_key_as_pem = datafile('pem_from_pkcs12.pem')
|
||||
pkcs12_key_as_pem = _helpers._parse_pem_key(pkcs12_key_as_pem)
|
||||
alternate_pem = datafile('pem_from_pkcs12_alternate.pem')
|
||||
self.assertTrue(pem_contents in [pkcs12_key_as_pem, alternate_pem])
|
||||
|
||||
def test_succeeds(self):
|
||||
self._succeeds_helper()
|
||||
def test_succeeds(self):
|
||||
self._succeeds_helper()
|
||||
|
||||
def test_succeeds_with_unicode_password(self):
|
||||
password = u'notasecret'
|
||||
self._succeeds_helper(password)
|
||||
def test_succeeds_with_unicode_password(self):
|
||||
password = u'notasecret'
|
||||
self._succeeds_helper(password)
|
||||
|
||||
def test_with_nonsense_key(self):
|
||||
from OpenSSL import crypto
|
||||
credentials = self._make_signed_jwt_creds(private_key=b'NOT_A_KEY')
|
||||
self.assertRaises(crypto.Error, crypt.pkcs12_key_as_pem,
|
||||
def test_with_nonsense_key(self):
|
||||
from OpenSSL import crypto
|
||||
credentials = self._make_signed_jwt_creds(private_key=b'NOT_A_KEY')
|
||||
self.assertRaises(crypto.Error, crypt.pkcs12_key_as_pem,
|
||||
credentials.private_key, credentials.private_key_password)
|
||||
|
||||
@@ -12,7 +12,6 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
"""Tests for oauth2client.devshell."""
|
||||
|
||||
import os
|
||||
@@ -30,110 +29,110 @@ from oauth2client.devshell import NoDevshellServer
|
||||
|
||||
class _AuthReferenceServer(threading.Thread):
|
||||
|
||||
def __init__(self, response=None):
|
||||
super(_AuthReferenceServer, self).__init__(None)
|
||||
self.response = (response or
|
||||
def __init__(self, response=None):
|
||||
super(_AuthReferenceServer, self).__init__(None)
|
||||
self.response = (response or
|
||||
'["joe@example.com", "fooproj", "sometoken"]')
|
||||
|
||||
def __enter__(self):
|
||||
self.start_server()
|
||||
def __enter__(self):
|
||||
self.start_server()
|
||||
|
||||
def start_server(self):
|
||||
self._socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
|
||||
self._socket.bind(('localhost', 0))
|
||||
port = self._socket.getsockname()[1]
|
||||
os.environ[DEVSHELL_ENV] = str(port)
|
||||
self._socket.listen(0)
|
||||
self.daemon = True
|
||||
self.start()
|
||||
return self
|
||||
def start_server(self):
|
||||
self._socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
|
||||
self._socket.bind(('localhost', 0))
|
||||
port = self._socket.getsockname()[1]
|
||||
os.environ[DEVSHELL_ENV] = str(port)
|
||||
self._socket.listen(0)
|
||||
self.daemon = True
|
||||
self.start()
|
||||
return self
|
||||
|
||||
def __exit__(self, e_type, value, traceback):
|
||||
self.stop_server()
|
||||
def __exit__(self, e_type, value, traceback):
|
||||
self.stop_server()
|
||||
|
||||
def stop_server(self):
|
||||
del os.environ[DEVSHELL_ENV]
|
||||
self._socket.close()
|
||||
def stop_server(self):
|
||||
del os.environ[DEVSHELL_ENV]
|
||||
self._socket.close()
|
||||
|
||||
def run(self):
|
||||
s = None
|
||||
try:
|
||||
# Do not set the timeout on the socket, leave it in the blocking mode as
|
||||
# setting the timeout seems to cause spurious EAGAIN errors on OSX.
|
||||
self._socket.settimeout(None)
|
||||
def run(self):
|
||||
s = None
|
||||
try:
|
||||
# Do not set the timeout on the socket, leave it in the blocking mode as
|
||||
# setting the timeout seems to cause spurious EAGAIN errors on OSX.
|
||||
self._socket.settimeout(None)
|
||||
|
||||
s, unused_addr = self._socket.accept()
|
||||
resp_buffer = ''
|
||||
resp_1 = s.recv(6).decode()
|
||||
if '\n' not in resp_1:
|
||||
raise Exception('invalid request data')
|
||||
nstr, extra = resp_1.split('\n', 1)
|
||||
resp_buffer = extra
|
||||
n = int(nstr)
|
||||
to_read = n-len(extra)
|
||||
if to_read > 0:
|
||||
resp_buffer += s.recv(to_read, socket.MSG_WAITALL)
|
||||
if resp_buffer != CREDENTIAL_INFO_REQUEST_JSON:
|
||||
raise Exception('bad request')
|
||||
l = len(self.response)
|
||||
s.sendall(('%d\n%s' % (l, self.response)).encode())
|
||||
finally:
|
||||
if s:
|
||||
s.close()
|
||||
s, unused_addr = self._socket.accept()
|
||||
resp_buffer = ''
|
||||
resp_1 = s.recv(6).decode()
|
||||
if '\n' not in resp_1:
|
||||
raise Exception('invalid request data')
|
||||
nstr, extra = resp_1.split('\n', 1)
|
||||
resp_buffer = extra
|
||||
n = int(nstr)
|
||||
to_read = n - len(extra)
|
||||
if to_read > 0:
|
||||
resp_buffer += s.recv(to_read, socket.MSG_WAITALL)
|
||||
if resp_buffer != CREDENTIAL_INFO_REQUEST_JSON:
|
||||
raise Exception('bad request')
|
||||
l = len(self.response)
|
||||
s.sendall(('%d\n%s' % (l, self.response)).encode())
|
||||
finally:
|
||||
if s:
|
||||
s.close()
|
||||
|
||||
|
||||
class DevshellCredentialsTests(unittest.TestCase):
|
||||
|
||||
def test_signals_no_server(self):
|
||||
self.assertRaises(NoDevshellServer, DevshellCredentials)
|
||||
def test_signals_no_server(self):
|
||||
self.assertRaises(NoDevshellServer, DevshellCredentials)
|
||||
|
||||
def test_request_response(self):
|
||||
with _AuthReferenceServer():
|
||||
response = _SendRecv()
|
||||
self.assertEqual(response.user_email, 'joe@example.com')
|
||||
self.assertEqual(response.project_id, 'fooproj')
|
||||
self.assertEqual(response.access_token, 'sometoken')
|
||||
def test_request_response(self):
|
||||
with _AuthReferenceServer():
|
||||
response = _SendRecv()
|
||||
self.assertEqual(response.user_email, 'joe@example.com')
|
||||
self.assertEqual(response.project_id, 'fooproj')
|
||||
self.assertEqual(response.access_token, 'sometoken')
|
||||
|
||||
def test_no_refresh_token(self):
|
||||
with _AuthReferenceServer():
|
||||
creds = DevshellCredentials()
|
||||
self.assertEquals(None, creds.refresh_token)
|
||||
def test_no_refresh_token(self):
|
||||
with _AuthReferenceServer():
|
||||
creds = DevshellCredentials()
|
||||
self.assertEquals(None, creds.refresh_token)
|
||||
|
||||
def test_reads_credentials(self):
|
||||
with _AuthReferenceServer():
|
||||
creds = DevshellCredentials()
|
||||
self.assertEqual('joe@example.com', creds.user_email)
|
||||
self.assertEqual('fooproj', creds.project_id)
|
||||
self.assertEqual('sometoken', creds.access_token)
|
||||
def test_reads_credentials(self):
|
||||
with _AuthReferenceServer():
|
||||
creds = DevshellCredentials()
|
||||
self.assertEqual('joe@example.com', creds.user_email)
|
||||
self.assertEqual('fooproj', creds.project_id)
|
||||
self.assertEqual('sometoken', creds.access_token)
|
||||
|
||||
def test_handles_skipped_fields(self):
|
||||
with _AuthReferenceServer('["joe@example.com"]'):
|
||||
creds = DevshellCredentials()
|
||||
self.assertEqual('joe@example.com', creds.user_email)
|
||||
self.assertEqual(None, creds.project_id)
|
||||
self.assertEqual(None, creds.access_token)
|
||||
def test_handles_skipped_fields(self):
|
||||
with _AuthReferenceServer('["joe@example.com"]'):
|
||||
creds = DevshellCredentials()
|
||||
self.assertEqual('joe@example.com', creds.user_email)
|
||||
self.assertEqual(None, creds.project_id)
|
||||
self.assertEqual(None, creds.access_token)
|
||||
|
||||
def test_handles_tiny_response(self):
|
||||
with _AuthReferenceServer('[]'):
|
||||
creds = DevshellCredentials()
|
||||
self.assertEqual(None, creds.user_email)
|
||||
self.assertEqual(None, creds.project_id)
|
||||
self.assertEqual(None, creds.access_token)
|
||||
def test_handles_tiny_response(self):
|
||||
with _AuthReferenceServer('[]'):
|
||||
creds = DevshellCredentials()
|
||||
self.assertEqual(None, creds.user_email)
|
||||
self.assertEqual(None, creds.project_id)
|
||||
self.assertEqual(None, creds.access_token)
|
||||
|
||||
def test_handles_ignores_extra_fields(self):
|
||||
with _AuthReferenceServer(
|
||||
def test_handles_ignores_extra_fields(self):
|
||||
with _AuthReferenceServer(
|
||||
'["joe@example.com", "fooproj", "sometoken", "extra"]'):
|
||||
creds = DevshellCredentials()
|
||||
self.assertEqual('joe@example.com', creds.user_email)
|
||||
self.assertEqual('fooproj', creds.project_id)
|
||||
self.assertEqual('sometoken', creds.access_token)
|
||||
creds = DevshellCredentials()
|
||||
self.assertEqual('joe@example.com', creds.user_email)
|
||||
self.assertEqual('fooproj', creds.project_id)
|
||||
self.assertEqual('sometoken', creds.access_token)
|
||||
|
||||
def test_refuses_to_save_to_well_known_file(self):
|
||||
ORIGINAL_ISDIR = os.path.isdir
|
||||
try:
|
||||
os.path.isdir = lambda path: True
|
||||
with _AuthReferenceServer():
|
||||
creds = DevshellCredentials()
|
||||
self.assertRaises(NotImplementedError, save_to_well_known_file, creds)
|
||||
finally:
|
||||
os.path.isdir = ORIGINAL_ISDIR
|
||||
def test_refuses_to_save_to_well_known_file(self):
|
||||
ORIGINAL_ISDIR = os.path.isdir
|
||||
try:
|
||||
os.path.isdir = lambda path: True
|
||||
with _AuthReferenceServer():
|
||||
creds = DevshellCredentials()
|
||||
self.assertRaises(NotImplementedError, save_to_well_known_file, creds)
|
||||
finally:
|
||||
os.path.isdir = ORIGINAL_ISDIR
|
||||
|
||||
@@ -14,7 +14,6 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
"""Discovery document tests
|
||||
|
||||
Unit tests for objects created from discovery documents.
|
||||
@@ -31,10 +30,10 @@ import unittest
|
||||
|
||||
# Ensure that if app engine is available, we use the correct django from it
|
||||
try:
|
||||
from google.appengine.dist import use_library
|
||||
use_library('django', '1.5')
|
||||
from google.appengine.dist import use_library
|
||||
use_library('django', '1.5')
|
||||
except ImportError:
|
||||
pass
|
||||
pass
|
||||
|
||||
from oauth2client.client import Credentials
|
||||
from oauth2client.client import Flow
|
||||
@@ -51,39 +50,39 @@ from oauth2client.django_orm import FlowField
|
||||
|
||||
|
||||
class TestCredentialsField(unittest.TestCase):
|
||||
def setUp(self):
|
||||
self.field = CredentialsField()
|
||||
self.credentials = Credentials()
|
||||
self.pickle = base64.b64encode(pickle.dumps(self.credentials))
|
||||
def setUp(self):
|
||||
self.field = CredentialsField()
|
||||
self.credentials = Credentials()
|
||||
self.pickle = base64.b64encode(pickle.dumps(self.credentials))
|
||||
|
||||
def test_field_is_text(self):
|
||||
self.assertEquals(self.field.get_internal_type(), 'TextField')
|
||||
def test_field_is_text(self):
|
||||
self.assertEquals(self.field.get_internal_type(), 'TextField')
|
||||
|
||||
def test_field_unpickled(self):
|
||||
self.assertTrue(isinstance(self.field.to_python(self.pickle), Credentials))
|
||||
def test_field_unpickled(self):
|
||||
self.assertTrue(isinstance(self.field.to_python(self.pickle), Credentials))
|
||||
|
||||
def test_field_pickled(self):
|
||||
prep_value = self.field.get_db_prep_value(self.credentials,
|
||||
def test_field_pickled(self):
|
||||
prep_value = self.field.get_db_prep_value(self.credentials,
|
||||
connection=None)
|
||||
self.assertEqual(prep_value, self.pickle)
|
||||
self.assertEqual(prep_value, self.pickle)
|
||||
|
||||
|
||||
class TestFlowField(unittest.TestCase):
|
||||
def setUp(self):
|
||||
self.field = FlowField()
|
||||
self.flow = Flow()
|
||||
self.pickle = base64.b64encode(pickle.dumps(self.flow))
|
||||
def setUp(self):
|
||||
self.field = FlowField()
|
||||
self.flow = Flow()
|
||||
self.pickle = base64.b64encode(pickle.dumps(self.flow))
|
||||
|
||||
def test_field_is_text(self):
|
||||
self.assertEquals(self.field.get_internal_type(), 'TextField')
|
||||
def test_field_is_text(self):
|
||||
self.assertEquals(self.field.get_internal_type(), 'TextField')
|
||||
|
||||
def test_field_unpickled(self):
|
||||
self.assertTrue(isinstance(self.field.to_python(self.pickle), Flow))
|
||||
def test_field_unpickled(self):
|
||||
self.assertTrue(isinstance(self.field.to_python(self.pickle), Flow))
|
||||
|
||||
def test_field_pickled(self):
|
||||
prep_value = self.field.get_db_prep_value(self.flow, connection=None)
|
||||
self.assertEqual(prep_value, self.pickle)
|
||||
def test_field_pickled(self):
|
||||
prep_value = self.field.get_db_prep_value(self.flow, connection=None)
|
||||
self.assertEqual(prep_value, self.pickle)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
unittest.main()
|
||||
|
||||
@@ -14,7 +14,6 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
"""Oauth2client.file tests
|
||||
|
||||
Unit tests for oauth2client.file
|
||||
@@ -42,363 +41,362 @@ from oauth2client.client import AccessTokenCredentials
|
||||
from oauth2client.client import OAuth2Credentials
|
||||
from six.moves import http_client
|
||||
try:
|
||||
# Python2
|
||||
from future_builtins import oct
|
||||
# Python2
|
||||
from future_builtins import oct
|
||||
except:
|
||||
pass
|
||||
|
||||
pass
|
||||
|
||||
FILENAME = tempfile.mktemp('oauth2client_test.data')
|
||||
|
||||
|
||||
class OAuth2ClientFileTests(unittest.TestCase):
|
||||
|
||||
def tearDown(self):
|
||||
try:
|
||||
os.unlink(FILENAME)
|
||||
except OSError:
|
||||
pass
|
||||
def tearDown(self):
|
||||
try:
|
||||
os.unlink(FILENAME)
|
||||
except OSError:
|
||||
pass
|
||||
|
||||
def setUp(self):
|
||||
try:
|
||||
os.unlink(FILENAME)
|
||||
except OSError:
|
||||
pass
|
||||
def setUp(self):
|
||||
try:
|
||||
os.unlink(FILENAME)
|
||||
except OSError:
|
||||
pass
|
||||
|
||||
def create_test_credentials(self, client_id='some_client_id',
|
||||
def create_test_credentials(self, client_id='some_client_id',
|
||||
expiration=None):
|
||||
access_token = 'foo'
|
||||
client_secret = 'cOuDdkfjxxnv+'
|
||||
refresh_token = '1/0/a.df219fjls0'
|
||||
token_expiry = expiration or datetime.datetime.utcnow()
|
||||
token_uri = 'https://www.google.com/accounts/o8/oauth2/token'
|
||||
user_agent = 'refresh_checker/1.0'
|
||||
access_token = 'foo'
|
||||
client_secret = 'cOuDdkfjxxnv+'
|
||||
refresh_token = '1/0/a.df219fjls0'
|
||||
token_expiry = expiration or datetime.datetime.utcnow()
|
||||
token_uri = 'https://www.google.com/accounts/o8/oauth2/token'
|
||||
user_agent = 'refresh_checker/1.0'
|
||||
|
||||
credentials = OAuth2Credentials(
|
||||
credentials = OAuth2Credentials(
|
||||
access_token, client_id, client_secret,
|
||||
refresh_token, token_expiry, token_uri,
|
||||
user_agent)
|
||||
return credentials
|
||||
return credentials
|
||||
|
||||
def test_non_existent_file_storage(self):
|
||||
s = file.Storage(FILENAME)
|
||||
credentials = s.get()
|
||||
self.assertEquals(None, credentials)
|
||||
def test_non_existent_file_storage(self):
|
||||
s = file.Storage(FILENAME)
|
||||
credentials = s.get()
|
||||
self.assertEquals(None, credentials)
|
||||
|
||||
def test_no_sym_link_credentials(self):
|
||||
if hasattr(os, 'symlink'):
|
||||
SYMFILENAME = FILENAME + '.sym'
|
||||
os.symlink(FILENAME, SYMFILENAME)
|
||||
s = file.Storage(SYMFILENAME)
|
||||
try:
|
||||
s.get()
|
||||
self.fail('Should have raised an exception.')
|
||||
except file.CredentialsFileSymbolicLinkError:
|
||||
pass
|
||||
finally:
|
||||
os.unlink(SYMFILENAME)
|
||||
def test_no_sym_link_credentials(self):
|
||||
if hasattr(os, 'symlink'):
|
||||
SYMFILENAME = FILENAME + '.sym'
|
||||
os.symlink(FILENAME, SYMFILENAME)
|
||||
s = file.Storage(SYMFILENAME)
|
||||
try:
|
||||
s.get()
|
||||
self.fail('Should have raised an exception.')
|
||||
except file.CredentialsFileSymbolicLinkError:
|
||||
pass
|
||||
finally:
|
||||
os.unlink(SYMFILENAME)
|
||||
|
||||
def test_pickle_and_json_interop(self):
|
||||
# Write a file with a pickled OAuth2Credentials.
|
||||
credentials = self.create_test_credentials()
|
||||
def test_pickle_and_json_interop(self):
|
||||
# Write a file with a pickled OAuth2Credentials.
|
||||
credentials = self.create_test_credentials()
|
||||
|
||||
f = open(FILENAME, 'wb')
|
||||
pickle.dump(credentials, f)
|
||||
f.close()
|
||||
f = open(FILENAME, 'wb')
|
||||
pickle.dump(credentials, f)
|
||||
f.close()
|
||||
|
||||
# Storage should be not be able to read that object, as the capability to
|
||||
# read and write credentials as pickled objects has been removed.
|
||||
s = file.Storage(FILENAME)
|
||||
read_credentials = s.get()
|
||||
self.assertEquals(None, read_credentials)
|
||||
# Storage should be not be able to read that object, as the capability to
|
||||
# read and write credentials as pickled objects has been removed.
|
||||
s = file.Storage(FILENAME)
|
||||
read_credentials = s.get()
|
||||
self.assertEquals(None, read_credentials)
|
||||
|
||||
# Now write it back out and confirm it has been rewritten as JSON
|
||||
s.put(credentials)
|
||||
with open(FILENAME) as f:
|
||||
data = json.load(f)
|
||||
# Now write it back out and confirm it has been rewritten as JSON
|
||||
s.put(credentials)
|
||||
with open(FILENAME) as f:
|
||||
data = json.load(f)
|
||||
|
||||
self.assertEquals(data['access_token'], 'foo')
|
||||
self.assertEquals(data['_class'], 'OAuth2Credentials')
|
||||
self.assertEquals(data['_module'], OAuth2Credentials.__module__)
|
||||
self.assertEquals(data['access_token'], 'foo')
|
||||
self.assertEquals(data['_class'], 'OAuth2Credentials')
|
||||
self.assertEquals(data['_module'], OAuth2Credentials.__module__)
|
||||
|
||||
def test_token_refresh_store_expired(self):
|
||||
expiration = datetime.datetime.utcnow() - datetime.timedelta(minutes=15)
|
||||
credentials = self.create_test_credentials(expiration=expiration)
|
||||
def test_token_refresh_store_expired(self):
|
||||
expiration = datetime.datetime.utcnow() - datetime.timedelta(minutes=15)
|
||||
credentials = self.create_test_credentials(expiration=expiration)
|
||||
|
||||
s = file.Storage(FILENAME)
|
||||
s.put(credentials)
|
||||
credentials = s.get()
|
||||
new_cred = copy.copy(credentials)
|
||||
new_cred.access_token = 'bar'
|
||||
s.put(new_cred)
|
||||
s = file.Storage(FILENAME)
|
||||
s.put(credentials)
|
||||
credentials = s.get()
|
||||
new_cred = copy.copy(credentials)
|
||||
new_cred.access_token = 'bar'
|
||||
s.put(new_cred)
|
||||
|
||||
access_token = '1/3w'
|
||||
token_response = {'access_token': access_token, 'expires_in': 3600}
|
||||
http = HttpMockSequence([
|
||||
access_token = '1/3w'
|
||||
token_response = {'access_token': access_token, 'expires_in': 3600}
|
||||
http = HttpMockSequence([
|
||||
({'status': '200'}, json.dumps(token_response).encode('utf-8')),
|
||||
])
|
||||
])
|
||||
|
||||
credentials._refresh(http.request)
|
||||
self.assertEquals(credentials.access_token, access_token)
|
||||
credentials._refresh(http.request)
|
||||
self.assertEquals(credentials.access_token, access_token)
|
||||
|
||||
def test_token_refresh_store_expires_soon(self):
|
||||
# Tests the case where an access token that is valid when it is read from
|
||||
# the store expires before the original request succeeds.
|
||||
expiration = datetime.datetime.utcnow() + datetime.timedelta(minutes=15)
|
||||
credentials = self.create_test_credentials(expiration=expiration)
|
||||
def test_token_refresh_store_expires_soon(self):
|
||||
# Tests the case where an access token that is valid when it is read from
|
||||
# the store expires before the original request succeeds.
|
||||
expiration = datetime.datetime.utcnow() + datetime.timedelta(minutes=15)
|
||||
credentials = self.create_test_credentials(expiration=expiration)
|
||||
|
||||
s = file.Storage(FILENAME)
|
||||
s.put(credentials)
|
||||
credentials = s.get()
|
||||
new_cred = copy.copy(credentials)
|
||||
new_cred.access_token = 'bar'
|
||||
s.put(new_cred)
|
||||
s = file.Storage(FILENAME)
|
||||
s.put(credentials)
|
||||
credentials = s.get()
|
||||
new_cred = copy.copy(credentials)
|
||||
new_cred.access_token = 'bar'
|
||||
s.put(new_cred)
|
||||
|
||||
access_token = '1/3w'
|
||||
token_response = {'access_token': access_token, 'expires_in': 3600}
|
||||
http = HttpMockSequence([
|
||||
access_token = '1/3w'
|
||||
token_response = {'access_token': access_token, 'expires_in': 3600}
|
||||
http = HttpMockSequence([
|
||||
({'status': str(http_client.UNAUTHORIZED)}, b'Initial token expired'),
|
||||
({'status': str(http_client.UNAUTHORIZED)}, b'Store token expired'),
|
||||
({'status': str(http_client.OK)},
|
||||
json.dumps(token_response).encode('utf-8')),
|
||||
({'status': str(http_client.OK)},
|
||||
b'Valid response to original request')
|
||||
])
|
||||
])
|
||||
|
||||
credentials.authorize(http)
|
||||
http.request('https://example.com')
|
||||
self.assertEqual(credentials.access_token, access_token)
|
||||
credentials.authorize(http)
|
||||
http.request('https://example.com')
|
||||
self.assertEqual(credentials.access_token, access_token)
|
||||
|
||||
def test_token_refresh_good_store(self):
|
||||
expiration = datetime.datetime.utcnow() + datetime.timedelta(minutes=15)
|
||||
credentials = self.create_test_credentials(expiration=expiration)
|
||||
def test_token_refresh_good_store(self):
|
||||
expiration = datetime.datetime.utcnow() + datetime.timedelta(minutes=15)
|
||||
credentials = self.create_test_credentials(expiration=expiration)
|
||||
|
||||
s = file.Storage(FILENAME)
|
||||
s.put(credentials)
|
||||
credentials = s.get()
|
||||
new_cred = copy.copy(credentials)
|
||||
new_cred.access_token = 'bar'
|
||||
s.put(new_cred)
|
||||
s = file.Storage(FILENAME)
|
||||
s.put(credentials)
|
||||
credentials = s.get()
|
||||
new_cred = copy.copy(credentials)
|
||||
new_cred.access_token = 'bar'
|
||||
s.put(new_cred)
|
||||
|
||||
credentials._refresh(lambda x: x)
|
||||
self.assertEquals(credentials.access_token, 'bar')
|
||||
credentials._refresh(lambda x: x)
|
||||
self.assertEquals(credentials.access_token, 'bar')
|
||||
|
||||
def test_token_refresh_stream_body(self):
|
||||
expiration = datetime.datetime.utcnow() + datetime.timedelta(minutes=15)
|
||||
credentials = self.create_test_credentials(expiration=expiration)
|
||||
def test_token_refresh_stream_body(self):
|
||||
expiration = datetime.datetime.utcnow() + datetime.timedelta(minutes=15)
|
||||
credentials = self.create_test_credentials(expiration=expiration)
|
||||
|
||||
s = file.Storage(FILENAME)
|
||||
s.put(credentials)
|
||||
credentials = s.get()
|
||||
new_cred = copy.copy(credentials)
|
||||
new_cred.access_token = 'bar'
|
||||
s.put(new_cred)
|
||||
s = file.Storage(FILENAME)
|
||||
s.put(credentials)
|
||||
credentials = s.get()
|
||||
new_cred = copy.copy(credentials)
|
||||
new_cred.access_token = 'bar'
|
||||
s.put(new_cred)
|
||||
|
||||
valid_access_token = '1/3w'
|
||||
token_response = {'access_token': valid_access_token, 'expires_in': 3600}
|
||||
http = HttpMockSequence([
|
||||
valid_access_token = '1/3w'
|
||||
token_response = {'access_token': valid_access_token, 'expires_in': 3600}
|
||||
http = HttpMockSequence([
|
||||
({'status': str(http_client.UNAUTHORIZED)}, b'Initial token expired'),
|
||||
({'status': str(http_client.UNAUTHORIZED)}, b'Store token expired'),
|
||||
({'status': str(http_client.OK)},
|
||||
json.dumps(token_response).encode('utf-8')),
|
||||
({'status': str(http_client.OK)}, 'echo_request_body')
|
||||
])
|
||||
])
|
||||
|
||||
body = six.StringIO('streaming body')
|
||||
body = six.StringIO('streaming body')
|
||||
|
||||
credentials.authorize(http)
|
||||
_, content = http.request('https://example.com', body=body)
|
||||
self.assertEqual(content, 'streaming body')
|
||||
self.assertEqual(credentials.access_token, valid_access_token)
|
||||
credentials.authorize(http)
|
||||
_, content = http.request('https://example.com', body=body)
|
||||
self.assertEqual(content, 'streaming body')
|
||||
self.assertEqual(credentials.access_token, valid_access_token)
|
||||
|
||||
def test_credentials_delete(self):
|
||||
credentials = self.create_test_credentials()
|
||||
def test_credentials_delete(self):
|
||||
credentials = self.create_test_credentials()
|
||||
|
||||
s = file.Storage(FILENAME)
|
||||
s.put(credentials)
|
||||
credentials = s.get()
|
||||
self.assertNotEquals(None, credentials)
|
||||
s.delete()
|
||||
credentials = s.get()
|
||||
self.assertEquals(None, credentials)
|
||||
s = file.Storage(FILENAME)
|
||||
s.put(credentials)
|
||||
credentials = s.get()
|
||||
self.assertNotEquals(None, credentials)
|
||||
s.delete()
|
||||
credentials = s.get()
|
||||
self.assertEquals(None, credentials)
|
||||
|
||||
def test_access_token_credentials(self):
|
||||
access_token = 'foo'
|
||||
user_agent = 'refresh_checker/1.0'
|
||||
def test_access_token_credentials(self):
|
||||
access_token = 'foo'
|
||||
user_agent = 'refresh_checker/1.0'
|
||||
|
||||
credentials = AccessTokenCredentials(access_token, user_agent)
|
||||
credentials = AccessTokenCredentials(access_token, user_agent)
|
||||
|
||||
s = file.Storage(FILENAME)
|
||||
credentials = s.put(credentials)
|
||||
credentials = s.get()
|
||||
s = file.Storage(FILENAME)
|
||||
credentials = s.put(credentials)
|
||||
credentials = s.get()
|
||||
|
||||
self.assertNotEquals(None, credentials)
|
||||
self.assertEquals('foo', credentials.access_token)
|
||||
mode = os.stat(FILENAME).st_mode
|
||||
self.assertNotEquals(None, credentials)
|
||||
self.assertEquals('foo', credentials.access_token)
|
||||
mode = os.stat(FILENAME).st_mode
|
||||
|
||||
if os.name == 'posix':
|
||||
self.assertEquals('0o600', oct(stat.S_IMODE(os.stat(FILENAME).st_mode)))
|
||||
if os.name == 'posix':
|
||||
self.assertEquals('0o600', oct(stat.S_IMODE(os.stat(FILENAME).st_mode)))
|
||||
|
||||
def test_read_only_file_fail_lock(self):
|
||||
credentials = self.create_test_credentials()
|
||||
def test_read_only_file_fail_lock(self):
|
||||
credentials = self.create_test_credentials()
|
||||
|
||||
open(FILENAME, 'a+b').close()
|
||||
os.chmod(FILENAME, 0o400)
|
||||
open(FILENAME, 'a+b').close()
|
||||
os.chmod(FILENAME, 0o400)
|
||||
|
||||
store = multistore_file.get_credential_storage(
|
||||
store = multistore_file.get_credential_storage(
|
||||
FILENAME,
|
||||
credentials.client_id,
|
||||
credentials.user_agent,
|
||||
['some-scope', 'some-other-scope'])
|
||||
|
||||
store.put(credentials)
|
||||
if os.name == 'posix':
|
||||
self.assertTrue(store._multistore._read_only)
|
||||
os.chmod(FILENAME, 0o600)
|
||||
store.put(credentials)
|
||||
if os.name == 'posix':
|
||||
self.assertTrue(store._multistore._read_only)
|
||||
os.chmod(FILENAME, 0o600)
|
||||
|
||||
def test_multistore_no_symbolic_link_files(self):
|
||||
if hasattr(os, 'symlink'):
|
||||
SYMFILENAME = FILENAME + 'sym'
|
||||
os.symlink(FILENAME, SYMFILENAME)
|
||||
store = multistore_file.get_credential_storage(
|
||||
def test_multistore_no_symbolic_link_files(self):
|
||||
if hasattr(os, 'symlink'):
|
||||
SYMFILENAME = FILENAME + 'sym'
|
||||
os.symlink(FILENAME, SYMFILENAME)
|
||||
store = multistore_file.get_credential_storage(
|
||||
SYMFILENAME,
|
||||
'some_client_id',
|
||||
'user-agent/1.0',
|
||||
['some-scope', 'some-other-scope'])
|
||||
try:
|
||||
store.get()
|
||||
self.fail('Should have raised an exception.')
|
||||
except locked_file.CredentialsFileSymbolicLinkError:
|
||||
pass
|
||||
finally:
|
||||
os.unlink(SYMFILENAME)
|
||||
try:
|
||||
store.get()
|
||||
self.fail('Should have raised an exception.')
|
||||
except locked_file.CredentialsFileSymbolicLinkError:
|
||||
pass
|
||||
finally:
|
||||
os.unlink(SYMFILENAME)
|
||||
|
||||
def test_multistore_non_existent_file(self):
|
||||
store = multistore_file.get_credential_storage(
|
||||
def test_multistore_non_existent_file(self):
|
||||
store = multistore_file.get_credential_storage(
|
||||
FILENAME,
|
||||
'some_client_id',
|
||||
'user-agent/1.0',
|
||||
['some-scope', 'some-other-scope'])
|
||||
|
||||
credentials = store.get()
|
||||
self.assertEquals(None, credentials)
|
||||
credentials = store.get()
|
||||
self.assertEquals(None, credentials)
|
||||
|
||||
def test_multistore_file(self):
|
||||
credentials = self.create_test_credentials()
|
||||
def test_multistore_file(self):
|
||||
credentials = self.create_test_credentials()
|
||||
|
||||
store = multistore_file.get_credential_storage(
|
||||
store = multistore_file.get_credential_storage(
|
||||
FILENAME,
|
||||
credentials.client_id,
|
||||
credentials.user_agent,
|
||||
['some-scope', 'some-other-scope'])
|
||||
|
||||
store.put(credentials)
|
||||
credentials = store.get()
|
||||
store.put(credentials)
|
||||
credentials = store.get()
|
||||
|
||||
self.assertNotEquals(None, credentials)
|
||||
self.assertEquals('foo', credentials.access_token)
|
||||
self.assertNotEquals(None, credentials)
|
||||
self.assertEquals('foo', credentials.access_token)
|
||||
|
||||
store.delete()
|
||||
credentials = store.get()
|
||||
store.delete()
|
||||
credentials = store.get()
|
||||
|
||||
self.assertEquals(None, credentials)
|
||||
self.assertEquals(None, credentials)
|
||||
|
||||
if os.name == 'posix':
|
||||
self.assertEquals('0o600', oct(stat.S_IMODE(os.stat(FILENAME).st_mode)))
|
||||
if os.name == 'posix':
|
||||
self.assertEquals('0o600', oct(stat.S_IMODE(os.stat(FILENAME).st_mode)))
|
||||
|
||||
def test_multistore_file_custom_key(self):
|
||||
credentials = self.create_test_credentials()
|
||||
def test_multistore_file_custom_key(self):
|
||||
credentials = self.create_test_credentials()
|
||||
|
||||
custom_key = {'myapp': 'testing', 'clientid': 'some client'}
|
||||
store = multistore_file.get_credential_storage_custom_key(
|
||||
custom_key = {'myapp': 'testing', 'clientid': 'some client'}
|
||||
store = multistore_file.get_credential_storage_custom_key(
|
||||
FILENAME, custom_key)
|
||||
|
||||
store.put(credentials)
|
||||
stored_credentials = store.get()
|
||||
store.put(credentials)
|
||||
stored_credentials = store.get()
|
||||
|
||||
self.assertNotEquals(None, stored_credentials)
|
||||
self.assertEqual(credentials.access_token, stored_credentials.access_token)
|
||||
self.assertNotEquals(None, stored_credentials)
|
||||
self.assertEqual(credentials.access_token, stored_credentials.access_token)
|
||||
|
||||
store.delete()
|
||||
stored_credentials = store.get()
|
||||
store.delete()
|
||||
stored_credentials = store.get()
|
||||
|
||||
self.assertEquals(None, stored_credentials)
|
||||
self.assertEquals(None, stored_credentials)
|
||||
|
||||
def test_multistore_file_custom_string_key(self):
|
||||
credentials = self.create_test_credentials()
|
||||
def test_multistore_file_custom_string_key(self):
|
||||
credentials = self.create_test_credentials()
|
||||
|
||||
# store with string key
|
||||
store = multistore_file.get_credential_storage_custom_string_key(
|
||||
# store with string key
|
||||
store = multistore_file.get_credential_storage_custom_string_key(
|
||||
FILENAME, 'mykey')
|
||||
|
||||
store.put(credentials)
|
||||
stored_credentials = store.get()
|
||||
store.put(credentials)
|
||||
stored_credentials = store.get()
|
||||
|
||||
self.assertNotEquals(None, stored_credentials)
|
||||
self.assertEqual(credentials.access_token, stored_credentials.access_token)
|
||||
self.assertNotEquals(None, stored_credentials)
|
||||
self.assertEqual(credentials.access_token, stored_credentials.access_token)
|
||||
|
||||
# try retrieving with a dictionary
|
||||
store_dict = multistore_file.get_credential_storage_custom_string_key(
|
||||
# try retrieving with a dictionary
|
||||
store_dict = multistore_file.get_credential_storage_custom_string_key(
|
||||
FILENAME, {'key': 'mykey'})
|
||||
stored_credentials = store.get()
|
||||
self.assertNotEquals(None, stored_credentials)
|
||||
self.assertEqual(credentials.access_token, stored_credentials.access_token)
|
||||
stored_credentials = store.get()
|
||||
self.assertNotEquals(None, stored_credentials)
|
||||
self.assertEqual(credentials.access_token, stored_credentials.access_token)
|
||||
|
||||
store.delete()
|
||||
stored_credentials = store.get()
|
||||
store.delete()
|
||||
stored_credentials = store.get()
|
||||
|
||||
self.assertEquals(None, stored_credentials)
|
||||
self.assertEquals(None, stored_credentials)
|
||||
|
||||
def test_multistore_file_backwards_compatibility(self):
|
||||
credentials = self.create_test_credentials()
|
||||
scopes = ['scope1', 'scope2']
|
||||
def test_multistore_file_backwards_compatibility(self):
|
||||
credentials = self.create_test_credentials()
|
||||
scopes = ['scope1', 'scope2']
|
||||
|
||||
# store the credentials using the legacy key method
|
||||
store = multistore_file.get_credential_storage(
|
||||
# store the credentials using the legacy key method
|
||||
store = multistore_file.get_credential_storage(
|
||||
FILENAME, 'client_id', 'user_agent', scopes)
|
||||
store.put(credentials)
|
||||
store.put(credentials)
|
||||
|
||||
# retrieve the credentials using a custom key that matches the legacy key
|
||||
key = {'clientId': 'client_id', 'userAgent': 'user_agent',
|
||||
# retrieve the credentials using a custom key that matches the legacy key
|
||||
key = {'clientId': 'client_id', 'userAgent': 'user_agent',
|
||||
'scope': util.scopes_to_string(scopes)}
|
||||
store = multistore_file.get_credential_storage_custom_key(FILENAME, key)
|
||||
stored_credentials = store.get()
|
||||
store = multistore_file.get_credential_storage_custom_key(FILENAME, key)
|
||||
stored_credentials = store.get()
|
||||
|
||||
self.assertEqual(credentials.access_token, stored_credentials.access_token)
|
||||
self.assertEqual(credentials.access_token, stored_credentials.access_token)
|
||||
|
||||
def test_multistore_file_get_all_keys(self):
|
||||
# start with no keys
|
||||
keys = multistore_file.get_all_credential_keys(FILENAME)
|
||||
self.assertEquals([], keys)
|
||||
def test_multistore_file_get_all_keys(self):
|
||||
# start with no keys
|
||||
keys = multistore_file.get_all_credential_keys(FILENAME)
|
||||
self.assertEquals([], keys)
|
||||
|
||||
# store credentials
|
||||
credentials = self.create_test_credentials(client_id='client1')
|
||||
custom_key = {'myapp': 'testing', 'clientid': 'client1'}
|
||||
store1 = multistore_file.get_credential_storage_custom_key(
|
||||
# store credentials
|
||||
credentials = self.create_test_credentials(client_id='client1')
|
||||
custom_key = {'myapp': 'testing', 'clientid': 'client1'}
|
||||
store1 = multistore_file.get_credential_storage_custom_key(
|
||||
FILENAME, custom_key)
|
||||
store1.put(credentials)
|
||||
store1.put(credentials)
|
||||
|
||||
keys = multistore_file.get_all_credential_keys(FILENAME)
|
||||
self.assertEquals([custom_key], keys)
|
||||
keys = multistore_file.get_all_credential_keys(FILENAME)
|
||||
self.assertEquals([custom_key], keys)
|
||||
|
||||
# store more credentials
|
||||
credentials = self.create_test_credentials(client_id='client2')
|
||||
string_key = 'string_key'
|
||||
store2 = multistore_file.get_credential_storage_custom_string_key(
|
||||
# store more credentials
|
||||
credentials = self.create_test_credentials(client_id='client2')
|
||||
string_key = 'string_key'
|
||||
store2 = multistore_file.get_credential_storage_custom_string_key(
|
||||
FILENAME, string_key)
|
||||
store2.put(credentials)
|
||||
store2.put(credentials)
|
||||
|
||||
keys = multistore_file.get_all_credential_keys(FILENAME)
|
||||
self.assertEquals(2, len(keys))
|
||||
self.assertTrue(custom_key in keys)
|
||||
self.assertTrue({'key': string_key} in keys)
|
||||
keys = multistore_file.get_all_credential_keys(FILENAME)
|
||||
self.assertEquals(2, len(keys))
|
||||
self.assertTrue(custom_key in keys)
|
||||
self.assertTrue({'key': string_key} in keys)
|
||||
|
||||
# back to no keys
|
||||
store1.delete()
|
||||
store2.delete()
|
||||
keys = multistore_file.get_all_credential_keys(FILENAME)
|
||||
self.assertEquals([], keys)
|
||||
# back to no keys
|
||||
store1.delete()
|
||||
store2.delete()
|
||||
keys = multistore_file.get_all_credential_keys(FILENAME)
|
||||
self.assertEquals([], keys)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
unittest.main()
|
||||
|
||||
@@ -12,7 +12,6 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
"""Unit tests for the Flask utilities"""
|
||||
|
||||
__author__ = 'jonwayne@google.com (Jon Wayne Parrott)'
|
||||
@@ -35,6 +34,7 @@ from oauth2client.client import OAuth2Credentials
|
||||
|
||||
class Http2Mock(object):
|
||||
"""Mock httplib2.Http for code exchange / refresh"""
|
||||
|
||||
def __init__(self, status=httplib.OK, **kwargs):
|
||||
self.status = status
|
||||
self.content = {
|
||||
|
||||
@@ -12,7 +12,6 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
"""Tests for oauth2client.gce.
|
||||
|
||||
Unit tests for oauth2client.gce.
|
||||
@@ -36,86 +35,86 @@ from oauth2client.gce import AppAssertionCredentials
|
||||
|
||||
class AssertionCredentialsTests(unittest.TestCase):
|
||||
|
||||
def _refresh_success_helper(self, bytes_response=False):
|
||||
access_token = u'this-is-a-token'
|
||||
return_val = json.dumps({u'accessToken': access_token})
|
||||
if bytes_response:
|
||||
return_val = _to_bytes(return_val)
|
||||
http = mock.MagicMock()
|
||||
http.request = mock.MagicMock(
|
||||
def _refresh_success_helper(self, bytes_response=False):
|
||||
access_token = u'this-is-a-token'
|
||||
return_val = json.dumps({u'accessToken': access_token})
|
||||
if bytes_response:
|
||||
return_val = _to_bytes(return_val)
|
||||
http = mock.MagicMock()
|
||||
http.request = mock.MagicMock(
|
||||
return_value=(mock.Mock(status=200), return_val))
|
||||
|
||||
scopes = ['http://example.com/a', 'http://example.com/b']
|
||||
credentials = AppAssertionCredentials(scope=scopes)
|
||||
self.assertEquals(None, credentials.access_token)
|
||||
credentials.refresh(http)
|
||||
self.assertEquals(access_token, credentials.access_token)
|
||||
scopes = ['http://example.com/a', 'http://example.com/b']
|
||||
credentials = AppAssertionCredentials(scope=scopes)
|
||||
self.assertEquals(None, credentials.access_token)
|
||||
credentials.refresh(http)
|
||||
self.assertEquals(access_token, credentials.access_token)
|
||||
|
||||
base_metadata_uri = ('http://metadata.google.internal/0.1/meta-data/'
|
||||
base_metadata_uri = ('http://metadata.google.internal/0.1/meta-data/'
|
||||
'service-accounts/default/acquire')
|
||||
escaped_scopes = urllib.parse.quote(' '.join(scopes), safe='')
|
||||
request_uri = base_metadata_uri + '?scope=' + escaped_scopes
|
||||
http.request.assert_called_once_with(request_uri)
|
||||
escaped_scopes = urllib.parse.quote(' '.join(scopes), safe='')
|
||||
request_uri = base_metadata_uri + '?scope=' + escaped_scopes
|
||||
http.request.assert_called_once_with(request_uri)
|
||||
|
||||
def test_refresh_success(self):
|
||||
self._refresh_success_helper(bytes_response=False)
|
||||
def test_refresh_success(self):
|
||||
self._refresh_success_helper(bytes_response=False)
|
||||
|
||||
def test_refresh_success_bytes(self):
|
||||
self._refresh_success_helper(bytes_response=True)
|
||||
def test_refresh_success_bytes(self):
|
||||
self._refresh_success_helper(bytes_response=True)
|
||||
|
||||
def test_fail_refresh(self):
|
||||
http = mock.MagicMock()
|
||||
http.request = mock.MagicMock(return_value=(mock.Mock(status=400), '{}'))
|
||||
def test_fail_refresh(self):
|
||||
http = mock.MagicMock()
|
||||
http.request = mock.MagicMock(return_value=(mock.Mock(status=400), '{}'))
|
||||
|
||||
c = AppAssertionCredentials(scope=['http://example.com/a',
|
||||
c = AppAssertionCredentials(scope=['http://example.com/a',
|
||||
'http://example.com/b'])
|
||||
self.assertRaises(AccessTokenRefreshError, c.refresh, http)
|
||||
self.assertRaises(AccessTokenRefreshError, c.refresh, http)
|
||||
|
||||
def test_to_from_json(self):
|
||||
c = AppAssertionCredentials(scope=['http://example.com/a',
|
||||
def test_to_from_json(self):
|
||||
c = AppAssertionCredentials(scope=['http://example.com/a',
|
||||
'http://example.com/b'])
|
||||
json = c.to_json()
|
||||
c2 = Credentials.new_from_json(json)
|
||||
json = c.to_json()
|
||||
c2 = Credentials.new_from_json(json)
|
||||
|
||||
self.assertEqual(c.access_token, c2.access_token)
|
||||
self.assertEqual(c.access_token, c2.access_token)
|
||||
|
||||
def test_create_scoped_required_without_scopes(self):
|
||||
credentials = AppAssertionCredentials([])
|
||||
self.assertTrue(credentials.create_scoped_required())
|
||||
def test_create_scoped_required_without_scopes(self):
|
||||
credentials = AppAssertionCredentials([])
|
||||
self.assertTrue(credentials.create_scoped_required())
|
||||
|
||||
def test_create_scoped_required_with_scopes(self):
|
||||
credentials = AppAssertionCredentials(['dummy_scope'])
|
||||
self.assertFalse(credentials.create_scoped_required())
|
||||
def test_create_scoped_required_with_scopes(self):
|
||||
credentials = AppAssertionCredentials(['dummy_scope'])
|
||||
self.assertFalse(credentials.create_scoped_required())
|
||||
|
||||
def test_create_scoped(self):
|
||||
credentials = AppAssertionCredentials([])
|
||||
new_credentials = credentials.create_scoped(['dummy_scope'])
|
||||
self.assertNotEqual(credentials, new_credentials)
|
||||
self.assertTrue(isinstance(new_credentials, AppAssertionCredentials))
|
||||
self.assertEqual('dummy_scope', new_credentials.scope)
|
||||
def test_create_scoped(self):
|
||||
credentials = AppAssertionCredentials([])
|
||||
new_credentials = credentials.create_scoped(['dummy_scope'])
|
||||
self.assertNotEqual(credentials, new_credentials)
|
||||
self.assertTrue(isinstance(new_credentials, AppAssertionCredentials))
|
||||
self.assertEqual('dummy_scope', new_credentials.scope)
|
||||
|
||||
def test_get_access_token(self):
|
||||
http = mock.MagicMock()
|
||||
http.request = mock.MagicMock(
|
||||
def test_get_access_token(self):
|
||||
http = mock.MagicMock()
|
||||
http.request = mock.MagicMock(
|
||||
return_value=(mock.Mock(status=200),
|
||||
'{"accessToken": "this-is-a-token"}'))
|
||||
|
||||
credentials = AppAssertionCredentials(['dummy_scope'])
|
||||
token = credentials.get_access_token(http=http)
|
||||
self.assertEqual('this-is-a-token', token.access_token)
|
||||
self.assertEqual(None, token.expires_in)
|
||||
credentials = AppAssertionCredentials(['dummy_scope'])
|
||||
token = credentials.get_access_token(http=http)
|
||||
self.assertEqual('this-is-a-token', token.access_token)
|
||||
self.assertEqual(None, token.expires_in)
|
||||
|
||||
http.request.assert_called_once_with(
|
||||
http.request.assert_called_once_with(
|
||||
'http://metadata.google.internal/0.1/meta-data/service-accounts/'
|
||||
'default/acquire?scope=dummy_scope')
|
||||
|
||||
def test_save_to_well_known_file(self):
|
||||
import os
|
||||
ORIGINAL_ISDIR = os.path.isdir
|
||||
try:
|
||||
os.path.isdir = lambda path: True
|
||||
credentials = AppAssertionCredentials([])
|
||||
self.assertRaises(NotImplementedError, save_to_well_known_file,
|
||||
def test_save_to_well_known_file(self):
|
||||
import os
|
||||
ORIGINAL_ISDIR = os.path.isdir
|
||||
try:
|
||||
os.path.isdir = lambda path: True
|
||||
credentials = AppAssertionCredentials([])
|
||||
self.assertRaises(NotImplementedError, save_to_well_known_file,
|
||||
credentials)
|
||||
finally:
|
||||
os.path.isdir = ORIGINAL_ISDIR
|
||||
finally:
|
||||
os.path.isdir = ORIGINAL_ISDIR
|
||||
|
||||
@@ -19,9 +19,9 @@ import unittest
|
||||
|
||||
class ImportTest(unittest.TestCase):
|
||||
|
||||
def test_tools_import(self):
|
||||
import oauth2client.tools
|
||||
def test_tools_import(self):
|
||||
import oauth2client.tools
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
unittest.main()
|
||||
|
||||
@@ -14,7 +14,6 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
"""Oauth2client tests
|
||||
|
||||
Unit tests for oauth2client.
|
||||
@@ -42,295 +41,295 @@ from oauth2client.file import Storage
|
||||
|
||||
|
||||
def datafile(filename):
|
||||
f = open(os.path.join(os.path.dirname(__file__), 'data', filename), 'rb')
|
||||
data = f.read()
|
||||
f.close()
|
||||
return data
|
||||
f = open(os.path.join(os.path.dirname(__file__), 'data', filename), 'rb')
|
||||
data = f.read()
|
||||
f.close()
|
||||
return data
|
||||
|
||||
|
||||
class CryptTests(unittest.TestCase):
|
||||
|
||||
def setUp(self):
|
||||
self.format = 'p12'
|
||||
self.signer = crypt.OpenSSLSigner
|
||||
self.verifier = crypt.OpenSSLVerifier
|
||||
def setUp(self):
|
||||
self.format = 'p12'
|
||||
self.signer = crypt.OpenSSLSigner
|
||||
self.verifier = crypt.OpenSSLVerifier
|
||||
|
||||
def test_sign_and_verify(self):
|
||||
self._check_sign_and_verify('privatekey.%s' % self.format)
|
||||
def test_sign_and_verify(self):
|
||||
self._check_sign_and_verify('privatekey.%s' % self.format)
|
||||
|
||||
def test_sign_and_verify_from_converted_pkcs12(self):
|
||||
# Tests that following instructions to convert from PKCS12 to PEM works.
|
||||
if self.format == 'pem':
|
||||
self._check_sign_and_verify('pem_from_pkcs12.pem')
|
||||
def test_sign_and_verify_from_converted_pkcs12(self):
|
||||
# Tests that following instructions to convert from PKCS12 to PEM works.
|
||||
if self.format == 'pem':
|
||||
self._check_sign_and_verify('pem_from_pkcs12.pem')
|
||||
|
||||
def _check_sign_and_verify(self, private_key_file):
|
||||
private_key = datafile(private_key_file)
|
||||
public_key = datafile('publickey.pem')
|
||||
def _check_sign_and_verify(self, private_key_file):
|
||||
private_key = datafile(private_key_file)
|
||||
public_key = datafile('publickey.pem')
|
||||
|
||||
# We pass in a non-bytes password to make sure all branches
|
||||
# are traversed in tests.
|
||||
signer = self.signer.from_string(private_key,
|
||||
# We pass in a non-bytes password to make sure all branches
|
||||
# are traversed in tests.
|
||||
signer = self.signer.from_string(private_key,
|
||||
password=u'notasecret')
|
||||
signature = signer.sign('foo')
|
||||
signature = signer.sign('foo')
|
||||
|
||||
verifier = self.verifier.from_string(public_key, True)
|
||||
self.assertTrue(verifier.verify(b'foo', signature))
|
||||
verifier = self.verifier.from_string(public_key, True)
|
||||
self.assertTrue(verifier.verify(b'foo', signature))
|
||||
|
||||
self.assertFalse(verifier.verify(b'bar', signature))
|
||||
self.assertFalse(verifier.verify(b'foo', b'bad signagure'))
|
||||
self.assertFalse(verifier.verify(b'foo', u'bad signagure'))
|
||||
self.assertFalse(verifier.verify(b'bar', signature))
|
||||
self.assertFalse(verifier.verify(b'foo', b'bad signagure'))
|
||||
self.assertFalse(verifier.verify(b'foo', u'bad signagure'))
|
||||
|
||||
def _check_jwt_failure(self, jwt, expected_error):
|
||||
public_key = datafile('publickey.pem')
|
||||
certs = {'foo': public_key}
|
||||
audience = ('https://www.googleapis.com/auth/id?client_id='
|
||||
def _check_jwt_failure(self, jwt, expected_error):
|
||||
public_key = datafile('publickey.pem')
|
||||
certs = {'foo': public_key}
|
||||
audience = ('https://www.googleapis.com/auth/id?client_id='
|
||||
'external_public_key@testing.gserviceaccount.com')
|
||||
try:
|
||||
crypt.verify_signed_jwt_with_certs(jwt, certs, audience)
|
||||
self.fail()
|
||||
except crypt.AppIdentityError as e:
|
||||
self.assertTrue(expected_error in str(e))
|
||||
try:
|
||||
crypt.verify_signed_jwt_with_certs(jwt, certs, audience)
|
||||
self.fail()
|
||||
except crypt.AppIdentityError as e:
|
||||
self.assertTrue(expected_error in str(e))
|
||||
|
||||
def _create_signed_jwt(self):
|
||||
private_key = datafile('privatekey.%s' % self.format)
|
||||
signer = self.signer.from_string(private_key)
|
||||
audience = 'some_audience_address@testing.gserviceaccount.com'
|
||||
now = int(time.time())
|
||||
def _create_signed_jwt(self):
|
||||
private_key = datafile('privatekey.%s' % self.format)
|
||||
signer = self.signer.from_string(private_key)
|
||||
audience = 'some_audience_address@testing.gserviceaccount.com'
|
||||
now = int(time.time())
|
||||
|
||||
return crypt.make_signed_jwt(signer, {
|
||||
return crypt.make_signed_jwt(signer, {
|
||||
'aud': audience,
|
||||
'iat': now,
|
||||
'exp': now + 300,
|
||||
'user': 'billy bob',
|
||||
'metadata': {'meta': 'data'},
|
||||
})
|
||||
})
|
||||
|
||||
def test_verify_id_token(self):
|
||||
jwt = self._create_signed_jwt()
|
||||
public_key = datafile('publickey.pem')
|
||||
certs = {'foo': public_key}
|
||||
audience = 'some_audience_address@testing.gserviceaccount.com'
|
||||
contents = crypt.verify_signed_jwt_with_certs(jwt, certs, audience)
|
||||
self.assertEqual('billy bob', contents['user'])
|
||||
self.assertEqual('data', contents['metadata']['meta'])
|
||||
def test_verify_id_token(self):
|
||||
jwt = self._create_signed_jwt()
|
||||
public_key = datafile('publickey.pem')
|
||||
certs = {'foo': public_key}
|
||||
audience = 'some_audience_address@testing.gserviceaccount.com'
|
||||
contents = crypt.verify_signed_jwt_with_certs(jwt, certs, audience)
|
||||
self.assertEqual('billy bob', contents['user'])
|
||||
self.assertEqual('data', contents['metadata']['meta'])
|
||||
|
||||
def test_verify_id_token_with_certs_uri(self):
|
||||
jwt = self._create_signed_jwt()
|
||||
def test_verify_id_token_with_certs_uri(self):
|
||||
jwt = self._create_signed_jwt()
|
||||
|
||||
http = HttpMockSequence([
|
||||
http = HttpMockSequence([
|
||||
({'status': '200'}, datafile('certs.json')),
|
||||
])
|
||||
])
|
||||
|
||||
contents = verify_id_token(
|
||||
contents = verify_id_token(
|
||||
jwt, 'some_audience_address@testing.gserviceaccount.com', http=http)
|
||||
self.assertEqual('billy bob', contents['user'])
|
||||
self.assertEqual('data', contents['metadata']['meta'])
|
||||
self.assertEqual('billy bob', contents['user'])
|
||||
self.assertEqual('data', contents['metadata']['meta'])
|
||||
|
||||
def test_verify_id_token_with_certs_uri_fails(self):
|
||||
jwt = self._create_signed_jwt()
|
||||
def test_verify_id_token_with_certs_uri_fails(self):
|
||||
jwt = self._create_signed_jwt()
|
||||
|
||||
http = HttpMockSequence([
|
||||
http = HttpMockSequence([
|
||||
({'status': '404'}, datafile('certs.json')),
|
||||
])
|
||||
])
|
||||
|
||||
self.assertRaises(VerifyJwtTokenError, verify_id_token, jwt,
|
||||
self.assertRaises(VerifyJwtTokenError, verify_id_token, jwt,
|
||||
'some_audience_address@testing.gserviceaccount.com',
|
||||
http=http)
|
||||
|
||||
def test_verify_id_token_bad_tokens(self):
|
||||
private_key = datafile('privatekey.%s' % self.format)
|
||||
def test_verify_id_token_bad_tokens(self):
|
||||
private_key = datafile('privatekey.%s' % self.format)
|
||||
|
||||
# Wrong number of segments
|
||||
self._check_jwt_failure('foo', 'Wrong number of segments')
|
||||
# Wrong number of segments
|
||||
self._check_jwt_failure('foo', 'Wrong number of segments')
|
||||
|
||||
# Not json
|
||||
self._check_jwt_failure('foo.bar.baz', 'Can\'t parse token')
|
||||
# Not json
|
||||
self._check_jwt_failure('foo.bar.baz', 'Can\'t parse token')
|
||||
|
||||
# Bad signature
|
||||
jwt = b'.'.join([b'foo', crypt._urlsafe_b64encode('{"a":"b"}'), b'baz'])
|
||||
self._check_jwt_failure(jwt, 'Invalid token signature')
|
||||
# Bad signature
|
||||
jwt = b'.'.join([b'foo', crypt._urlsafe_b64encode('{"a":"b"}'), b'baz'])
|
||||
self._check_jwt_failure(jwt, 'Invalid token signature')
|
||||
|
||||
# No expiration
|
||||
signer = self.signer.from_string(private_key)
|
||||
audience = ('https:#www.googleapis.com/auth/id?client_id='
|
||||
# No expiration
|
||||
signer = self.signer.from_string(private_key)
|
||||
audience = ('https:#www.googleapis.com/auth/id?client_id='
|
||||
'external_public_key@testing.gserviceaccount.com')
|
||||
jwt = crypt.make_signed_jwt(signer, {
|
||||
jwt = crypt.make_signed_jwt(signer, {
|
||||
'aud': audience,
|
||||
'iat': time.time(),
|
||||
})
|
||||
self._check_jwt_failure(jwt, 'No exp field in token')
|
||||
})
|
||||
self._check_jwt_failure(jwt, 'No exp field in token')
|
||||
|
||||
# No issued at
|
||||
jwt = crypt.make_signed_jwt(signer, {
|
||||
# No issued at
|
||||
jwt = crypt.make_signed_jwt(signer, {
|
||||
'aud': 'audience',
|
||||
'exp': time.time() + 400,
|
||||
})
|
||||
self._check_jwt_failure(jwt, 'No iat field in token')
|
||||
})
|
||||
self._check_jwt_failure(jwt, 'No iat field in token')
|
||||
|
||||
# Too early
|
||||
jwt = crypt.make_signed_jwt(signer, {
|
||||
# Too early
|
||||
jwt = crypt.make_signed_jwt(signer, {
|
||||
'aud': 'audience',
|
||||
'iat': time.time() + 301,
|
||||
'exp': time.time() + 400,
|
||||
})
|
||||
self._check_jwt_failure(jwt, 'Token used too early')
|
||||
})
|
||||
self._check_jwt_failure(jwt, 'Token used too early')
|
||||
|
||||
# Too late
|
||||
jwt = crypt.make_signed_jwt(signer, {
|
||||
# Too late
|
||||
jwt = crypt.make_signed_jwt(signer, {
|
||||
'aud': 'audience',
|
||||
'iat': time.time() - 500,
|
||||
'exp': time.time() - 301,
|
||||
})
|
||||
self._check_jwt_failure(jwt, 'Token used too late')
|
||||
})
|
||||
self._check_jwt_failure(jwt, 'Token used too late')
|
||||
|
||||
# Wrong target
|
||||
jwt = crypt.make_signed_jwt(signer, {
|
||||
# Wrong target
|
||||
jwt = crypt.make_signed_jwt(signer, {
|
||||
'aud': 'somebody else',
|
||||
'iat': time.time(),
|
||||
'exp': time.time() + 300,
|
||||
})
|
||||
self._check_jwt_failure(jwt, 'Wrong recipient')
|
||||
})
|
||||
self._check_jwt_failure(jwt, 'Wrong recipient')
|
||||
|
||||
def test_from_string_non_509_cert(self):
|
||||
# Use a private key instead of a certificate to test the other branch
|
||||
# of from_string().
|
||||
public_key = datafile('privatekey.pem')
|
||||
verifier = self.verifier.from_string(public_key, is_x509_cert=False)
|
||||
self.assertTrue(isinstance(verifier, self.verifier))
|
||||
def test_from_string_non_509_cert(self):
|
||||
# Use a private key instead of a certificate to test the other branch
|
||||
# of from_string().
|
||||
public_key = datafile('privatekey.pem')
|
||||
verifier = self.verifier.from_string(public_key, is_x509_cert=False)
|
||||
self.assertTrue(isinstance(verifier, self.verifier))
|
||||
|
||||
|
||||
class PEMCryptTestsPyCrypto(CryptTests):
|
||||
|
||||
def setUp(self):
|
||||
self.format = 'pem'
|
||||
self.signer = crypt.PyCryptoSigner
|
||||
self.verifier = crypt.PyCryptoVerifier
|
||||
def setUp(self):
|
||||
self.format = 'pem'
|
||||
self.signer = crypt.PyCryptoSigner
|
||||
self.verifier = crypt.PyCryptoVerifier
|
||||
|
||||
|
||||
class PEMCryptTestsOpenSSL(CryptTests):
|
||||
|
||||
def setUp(self):
|
||||
self.format = 'pem'
|
||||
self.signer = crypt.OpenSSLSigner
|
||||
self.verifier = crypt.OpenSSLVerifier
|
||||
def setUp(self):
|
||||
self.format = 'pem'
|
||||
self.signer = crypt.OpenSSLSigner
|
||||
self.verifier = crypt.OpenSSLVerifier
|
||||
|
||||
|
||||
class SignedJwtAssertionCredentialsTests(unittest.TestCase):
|
||||
|
||||
def setUp(self):
|
||||
self.format = 'p12'
|
||||
crypt.Signer = crypt.OpenSSLSigner
|
||||
def setUp(self):
|
||||
self.format = 'p12'
|
||||
crypt.Signer = crypt.OpenSSLSigner
|
||||
|
||||
def test_credentials_good(self):
|
||||
private_key = datafile('privatekey.%s' % self.format)
|
||||
credentials = SignedJwtAssertionCredentials(
|
||||
def test_credentials_good(self):
|
||||
private_key = datafile('privatekey.%s' % self.format)
|
||||
credentials = SignedJwtAssertionCredentials(
|
||||
'some_account@example.com',
|
||||
private_key,
|
||||
scope='read+write',
|
||||
sub='joe@example.org')
|
||||
http = HttpMockSequence([
|
||||
http = HttpMockSequence([
|
||||
({'status': '200'}, b'{"access_token":"1/3w","expires_in":3600}'),
|
||||
({'status': '200'}, 'echo_request_headers'),
|
||||
])
|
||||
http = credentials.authorize(http)
|
||||
resp, content = http.request('http://example.org')
|
||||
self.assertEqual(b'Bearer 1/3w', content[b'Authorization'])
|
||||
])
|
||||
http = credentials.authorize(http)
|
||||
resp, content = http.request('http://example.org')
|
||||
self.assertEqual(b'Bearer 1/3w', content[b'Authorization'])
|
||||
|
||||
def test_credentials_to_from_json(self):
|
||||
private_key = datafile('privatekey.%s' % self.format)
|
||||
credentials = SignedJwtAssertionCredentials(
|
||||
def test_credentials_to_from_json(self):
|
||||
private_key = datafile('privatekey.%s' % self.format)
|
||||
credentials = SignedJwtAssertionCredentials(
|
||||
'some_account@example.com',
|
||||
private_key,
|
||||
scope='read+write',
|
||||
sub='joe@example.org')
|
||||
json = credentials.to_json()
|
||||
restored = Credentials.new_from_json(json)
|
||||
self.assertEqual(credentials.private_key, restored.private_key)
|
||||
self.assertEqual(credentials.private_key_password,
|
||||
json = credentials.to_json()
|
||||
restored = Credentials.new_from_json(json)
|
||||
self.assertEqual(credentials.private_key, restored.private_key)
|
||||
self.assertEqual(credentials.private_key_password,
|
||||
restored.private_key_password)
|
||||
self.assertEqual(credentials.kwargs, restored.kwargs)
|
||||
self.assertEqual(credentials.kwargs, restored.kwargs)
|
||||
|
||||
def _credentials_refresh(self, credentials):
|
||||
http = HttpMockSequence([
|
||||
def _credentials_refresh(self, credentials):
|
||||
http = HttpMockSequence([
|
||||
({'status': '200'}, b'{"access_token":"1/3w","expires_in":3600}'),
|
||||
({'status': '401'}, b''),
|
||||
({'status': '200'}, b'{"access_token":"3/3w","expires_in":3600}'),
|
||||
({'status': '200'}, 'echo_request_headers'),
|
||||
])
|
||||
http = credentials.authorize(http)
|
||||
_, content = http.request('http://example.org')
|
||||
return content
|
||||
])
|
||||
http = credentials.authorize(http)
|
||||
_, content = http.request('http://example.org')
|
||||
return content
|
||||
|
||||
def test_credentials_refresh_without_storage(self):
|
||||
private_key = datafile('privatekey.%s' % self.format)
|
||||
credentials = SignedJwtAssertionCredentials(
|
||||
def test_credentials_refresh_without_storage(self):
|
||||
private_key = datafile('privatekey.%s' % self.format)
|
||||
credentials = SignedJwtAssertionCredentials(
|
||||
'some_account@example.com',
|
||||
private_key,
|
||||
scope='read+write',
|
||||
sub='joe@example.org')
|
||||
|
||||
content = self._credentials_refresh(credentials)
|
||||
content = self._credentials_refresh(credentials)
|
||||
|
||||
self.assertEqual(b'Bearer 3/3w', content[b'Authorization'])
|
||||
self.assertEqual(b'Bearer 3/3w', content[b'Authorization'])
|
||||
|
||||
def test_credentials_refresh_with_storage(self):
|
||||
private_key = datafile('privatekey.%s' % self.format)
|
||||
credentials = SignedJwtAssertionCredentials(
|
||||
def test_credentials_refresh_with_storage(self):
|
||||
private_key = datafile('privatekey.%s' % self.format)
|
||||
credentials = SignedJwtAssertionCredentials(
|
||||
'some_account@example.com',
|
||||
private_key,
|
||||
scope='read+write',
|
||||
sub='joe@example.org')
|
||||
|
||||
(filehandle, filename) = tempfile.mkstemp()
|
||||
os.close(filehandle)
|
||||
store = Storage(filename)
|
||||
store.put(credentials)
|
||||
credentials.set_store(store)
|
||||
(filehandle, filename) = tempfile.mkstemp()
|
||||
os.close(filehandle)
|
||||
store = Storage(filename)
|
||||
store.put(credentials)
|
||||
credentials.set_store(store)
|
||||
|
||||
content = self._credentials_refresh(credentials)
|
||||
content = self._credentials_refresh(credentials)
|
||||
|
||||
self.assertEqual(b'Bearer 3/3w', content[b'Authorization'])
|
||||
os.unlink(filename)
|
||||
self.assertEqual(b'Bearer 3/3w', content[b'Authorization'])
|
||||
os.unlink(filename)
|
||||
|
||||
|
||||
class PEMSignedJwtAssertionCredentialsOpenSSLTests(
|
||||
SignedJwtAssertionCredentialsTests):
|
||||
|
||||
def setUp(self):
|
||||
self.format = 'pem'
|
||||
crypt.Signer = crypt.OpenSSLSigner
|
||||
def setUp(self):
|
||||
self.format = 'pem'
|
||||
crypt.Signer = crypt.OpenSSLSigner
|
||||
|
||||
|
||||
class PEMSignedJwtAssertionCredentialsPyCryptoTests(
|
||||
SignedJwtAssertionCredentialsTests):
|
||||
|
||||
def setUp(self):
|
||||
self.format = 'pem'
|
||||
crypt.Signer = crypt.PyCryptoSigner
|
||||
def setUp(self):
|
||||
self.format = 'pem'
|
||||
crypt.Signer = crypt.PyCryptoSigner
|
||||
|
||||
|
||||
class PKCSSignedJwtAssertionCredentialsPyCryptoTests(unittest.TestCase):
|
||||
|
||||
def test_for_failure(self):
|
||||
crypt.Signer = crypt.PyCryptoSigner
|
||||
private_key = datafile('privatekey.p12')
|
||||
credentials = SignedJwtAssertionCredentials(
|
||||
def test_for_failure(self):
|
||||
crypt.Signer = crypt.PyCryptoSigner
|
||||
private_key = datafile('privatekey.p12')
|
||||
credentials = SignedJwtAssertionCredentials(
|
||||
'some_account@example.com',
|
||||
private_key,
|
||||
scope='read+write',
|
||||
sub='joe@example.org')
|
||||
try:
|
||||
credentials._generate_assertion()
|
||||
self.fail()
|
||||
except NotImplementedError:
|
||||
pass
|
||||
try:
|
||||
credentials._generate_assertion()
|
||||
self.fail()
|
||||
except NotImplementedError:
|
||||
pass
|
||||
|
||||
|
||||
class TestHasOpenSSLFlag(unittest.TestCase):
|
||||
def test_true(self):
|
||||
self.assertEqual(True, HAS_OPENSSL)
|
||||
self.assertEqual(True, HAS_CRYPTO)
|
||||
def test_true(self):
|
||||
self.assertEqual(True, HAS_OPENSSL)
|
||||
self.assertEqual(True, HAS_CRYPTO)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
unittest.main()
|
||||
|
||||
@@ -12,7 +12,6 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
"""Tests for oauth2client.keyring_storage tests.
|
||||
|
||||
Unit tests for oauth2client.keyring_storage.
|
||||
@@ -33,59 +32,59 @@ from oauth2client.keyring_storage import Storage
|
||||
|
||||
class OAuth2ClientKeyringTests(unittest.TestCase):
|
||||
|
||||
def test_non_existent_credentials_storage(self):
|
||||
with mock.patch.object(keyring, 'get_password',
|
||||
def test_non_existent_credentials_storage(self):
|
||||
with mock.patch.object(keyring, 'get_password',
|
||||
return_value=None,
|
||||
autospec=True) as get_password:
|
||||
s = Storage('my_unit_test', 'me')
|
||||
credentials = s.get()
|
||||
self.assertEquals(None, credentials)
|
||||
get_password.assert_called_once_with('my_unit_test', 'me')
|
||||
s = Storage('my_unit_test', 'me')
|
||||
credentials = s.get()
|
||||
self.assertEquals(None, credentials)
|
||||
get_password.assert_called_once_with('my_unit_test', 'me')
|
||||
|
||||
def test_malformed_credentials_in_storage(self):
|
||||
with mock.patch.object(keyring, 'get_password',
|
||||
def test_malformed_credentials_in_storage(self):
|
||||
with mock.patch.object(keyring, 'get_password',
|
||||
return_value='{',
|
||||
autospec=True) as get_password:
|
||||
s = Storage('my_unit_test', 'me')
|
||||
credentials = s.get()
|
||||
self.assertEquals(None, credentials)
|
||||
get_password.assert_called_once_with('my_unit_test', 'me')
|
||||
s = Storage('my_unit_test', 'me')
|
||||
credentials = s.get()
|
||||
self.assertEquals(None, credentials)
|
||||
get_password.assert_called_once_with('my_unit_test', 'me')
|
||||
|
||||
def test_json_credentials_storage(self):
|
||||
access_token = 'foo'
|
||||
client_id = 'some_client_id'
|
||||
client_secret = 'cOuDdkfjxxnv+'
|
||||
refresh_token = '1/0/a.df219fjls0'
|
||||
token_expiry = datetime.datetime.utcnow()
|
||||
user_agent = 'refresh_checker/1.0'
|
||||
def test_json_credentials_storage(self):
|
||||
access_token = 'foo'
|
||||
client_id = 'some_client_id'
|
||||
client_secret = 'cOuDdkfjxxnv+'
|
||||
refresh_token = '1/0/a.df219fjls0'
|
||||
token_expiry = datetime.datetime.utcnow()
|
||||
user_agent = 'refresh_checker/1.0'
|
||||
|
||||
credentials = OAuth2Credentials(
|
||||
credentials = OAuth2Credentials(
|
||||
access_token, client_id, client_secret,
|
||||
refresh_token, token_expiry, GOOGLE_TOKEN_URI,
|
||||
user_agent)
|
||||
|
||||
# Setting autospec on a mock with an iterable side_effect is
|
||||
# currently broken (http://bugs.python.org/issue17826), so instead
|
||||
# we patch twice.
|
||||
with mock.patch.object(keyring, 'get_password',
|
||||
# Setting autospec on a mock with an iterable side_effect is
|
||||
# currently broken (http://bugs.python.org/issue17826), so instead
|
||||
# we patch twice.
|
||||
with mock.patch.object(keyring, 'get_password',
|
||||
return_value=None,
|
||||
autospec=True) as get_password:
|
||||
with mock.patch.object(keyring, 'set_password',
|
||||
with mock.patch.object(keyring, 'set_password',
|
||||
return_value=None,
|
||||
autospec=True) as set_password:
|
||||
s = Storage('my_unit_test', 'me')
|
||||
self.assertEquals(None, s.get())
|
||||
s = Storage('my_unit_test', 'me')
|
||||
self.assertEquals(None, s.get())
|
||||
|
||||
s.put(credentials)
|
||||
s.put(credentials)
|
||||
|
||||
set_password.assert_called_once_with(
|
||||
set_password.assert_called_once_with(
|
||||
'my_unit_test', 'me', credentials.to_json())
|
||||
get_password.assert_called_once_with('my_unit_test', 'me')
|
||||
get_password.assert_called_once_with('my_unit_test', 'me')
|
||||
|
||||
with mock.patch.object(keyring, 'get_password',
|
||||
with mock.patch.object(keyring, 'get_password',
|
||||
return_value=credentials.to_json(),
|
||||
autospec=True) as get_password:
|
||||
restored = s.get()
|
||||
self.assertEqual('foo', restored.access_token)
|
||||
self.assertEqual('some_client_id', restored.client_id)
|
||||
get_password.assert_called_once_with('my_unit_test', 'me')
|
||||
restored = s.get()
|
||||
self.assertEqual('foo', restored.access_token)
|
||||
self.assertEqual('some_client_id', restored.client_id)
|
||||
get_password.assert_called_once_with('my_unit_test', 'me')
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -14,7 +14,6 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
"""Oauth2client tests.
|
||||
|
||||
Unit tests for service account credentials implemented using RSA.
|
||||
@@ -31,94 +30,94 @@ from oauth2client.service_account import _ServiceAccountCredentials
|
||||
|
||||
|
||||
def datafile(filename):
|
||||
# TODO(orestica): Refactor this using pkgutil.get_data
|
||||
f = open(os.path.join(os.path.dirname(__file__), 'data', filename), 'rb')
|
||||
data = f.read()
|
||||
f.close()
|
||||
return data
|
||||
# TODO(orestica): Refactor this using pkgutil.get_data
|
||||
f = open(os.path.join(os.path.dirname(__file__), 'data', filename), 'rb')
|
||||
data = f.read()
|
||||
f.close()
|
||||
return data
|
||||
|
||||
|
||||
class ServiceAccountCredentialsTests(unittest.TestCase):
|
||||
def setUp(self):
|
||||
self.service_account_id = '123'
|
||||
self.service_account_email = 'dummy@google.com'
|
||||
self.private_key_id = 'ABCDEF'
|
||||
self.private_key = datafile('pem_from_pkcs12.pem')
|
||||
self.scopes = ['dummy_scope']
|
||||
self.credentials = _ServiceAccountCredentials(self.service_account_id,
|
||||
def setUp(self):
|
||||
self.service_account_id = '123'
|
||||
self.service_account_email = 'dummy@google.com'
|
||||
self.private_key_id = 'ABCDEF'
|
||||
self.private_key = datafile('pem_from_pkcs12.pem')
|
||||
self.scopes = ['dummy_scope']
|
||||
self.credentials = _ServiceAccountCredentials(self.service_account_id,
|
||||
self.service_account_email,
|
||||
self.private_key_id,
|
||||
self.private_key,
|
||||
[])
|
||||
|
||||
def test_sign_blob(self):
|
||||
private_key_id, signature = self.credentials.sign_blob('Google')
|
||||
self.assertEqual( self.private_key_id, private_key_id)
|
||||
def test_sign_blob(self):
|
||||
private_key_id, signature = self.credentials.sign_blob('Google')
|
||||
self.assertEqual(self.private_key_id, private_key_id)
|
||||
|
||||
pub_key = rsa.PublicKey.load_pkcs1_openssl_pem(
|
||||
pub_key = rsa.PublicKey.load_pkcs1_openssl_pem(
|
||||
datafile('publickey_openssl.pem'))
|
||||
|
||||
self.assertTrue(rsa.pkcs1.verify(b'Google', signature, pub_key))
|
||||
self.assertTrue(rsa.pkcs1.verify(b'Google', signature, pub_key))
|
||||
|
||||
try:
|
||||
rsa.pkcs1.verify(b'Orest', signature, pub_key)
|
||||
self.fail('Verification should have failed!')
|
||||
except rsa.pkcs1.VerificationError:
|
||||
pass # Expected
|
||||
try:
|
||||
rsa.pkcs1.verify(b'Orest', signature, pub_key)
|
||||
self.fail('Verification should have failed!')
|
||||
except rsa.pkcs1.VerificationError:
|
||||
pass # Expected
|
||||
|
||||
try:
|
||||
rsa.pkcs1.verify(b'Google', b'bad signature', pub_key)
|
||||
self.fail('Verification should have failed!')
|
||||
except rsa.pkcs1.VerificationError:
|
||||
pass # Expected
|
||||
try:
|
||||
rsa.pkcs1.verify(b'Google', b'bad signature', pub_key)
|
||||
self.fail('Verification should have failed!')
|
||||
except rsa.pkcs1.VerificationError:
|
||||
pass # Expected
|
||||
|
||||
def test_service_account_email(self):
|
||||
self.assertEqual(self.service_account_email,
|
||||
def test_service_account_email(self):
|
||||
self.assertEqual(self.service_account_email,
|
||||
self.credentials.service_account_email)
|
||||
|
||||
def test_create_scoped_required_without_scopes(self):
|
||||
self.assertTrue(self.credentials.create_scoped_required())
|
||||
def test_create_scoped_required_without_scopes(self):
|
||||
self.assertTrue(self.credentials.create_scoped_required())
|
||||
|
||||
def test_create_scoped_required_with_scopes(self):
|
||||
self.credentials = _ServiceAccountCredentials(self.service_account_id,
|
||||
def test_create_scoped_required_with_scopes(self):
|
||||
self.credentials = _ServiceAccountCredentials(self.service_account_id,
|
||||
self.service_account_email,
|
||||
self.private_key_id,
|
||||
self.private_key,
|
||||
self.scopes)
|
||||
self.assertFalse(self.credentials.create_scoped_required())
|
||||
self.assertFalse(self.credentials.create_scoped_required())
|
||||
|
||||
def test_create_scoped(self):
|
||||
new_credentials = self.credentials.create_scoped(self.scopes)
|
||||
self.assertNotEqual(self.credentials, new_credentials)
|
||||
self.assertTrue(isinstance(new_credentials, _ServiceAccountCredentials))
|
||||
self.assertEqual('dummy_scope', new_credentials._scopes)
|
||||
def test_create_scoped(self):
|
||||
new_credentials = self.credentials.create_scoped(self.scopes)
|
||||
self.assertNotEqual(self.credentials, new_credentials)
|
||||
self.assertTrue(isinstance(new_credentials, _ServiceAccountCredentials))
|
||||
self.assertEqual('dummy_scope', new_credentials._scopes)
|
||||
|
||||
def test_access_token(self):
|
||||
S = 2 # number of seconds in which the token expires
|
||||
token_response_first = {'access_token': 'first_token', 'expires_in': S}
|
||||
token_response_second = {'access_token': 'second_token', 'expires_in': S}
|
||||
http = HttpMockSequence([
|
||||
def test_access_token(self):
|
||||
S = 2 # number of seconds in which the token expires
|
||||
token_response_first = {'access_token': 'first_token', 'expires_in': S}
|
||||
token_response_second = {'access_token': 'second_token', 'expires_in': S}
|
||||
http = HttpMockSequence([
|
||||
({'status': '200'}, json.dumps(token_response_first).encode('utf-8')),
|
||||
({'status': '200'}, json.dumps(token_response_second).encode('utf-8')),
|
||||
])
|
||||
])
|
||||
|
||||
token = self.credentials.get_access_token(http=http)
|
||||
self.assertEqual('first_token', token.access_token)
|
||||
self.assertEqual(S - 1, token.expires_in)
|
||||
self.assertFalse(self.credentials.access_token_expired)
|
||||
self.assertEqual(token_response_first, self.credentials.token_response)
|
||||
token = self.credentials.get_access_token(http=http)
|
||||
self.assertEqual('first_token', token.access_token)
|
||||
self.assertEqual(S - 1, token.expires_in)
|
||||
self.assertFalse(self.credentials.access_token_expired)
|
||||
self.assertEqual(token_response_first, self.credentials.token_response)
|
||||
|
||||
token = self.credentials.get_access_token(http=http)
|
||||
self.assertEqual('first_token', token.access_token)
|
||||
self.assertEqual(S - 1, token.expires_in)
|
||||
self.assertFalse(self.credentials.access_token_expired)
|
||||
self.assertEqual(token_response_first, self.credentials.token_response)
|
||||
token = self.credentials.get_access_token(http=http)
|
||||
self.assertEqual('first_token', token.access_token)
|
||||
self.assertEqual(S - 1, token.expires_in)
|
||||
self.assertFalse(self.credentials.access_token_expired)
|
||||
self.assertEqual(token_response_first, self.credentials.token_response)
|
||||
|
||||
time.sleep(S + 0.5) # some margin to avoid flakiness
|
||||
self.assertTrue(self.credentials.access_token_expired)
|
||||
time.sleep(S + 0.5) # some margin to avoid flakiness
|
||||
self.assertTrue(self.credentials.access_token_expired)
|
||||
|
||||
token = self.credentials.get_access_token(http=http)
|
||||
self.assertEqual('second_token', token.access_token)
|
||||
self.assertEqual(S - 1, token.expires_in)
|
||||
self.assertFalse(self.credentials.access_token_expired)
|
||||
self.assertEqual(token_response_second, self.credentials.token_response)
|
||||
token = self.credentials.get_access_token(http=http)
|
||||
self.assertEqual('second_token', token.access_token)
|
||||
self.assertEqual(S - 1, token.expires_in)
|
||||
self.assertFalse(self.credentials.access_token_expired)
|
||||
self.assertEqual(token_response_second, self.credentials.token_response)
|
||||
|
||||
@@ -5,6 +5,7 @@ from oauth2client import tools
|
||||
from six.moves.urllib import request
|
||||
import threading
|
||||
|
||||
|
||||
class TestClientRedirectServer(unittest.TestCase):
|
||||
"""Test the ClientRedirectServer and ClientRedirectHandler classes."""
|
||||
|
||||
@@ -15,16 +16,15 @@ class TestClientRedirectServer(unittest.TestCase):
|
||||
httpd = tools.ClientRedirectServer(('localhost', 0), tools.ClientRedirectHandler)
|
||||
code = 'foo'
|
||||
url = 'http://localhost:%i?code=%s' % (httpd.server_address[1], code)
|
||||
t = threading.Thread(target = httpd.handle_request)
|
||||
t = threading.Thread(target=httpd.handle_request)
|
||||
t.setDaemon(True)
|
||||
t.start()
|
||||
f = request.urlopen( url )
|
||||
f = request.urlopen(url)
|
||||
self.assertTrue(f.read())
|
||||
t.join()
|
||||
httpd.server_close()
|
||||
self.assertEqual(httpd.query_params.get('code'),code)
|
||||
self.assertEqual(httpd.query_params.get('code'), code)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
|
||||
|
||||
@@ -9,8 +9,8 @@ from oauth2client import util
|
||||
|
||||
class ScopeToStringTests(unittest.TestCase):
|
||||
|
||||
def test_iterables(self):
|
||||
cases = [
|
||||
def test_iterables(self):
|
||||
cases = [
|
||||
('', ''),
|
||||
('', ()),
|
||||
('', []),
|
||||
@@ -22,36 +22,37 @@ class ScopeToStringTests(unittest.TestCase):
|
||||
('a b', ('a', 'b')),
|
||||
('a b', 'a b'),
|
||||
('a b', (s for s in ['a', 'b'])),
|
||||
]
|
||||
for expected, case in cases:
|
||||
self.assertEqual(expected, util.scopes_to_string(case))
|
||||
]
|
||||
for expected, case in cases:
|
||||
self.assertEqual(expected, util.scopes_to_string(case))
|
||||
|
||||
|
||||
class StringToScopeTests(unittest.TestCase):
|
||||
|
||||
def test_conversion(self):
|
||||
cases = [
|
||||
def test_conversion(self):
|
||||
cases = [
|
||||
(['a', 'b'], ['a', 'b']),
|
||||
('', []),
|
||||
('a', ['a']),
|
||||
('a b c d e f', ['a', 'b', 'c', 'd', 'e', 'f']),
|
||||
]
|
||||
]
|
||||
|
||||
for case, expected in cases:
|
||||
self.assertEqual(expected, util.string_to_scopes(case))
|
||||
|
||||
for case, expected in cases:
|
||||
self.assertEqual(expected, util.string_to_scopes(case))
|
||||
|
||||
class KeyConversionTests(unittest.TestCase):
|
||||
|
||||
def test_key_conversions(self):
|
||||
d = {'somekey': 'some value', 'another': 'something else', 'onemore': 'foo'}
|
||||
tuple_key = util.dict_to_tuple_key(d)
|
||||
def test_key_conversions(self):
|
||||
d = {'somekey': 'some value', 'another': 'something else', 'onemore': 'foo'}
|
||||
tuple_key = util.dict_to_tuple_key(d)
|
||||
|
||||
# the resulting key should be naturally sorted
|
||||
self.assertEqual(
|
||||
# the resulting key should be naturally sorted
|
||||
self.assertEqual(
|
||||
(('another', 'something else'),
|
||||
('onemore', 'foo'),
|
||||
('somekey', 'some value')),
|
||||
tuple_key)
|
||||
|
||||
# check we get the original dictionary back
|
||||
self.assertEqual(d, dict(tuple_key))
|
||||
# check we get the original dictionary back
|
||||
self.assertEqual(d, dict(tuple_key))
|
||||
|
||||
@@ -34,78 +34,78 @@ TEST_EXTRA_INFO_2 = 'more_extra_info'
|
||||
|
||||
|
||||
class XsrfUtilTests(unittest.TestCase):
|
||||
"""Test xsrfutil functions."""
|
||||
"""Test xsrfutil functions."""
|
||||
|
||||
def testGenerateAndValidateToken(self):
|
||||
"""Test generating and validating a token."""
|
||||
token = xsrfutil.generate_token(TEST_KEY,
|
||||
def testGenerateAndValidateToken(self):
|
||||
"""Test generating and validating a token."""
|
||||
token = xsrfutil.generate_token(TEST_KEY,
|
||||
TEST_USER_ID_1,
|
||||
action_id=TEST_ACTION_ID_1,
|
||||
when=TEST_TIME)
|
||||
|
||||
# Check that the token is considered valid when it should be.
|
||||
self.assertTrue(xsrfutil.validate_token(TEST_KEY,
|
||||
# Check that the token is considered valid when it should be.
|
||||
self.assertTrue(xsrfutil.validate_token(TEST_KEY,
|
||||
token,
|
||||
TEST_USER_ID_1,
|
||||
action_id=TEST_ACTION_ID_1,
|
||||
current_time=TEST_TIME))
|
||||
|
||||
# Should still be valid 15 minutes later.
|
||||
later15mins = TEST_TIME + 15*60
|
||||
self.assertTrue(xsrfutil.validate_token(TEST_KEY,
|
||||
# Should still be valid 15 minutes later.
|
||||
later15mins = TEST_TIME + 15 * 60
|
||||
self.assertTrue(xsrfutil.validate_token(TEST_KEY,
|
||||
token,
|
||||
TEST_USER_ID_1,
|
||||
action_id=TEST_ACTION_ID_1,
|
||||
current_time=later15mins))
|
||||
|
||||
# But not if beyond the timeout.
|
||||
later2hours = TEST_TIME + 2*60*60
|
||||
self.assertFalse(xsrfutil.validate_token(TEST_KEY,
|
||||
# But not if beyond the timeout.
|
||||
later2hours = TEST_TIME + 2 * 60 * 60
|
||||
self.assertFalse(xsrfutil.validate_token(TEST_KEY,
|
||||
token,
|
||||
TEST_USER_ID_1,
|
||||
action_id=TEST_ACTION_ID_1,
|
||||
current_time=later2hours))
|
||||
|
||||
# Or if the key is different.
|
||||
self.assertFalse(xsrfutil.validate_token('another key',
|
||||
# Or if the key is different.
|
||||
self.assertFalse(xsrfutil.validate_token('another key',
|
||||
token,
|
||||
TEST_USER_ID_1,
|
||||
action_id=TEST_ACTION_ID_1,
|
||||
current_time=later15mins))
|
||||
|
||||
# Or the user ID....
|
||||
self.assertFalse(xsrfutil.validate_token(TEST_KEY,
|
||||
# Or the user ID....
|
||||
self.assertFalse(xsrfutil.validate_token(TEST_KEY,
|
||||
token,
|
||||
TEST_USER_ID_2,
|
||||
action_id=TEST_ACTION_ID_1,
|
||||
current_time=later15mins))
|
||||
|
||||
# Or the action ID...
|
||||
self.assertFalse(xsrfutil.validate_token(TEST_KEY,
|
||||
# Or the action ID...
|
||||
self.assertFalse(xsrfutil.validate_token(TEST_KEY,
|
||||
token,
|
||||
TEST_USER_ID_1,
|
||||
action_id=TEST_ACTION_ID_2,
|
||||
current_time=later15mins))
|
||||
|
||||
# Invalid when truncated
|
||||
self.assertFalse(xsrfutil.validate_token(TEST_KEY,
|
||||
# Invalid when truncated
|
||||
self.assertFalse(xsrfutil.validate_token(TEST_KEY,
|
||||
token[:-1],
|
||||
TEST_USER_ID_1,
|
||||
action_id=TEST_ACTION_ID_1,
|
||||
current_time=later15mins))
|
||||
|
||||
# Invalid with extra garbage
|
||||
self.assertFalse(xsrfutil.validate_token(TEST_KEY,
|
||||
# Invalid with extra garbage
|
||||
self.assertFalse(xsrfutil.validate_token(TEST_KEY,
|
||||
token + b'x',
|
||||
TEST_USER_ID_1,
|
||||
action_id=TEST_ACTION_ID_1,
|
||||
current_time=later15mins))
|
||||
|
||||
# Invalid with token of None
|
||||
self.assertFalse(xsrfutil.validate_token(TEST_KEY,
|
||||
# Invalid with token of None
|
||||
self.assertFalse(xsrfutil.validate_token(TEST_KEY,
|
||||
None,
|
||||
TEST_USER_ID_1,
|
||||
action_id=TEST_ACTION_ID_1))
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
unittest.main()
|
||||
|
||||
Reference in New Issue
Block a user