Raw pep8ify changes.

Simply ran

  pep8ify -w oauth2client/
  pep8ify -w tests/
This commit is contained in:
Danny Hermes
2015-08-19 22:03:50 -07:00
parent 043e066e54
commit 34c1ff543d
40 changed files with 4635 additions and 4654 deletions

View File

@@ -19,7 +19,7 @@ import six
def _parse_pem_key(raw_key_input): def _parse_pem_key(raw_key_input):
"""Identify and extract PEM keys. """Identify and extract PEM keys.
Determines whether the given key is in the format of PEM key, and extracts Determines whether the given key is in the format of PEM key, and extracts
the relevant part of the key if it is. the relevant part of the key if it is.
@@ -30,17 +30,17 @@ def _parse_pem_key(raw_key_input):
Returns: Returns:
string, The actual key if the contents are from a PEM file, or else None. string, The actual key if the contents are from a PEM file, or else None.
""" """
offset = raw_key_input.find(b'-----BEGIN ') offset = raw_key_input.find(b'-----BEGIN ')
if offset != -1: if offset != -1:
return raw_key_input[offset:] return raw_key_input[offset:]
def _json_encode(data): def _json_encode(data):
return json.dumps(data, separators=(',', ':')) return json.dumps(data, separators=(',', ':'))
def _to_bytes(value, encoding='ascii'): def _to_bytes(value, encoding='ascii'):
"""Converts a string value to bytes, if necessary. """Converts a string value to bytes, if necessary.
Unfortunately, ``six.b`` is insufficient for this task since in Unfortunately, ``six.b`` is insufficient for this task since in
Python2 it does not modify ``unicode`` objects. Python2 it does not modify ``unicode`` objects.
@@ -60,16 +60,16 @@ def _to_bytes(value, encoding='ascii'):
Raises: Raises:
ValueError if the value could not be converted to bytes. ValueError if the value could not be converted to bytes.
""" """
result = (value.encode(encoding) result = (value.encode(encoding)
if isinstance(value, six.text_type) else value) if isinstance(value, six.text_type) else value)
if isinstance(result, six.binary_type): if isinstance(result, six.binary_type):
return result return result
else: else:
raise ValueError('%r could not be converted to bytes' % (value,)) raise ValueError('%r could not be converted to bytes' % (value, ))
def _from_bytes(value): def _from_bytes(value):
"""Converts bytes to a string value, if necessary. """Converts bytes to a string value, if necessary.
Args: Args:
value: The string/bytes value to be converted. value: The string/bytes value to be converted.
@@ -81,21 +81,21 @@ def _from_bytes(value):
Raises: Raises:
ValueError if the value could not be converted to unicode. ValueError if the value could not be converted to unicode.
""" """
result = (value.decode('utf-8') result = (value.decode('utf-8')
if isinstance(value, six.binary_type) else value) if isinstance(value, six.binary_type) else value)
if isinstance(result, six.text_type): if isinstance(result, six.text_type):
return result return result
else: else:
raise ValueError('%r could not be converted to unicode' % (value,)) raise ValueError('%r could not be converted to unicode' % (value, ))
def _urlsafe_b64encode(raw_bytes): def _urlsafe_b64encode(raw_bytes):
raw_bytes = _to_bytes(raw_bytes, encoding='utf-8') raw_bytes = _to_bytes(raw_bytes, encoding='utf-8')
return base64.urlsafe_b64encode(raw_bytes).rstrip(b'=') return base64.urlsafe_b64encode(raw_bytes).rstrip(b'=')
def _urlsafe_b64decode(b64string): def _urlsafe_b64decode(b64string):
# Guard against unicode strings, which base64 can't handle. # Guard against unicode strings, which base64 can't handle.
b64string = _to_bytes(b64string) b64string = _to_bytes(b64string)
padded = b64string + b'=' * (4 - len(b64string) % 4) padded = b64string + b'=' * (4 - len(b64string) % 4)
return base64.urlsafe_b64decode(padded) return base64.urlsafe_b64decode(padded)

View File

@@ -22,18 +22,18 @@ from oauth2client._helpers import _to_bytes
class OpenSSLVerifier(object): class OpenSSLVerifier(object):
"""Verifies the signature on a message.""" """Verifies the signature on a message."""
def __init__(self, pubkey): def __init__(self, pubkey):
"""Constructor. """Constructor.
Args: Args:
pubkey, OpenSSL.crypto.PKey, The public key to verify with. pubkey, OpenSSL.crypto.PKey, The public key to verify with.
""" """
self._pubkey = pubkey self._pubkey = pubkey
def verify(self, message, signature): def verify(self, message, signature):
"""Verifies a message against a signature. """Verifies a message against a signature.
Args: Args:
message: string or bytes, The message to verify. If string, will be message: string or bytes, The message to verify. If string, will be
@@ -45,17 +45,17 @@ class OpenSSLVerifier(object):
True if message was signed by the private key associated with the public True if message was signed by the private key associated with the public
key that this object was constructed with. key that this object was constructed with.
""" """
message = _to_bytes(message, encoding='utf-8') message = _to_bytes(message, encoding='utf-8')
signature = _to_bytes(signature, encoding='utf-8') signature = _to_bytes(signature, encoding='utf-8')
try: try:
crypto.verify(self._pubkey, signature, message, 'sha256') crypto.verify(self._pubkey, signature, message, 'sha256')
return True return True
except crypto.Error: except crypto.Error:
return False return False
@staticmethod @staticmethod
def from_string(key_pem, is_x509_cert): def from_string(key_pem, is_x509_cert):
"""Construct a Verified instance from a string. """Construct a Verified instance from a string.
Args: Args:
key_pem: string, public key in PEM format. key_pem: string, public key in PEM format.
@@ -68,26 +68,26 @@ class OpenSSLVerifier(object):
Raises: Raises:
OpenSSL.crypto.Error if the key_pem can't be parsed. OpenSSL.crypto.Error if the key_pem can't be parsed.
""" """
if is_x509_cert: if is_x509_cert:
pubkey = crypto.load_certificate(crypto.FILETYPE_PEM, key_pem) pubkey = crypto.load_certificate(crypto.FILETYPE_PEM, key_pem)
else: else:
pubkey = crypto.load_privatekey(crypto.FILETYPE_PEM, key_pem) pubkey = crypto.load_privatekey(crypto.FILETYPE_PEM, key_pem)
return OpenSSLVerifier(pubkey) return OpenSSLVerifier(pubkey)
class OpenSSLSigner(object): class OpenSSLSigner(object):
"""Signs messages with a private key.""" """Signs messages with a private key."""
def __init__(self, pkey): def __init__(self, pkey):
"""Constructor. """Constructor.
Args: Args:
pkey, OpenSSL.crypto.PKey (or equiv), The private key to sign with. pkey, OpenSSL.crypto.PKey (or equiv), The private key to sign with.
""" """
self._key = pkey self._key = pkey
def sign(self, message): def sign(self, message):
"""Signs a message. """Signs a message.
Args: Args:
message: bytes, Message to be signed. message: bytes, Message to be signed.
@@ -95,12 +95,12 @@ class OpenSSLSigner(object):
Returns: Returns:
string, The signature of the message for the given key. string, The signature of the message for the given key.
""" """
message = _to_bytes(message, encoding='utf-8') message = _to_bytes(message, encoding='utf-8')
return crypto.sign(self._key, message, 'sha256') return crypto.sign(self._key, message, 'sha256')
@staticmethod @staticmethod
def from_string(key, password=b'notasecret'): def from_string(key, password=b'notasecret'):
"""Construct a Signer instance from a string. """Construct a Signer instance from a string.
Args: Args:
key: string, private key in PKCS12 or PEM format. key: string, private key in PKCS12 or PEM format.
@@ -112,17 +112,17 @@ class OpenSSLSigner(object):
Raises: Raises:
OpenSSL.crypto.Error if the key can't be parsed. OpenSSL.crypto.Error if the key can't be parsed.
""" """
parsed_pem_key = _parse_pem_key(key) parsed_pem_key = _parse_pem_key(key)
if parsed_pem_key: if parsed_pem_key:
pkey = crypto.load_privatekey(crypto.FILETYPE_PEM, parsed_pem_key) pkey = crypto.load_privatekey(crypto.FILETYPE_PEM, parsed_pem_key)
else: else:
password = _to_bytes(password, encoding='utf-8') password = _to_bytes(password, encoding='utf-8')
pkey = crypto.load_pkcs12(key, password).get_privatekey() pkey = crypto.load_pkcs12(key, password).get_privatekey()
return OpenSSLSigner(pkey) return OpenSSLSigner(pkey)
def pkcs12_key_as_pem(private_key_text, private_key_password): def pkcs12_key_as_pem(private_key_text, private_key_password):
"""Convert the contents of a PKCS12 key to PEM using OpenSSL. """Convert the contents of a PKCS12 key to PEM using OpenSSL.
Args: Args:
private_key_text: String. Private key. private_key_text: String. Private key.
@@ -131,9 +131,9 @@ def pkcs12_key_as_pem(private_key_text, private_key_password):
Returns: Returns:
String. PEM contents of ``private_key_text``. String. PEM contents of ``private_key_text``.
""" """
decoded_body = base64.b64decode(private_key_text) decoded_body = base64.b64decode(private_key_text)
private_key_password = _to_bytes(private_key_password) private_key_password = _to_bytes(private_key_password)
pkcs12 = crypto.load_pkcs12(decoded_body, private_key_password) pkcs12 = crypto.load_pkcs12(decoded_body, private_key_password)
return crypto.dump_privatekey(crypto.FILETYPE_PEM, return crypto.dump_privatekey(crypto.FILETYPE_PEM,
pkcs12.get_privatekey()) pkcs12.get_privatekey())

View File

@@ -25,18 +25,18 @@ from oauth2client._helpers import _urlsafe_b64decode
class PyCryptoVerifier(object): class PyCryptoVerifier(object):
"""Verifies the signature on a message.""" """Verifies the signature on a message."""
def __init__(self, pubkey): def __init__(self, pubkey):
"""Constructor. """Constructor.
Args: Args:
pubkey, OpenSSL.crypto.PKey (or equiv), The public key to verify with. pubkey, OpenSSL.crypto.PKey (or equiv), The public key to verify with.
""" """
self._pubkey = pubkey self._pubkey = pubkey
def verify(self, message, signature): def verify(self, message, signature):
"""Verifies a message against a signature. """Verifies a message against a signature.
Args: Args:
message: string or bytes, The message to verify. If string, will be message: string or bytes, The message to verify. If string, will be
@@ -47,13 +47,13 @@ class PyCryptoVerifier(object):
True if message was signed by the private key associated with the public True if message was signed by the private key associated with the public
key that this object was constructed with. key that this object was constructed with.
""" """
message = _to_bytes(message, encoding='utf-8') message = _to_bytes(message, encoding='utf-8')
return PKCS1_v1_5.new(self._pubkey).verify( return PKCS1_v1_5.new(self._pubkey).verify(
SHA256.new(message), signature) SHA256.new(message), signature)
@staticmethod @staticmethod
def from_string(key_pem, is_x509_cert): def from_string(key_pem, is_x509_cert):
"""Construct a Verified instance from a string. """Construct a Verified instance from a string.
Args: Args:
key_pem: string, public key in PEM format. key_pem: string, public key in PEM format.
@@ -63,33 +63,33 @@ class PyCryptoVerifier(object):
Returns: Returns:
Verifier instance. Verifier instance.
""" """
if is_x509_cert: if is_x509_cert:
key_pem = _to_bytes(key_pem) key_pem = _to_bytes(key_pem)
pemLines = key_pem.replace(b' ', b'').split() pemLines = key_pem.replace(b' ', b'').split()
certDer = _urlsafe_b64decode(b''.join(pemLines[1:-1])) certDer = _urlsafe_b64decode(b''.join(pemLines[1:-1]))
certSeq = DerSequence() certSeq = DerSequence()
certSeq.decode(certDer) certSeq.decode(certDer)
tbsSeq = DerSequence() tbsSeq = DerSequence()
tbsSeq.decode(certSeq[0]) tbsSeq.decode(certSeq[0])
pubkey = RSA.importKey(tbsSeq[6]) pubkey = RSA.importKey(tbsSeq[6])
else: else:
pubkey = RSA.importKey(key_pem) pubkey = RSA.importKey(key_pem)
return PyCryptoVerifier(pubkey) return PyCryptoVerifier(pubkey)
class PyCryptoSigner(object): class PyCryptoSigner(object):
"""Signs messages with a private key.""" """Signs messages with a private key."""
def __init__(self, pkey): def __init__(self, pkey):
"""Constructor. """Constructor.
Args: Args:
pkey, OpenSSL.crypto.PKey (or equiv), The private key to sign with. pkey, OpenSSL.crypto.PKey (or equiv), The private key to sign with.
""" """
self._key = pkey self._key = pkey
def sign(self, message): def sign(self, message):
"""Signs a message. """Signs a message.
Args: Args:
message: string, Message to be signed. message: string, Message to be signed.
@@ -97,12 +97,12 @@ class PyCryptoSigner(object):
Returns: Returns:
string, The signature of the message for the given key. string, The signature of the message for the given key.
""" """
message = _to_bytes(message, encoding='utf-8') message = _to_bytes(message, encoding='utf-8')
return PKCS1_v1_5.new(self._key).sign(SHA256.new(message)) return PKCS1_v1_5.new(self._key).sign(SHA256.new(message))
@staticmethod @staticmethod
def from_string(key, password='notasecret'): def from_string(key, password='notasecret'):
"""Construct a Signer instance from a string. """Construct a Signer instance from a string.
Args: Args:
key: string, private key in PEM format. key: string, private key in PEM format.
@@ -114,13 +114,13 @@ class PyCryptoSigner(object):
Raises: Raises:
NotImplementedError if the key isn't in PEM format. NotImplementedError if the key isn't in PEM format.
""" """
parsed_pem_key = _parse_pem_key(key) parsed_pem_key = _parse_pem_key(key)
if parsed_pem_key: if parsed_pem_key:
pkey = RSA.importKey(parsed_pem_key) pkey = RSA.importKey(parsed_pem_key)
else: else:
raise NotImplementedError( raise NotImplementedError(
'PKCS12 format is not supported by the PyCrypto library. ' 'PKCS12 format is not supported by the PyCrypto library. '
'Try converting to a "PEM" ' 'Try converting to a "PEM" '
'(openssl pkcs12 -in xxxxx.p12 -nodes -nocerts > privatekey.pem) ' '(openssl pkcs12 -in xxxxx.p12 -nodes -nocerts > privatekey.pem) '
'or using PyOpenSSL if native code is an option.') 'or using PyOpenSSL if native code is an option.')
return PyCryptoSigner(pkey) return PyCryptoSigner(pkey)

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@@ -23,7 +23,6 @@ __author__ = 'jcgregorio@google.com (Joe Gregorio)'
import json import json
import six import six
# Properties that make a client_secrets.json file valid. # Properties that make a client_secrets.json file valid.
TYPE_WEB = 'web' TYPE_WEB = 'web'
TYPE_INSTALLED = 'installed' TYPE_INSTALLED = 'installed'
@@ -59,65 +58,65 @@ VALID_CLIENT = {
class Error(Exception): class Error(Exception):
"""Base error for this module.""" """Base error for this module."""
pass pass
class InvalidClientSecretsError(Error): class InvalidClientSecretsError(Error):
"""Format of ClientSecrets file is invalid.""" """Format of ClientSecrets file is invalid."""
pass pass
def _validate_clientsecrets(obj): def _validate_clientsecrets(obj):
_INVALID_FILE_FORMAT_MSG = ( _INVALID_FILE_FORMAT_MSG = (
'Invalid file format. See ' 'Invalid file format. See '
'https://developers.google.com/api-client-library/' 'https://developers.google.com/api-client-library/'
'python/guide/aaa_client_secrets') 'python/guide/aaa_client_secrets')
if obj is None: if obj is None:
raise InvalidClientSecretsError(_INVALID_FILE_FORMAT_MSG) raise InvalidClientSecretsError(_INVALID_FILE_FORMAT_MSG)
if len(obj) != 1: if len(obj) != 1:
raise InvalidClientSecretsError( raise InvalidClientSecretsError(
_INVALID_FILE_FORMAT_MSG + ' ' _INVALID_FILE_FORMAT_MSG + ' '
'Expected a JSON object with a single property for a "web" or ' 'Expected a JSON object with a single property for a "web" or '
'"installed" application') '"installed" application')
client_type = tuple(obj)[0] client_type = tuple(obj)[0]
if client_type not in VALID_CLIENT: if client_type not in VALID_CLIENT:
raise InvalidClientSecretsError('Unknown client type: %s.' % (client_type,)) raise InvalidClientSecretsError('Unknown client type: %s.' % (client_type, ))
client_info = obj[client_type] client_info = obj[client_type]
for prop_name in VALID_CLIENT[client_type]['required']: for prop_name in VALID_CLIENT[client_type]['required']:
if prop_name not in client_info: if prop_name not in client_info:
raise InvalidClientSecretsError( raise InvalidClientSecretsError(
'Missing property "%s" in a client type of "%s".' % (prop_name, 'Missing property "%s" in a client type of "%s".' % (prop_name,
client_type)) client_type))
for prop_name in VALID_CLIENT[client_type]['string']: for prop_name in VALID_CLIENT[client_type]['string']:
if client_info[prop_name].startswith('[['): if client_info[prop_name].startswith('[['):
raise InvalidClientSecretsError( raise InvalidClientSecretsError(
'Property "%s" is not configured.' % prop_name) 'Property "%s" is not configured.' % prop_name)
return client_type, client_info return client_type, client_info
def load(fp): def load(fp):
obj = json.load(fp) obj = json.load(fp)
return _validate_clientsecrets(obj) return _validate_clientsecrets(obj)
def loads(s): def loads(s):
obj = json.loads(s) obj = json.loads(s)
return _validate_clientsecrets(obj) return _validate_clientsecrets(obj)
def _loadfile(filename): def _loadfile(filename):
try: try:
with open(filename, 'r') as fp: with open(filename, 'r') as fp:
obj = json.load(fp) obj = json.load(fp)
except IOError: except IOError:
raise InvalidClientSecretsError('File not found: "%s"' % filename) raise InvalidClientSecretsError('File not found: "%s"' % filename)
return _validate_clientsecrets(obj) return _validate_clientsecrets(obj)
def loadfile(filename, cache=None): def loadfile(filename, cache=None):
"""Loading of client_secrets JSON file, optionally backed by a cache. """Loading of client_secrets JSON file, optionally backed by a cache.
Typical cache storage would be App Engine memcache service, Typical cache storage would be App Engine memcache service,
but you can pass in any other cache client that implements but you can pass in any other cache client that implements
@@ -149,15 +148,15 @@ def loadfile(filename, cache=None):
JSON contents is validated only during first load. Cache hits are not JSON contents is validated only during first load. Cache hits are not
validated. validated.
""" """
_SECRET_NAMESPACE = 'oauth2client:secrets#ns' _SECRET_NAMESPACE = 'oauth2client:secrets#ns'
if not cache: if not cache:
return _loadfile(filename) return _loadfile(filename)
obj = cache.get(filename, namespace=_SECRET_NAMESPACE) obj = cache.get(filename, namespace=_SECRET_NAMESPACE)
if obj is None: if obj is None:
client_type, client_info = _loadfile(filename) client_type, client_info = _loadfile(filename)
obj = {client_type: client_info} obj = {client_type: client_info}
cache.set(filename, obj, namespace=_SECRET_NAMESPACE) cache.set(filename, obj, namespace=_SECRET_NAMESPACE)
return next(six.iteritems(obj)) return next(six.iteritems(obj))

View File

@@ -25,51 +25,51 @@ from oauth2client._helpers import _to_bytes
from oauth2client._helpers import _urlsafe_b64decode from oauth2client._helpers import _urlsafe_b64decode
from oauth2client._helpers import _urlsafe_b64encode from oauth2client._helpers import _urlsafe_b64encode
CLOCK_SKEW_SECS = 300 # 5 minutes in seconds CLOCK_SKEW_SECS = 300 # 5 minutes in seconds
AUTH_TOKEN_LIFETIME_SECS = 300 # 5 minutes in seconds AUTH_TOKEN_LIFETIME_SECS = 300 # 5 minutes in seconds
MAX_TOKEN_LIFETIME_SECS = 86400 # 1 day in seconds MAX_TOKEN_LIFETIME_SECS = 86400 # 1 day in seconds
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class AppIdentityError(Exception): class AppIdentityError(Exception):
pass pass
try: try:
from oauth2client._openssl_crypt import OpenSSLVerifier from oauth2client._openssl_crypt import OpenSSLVerifier
from oauth2client._openssl_crypt import OpenSSLSigner from oauth2client._openssl_crypt import OpenSSLSigner
from oauth2client._openssl_crypt import pkcs12_key_as_pem from oauth2client._openssl_crypt import pkcs12_key_as_pem
except ImportError: except ImportError:
OpenSSLVerifier = None OpenSSLVerifier = None
OpenSSLSigner = None OpenSSLSigner = None
def pkcs12_key_as_pem(*args, **kwargs):
raise NotImplementedError('pkcs12_key_as_pem requires OpenSSL.')
def pkcs12_key_as_pem(*args, **kwargs):
raise NotImplementedError('pkcs12_key_as_pem requires OpenSSL.')
try: try:
from oauth2client._pycrypto_crypt import PyCryptoVerifier from oauth2client._pycrypto_crypt import PyCryptoVerifier
from oauth2client._pycrypto_crypt import PyCryptoSigner from oauth2client._pycrypto_crypt import PyCryptoSigner
except ImportError: except ImportError:
PyCryptoVerifier = None PyCryptoVerifier = None
PyCryptoSigner = None PyCryptoSigner = None
if OpenSSLSigner: if OpenSSLSigner:
Signer = OpenSSLSigner Signer = OpenSSLSigner
Verifier = OpenSSLVerifier Verifier = OpenSSLVerifier
elif PyCryptoSigner: elif PyCryptoSigner:
Signer = PyCryptoSigner Signer = PyCryptoSigner
Verifier = PyCryptoVerifier Verifier = PyCryptoVerifier
else: else:
raise ImportError('No encryption library found. Please install either ' raise ImportError('No encryption library found. Please install either '
'PyOpenSSL, or PyCrypto 2.6 or later') 'PyOpenSSL, or PyCrypto 2.6 or later')
def make_signed_jwt(signer, payload): def make_signed_jwt(signer, payload):
"""Make a signed JWT. """Make a signed JWT.
See http://self-issued.info/docs/draft-jones-json-web-token.html. See http://self-issued.info/docs/draft-jones-json-web-token.html.
@@ -80,24 +80,24 @@ def make_signed_jwt(signer, payload):
Returns: Returns:
string, The JWT for the payload. string, The JWT for the payload.
""" """
header = {'typ': 'JWT', 'alg': 'RS256'} header = {'typ': 'JWT', 'alg': 'RS256'}
segments = [ segments = [
_urlsafe_b64encode(_json_encode(header)), _urlsafe_b64encode(_json_encode(header)),
_urlsafe_b64encode(_json_encode(payload)), _urlsafe_b64encode(_json_encode(payload)),
] ]
signing_input = b'.'.join(segments) signing_input = b'.'.join(segments)
signature = signer.sign(signing_input) signature = signer.sign(signing_input)
segments.append(_urlsafe_b64encode(signature)) segments.append(_urlsafe_b64encode(signature))
logger.debug(str(segments)) logger.debug(str(segments))
return b'.'.join(segments) return b'.'.join(segments)
def verify_signed_jwt_with_certs(jwt, certs, audience): def verify_signed_jwt_with_certs(jwt, certs, audience):
"""Verify a JWT against public certs. """Verify a JWT against public certs.
See http://self-issued.info/docs/draft-jones-json-web-token.html. See http://self-issued.info/docs/draft-jones-json-web-token.html.
@@ -113,61 +113,61 @@ def verify_signed_jwt_with_certs(jwt, certs, audience):
Raises: Raises:
AppIdentityError if any checks are failed. AppIdentityError if any checks are failed.
""" """
jwt = _to_bytes(jwt) jwt = _to_bytes(jwt)
segments = jwt.split(b'.') segments = jwt.split(b'.')
if len(segments) != 3: if len(segments) != 3:
raise AppIdentityError('Wrong number of segments in token: %s' % jwt) raise AppIdentityError('Wrong number of segments in token: %s' % jwt)
signed = segments[0] + b'.' + segments[1] signed = segments[0] + b'.' + segments[1]
signature = _urlsafe_b64decode(segments[2]) signature = _urlsafe_b64decode(segments[2])
# Parse token. # Parse token.
json_body = _urlsafe_b64decode(segments[1]) json_body = _urlsafe_b64decode(segments[1])
try: try:
parsed = json.loads(_from_bytes(json_body)) parsed = json.loads(_from_bytes(json_body))
except: except:
raise AppIdentityError('Can\'t parse token: %s' % json_body) raise AppIdentityError('Can\'t parse token: %s' % json_body)
# Check signature. # Check signature.
verified = False verified = False
for pem in certs.values(): for pem in certs.values():
verifier = Verifier.from_string(pem, True) verifier = Verifier.from_string(pem, True)
if verifier.verify(signed, signature): if verifier.verify(signed, signature):
verified = True verified = True
break break
if not verified: if not verified:
raise AppIdentityError('Invalid token signature: %s' % jwt) raise AppIdentityError('Invalid token signature: %s' % jwt)
# Check creation timestamp. # Check creation timestamp.
iat = parsed.get('iat') iat = parsed.get('iat')
if iat is None: if iat is None:
raise AppIdentityError('No iat field in token: %s' % json_body) raise AppIdentityError('No iat field in token: %s' % json_body)
earliest = iat - CLOCK_SKEW_SECS earliest = iat - CLOCK_SKEW_SECS
# Check expiration timestamp. # Check expiration timestamp.
now = int(time.time()) now = int(time.time())
exp = parsed.get('exp') exp = parsed.get('exp')
if exp is None: if exp is None:
raise AppIdentityError('No exp field in token: %s' % json_body) raise AppIdentityError('No exp field in token: %s' % json_body)
if exp >= now + MAX_TOKEN_LIFETIME_SECS: if exp >= now + MAX_TOKEN_LIFETIME_SECS:
raise AppIdentityError('exp field too far in future: %s' % json_body) raise AppIdentityError('exp field too far in future: %s' % json_body)
latest = exp + CLOCK_SKEW_SECS latest = exp + CLOCK_SKEW_SECS
if now < earliest: if now < earliest:
raise AppIdentityError('Token used too early, %d < %d: %s' % raise AppIdentityError('Token used too early, %d < %d: %s' %
(now, earliest, json_body)) (now, earliest, json_body))
if now > latest: if now > latest:
raise AppIdentityError('Token used too late, %d > %d: %s' % raise AppIdentityError('Token used too late, %d > %d: %s' %
(now, latest, json_body)) (now, latest, json_body))
# Check audience. # Check audience.
if audience is not None: if audience is not None:
aud = parsed.get('aud') aud = parsed.get('aud')
if aud is None: if aud is None:
raise AppIdentityError('No aud field in token: %s' % json_body) raise AppIdentityError('No aud field in token: %s' % json_body)
if aud != audience: if aud != audience:
raise AppIdentityError('Wrong recipient, %s != %s: %s' % raise AppIdentityError('Wrong recipient, %s != %s: %s' %
(aud, audience, json_body)) (aud, audience, json_body))
return parsed return parsed

View File

@@ -20,22 +20,20 @@ import os
from oauth2client._helpers import _to_bytes from oauth2client._helpers import _to_bytes
from oauth2client import client from oauth2client import client
DEVSHELL_ENV = 'DEVSHELL_CLIENT_PORT' DEVSHELL_ENV = 'DEVSHELL_CLIENT_PORT'
class Error(Exception): class Error(Exception):
"""Errors for this module.""" """Errors for this module."""
pass pass
class CommunicationError(Error): class CommunicationError(Error):
"""Errors for communication with the Developer Shell server.""" """Errors for communication with the Developer Shell server."""
class NoDevshellServer(Error): class NoDevshellServer(Error):
"""Error when no Developer Shell server can be contacted.""" """Error when no Developer Shell server can be contacted."""
# The request for credential information to the Developer Shell client socket is # The request for credential information to the Developer Shell client socket is
# always an empty PBLite-formatted JSON object, so just define it as a constant. # always an empty PBLite-formatted JSON object, so just define it as a constant.
@@ -43,7 +41,7 @@ CREDENTIAL_INFO_REQUEST_JSON = '[]'
class CredentialInfoResponse(object): class CredentialInfoResponse(object):
"""Credential information response from Developer Shell server. """Credential information response from Developer Shell server.
The credential information response from Developer Shell socket is a The credential information response from Developer Shell socket is a
PBLite-formatted JSON array with fields encoded by their index in the array: PBLite-formatted JSON array with fields encoded by their index in the array:
@@ -52,46 +50,46 @@ class CredentialInfoResponse(object):
* Index 2 - OAuth2 access token. None if there is no valid auth context. * Index 2 - OAuth2 access token. None if there is no valid auth context.
""" """
def __init__(self, json_string): def __init__(self, json_string):
"""Initialize the response data from JSON PBLite array.""" """Initialize the response data from JSON PBLite array."""
pbl = json.loads(json_string) pbl = json.loads(json_string)
if not isinstance(pbl, list): if not isinstance(pbl, list):
raise ValueError('Not a list: ' + str(pbl)) raise ValueError('Not a list: ' + str(pbl))
pbl_len = len(pbl) pbl_len = len(pbl)
self.user_email = pbl[0] if pbl_len > 0 else None self.user_email = pbl[0] if pbl_len > 0 else None
self.project_id = pbl[1] if pbl_len > 1 else None self.project_id = pbl[1] if pbl_len > 1 else None
self.access_token = pbl[2] if pbl_len > 2 else None self.access_token = pbl[2] if pbl_len > 2 else None
def _SendRecv(): def _SendRecv():
"""Communicate with the Developer Shell server socket.""" """Communicate with the Developer Shell server socket."""
port = int(os.getenv(DEVSHELL_ENV, 0)) port = int(os.getenv(DEVSHELL_ENV, 0))
if port == 0: if port == 0:
raise NoDevshellServer() raise NoDevshellServer()
import socket import socket
sock = socket.socket() sock = socket.socket()
sock.connect(('localhost', port)) sock.connect(('localhost', port))
data = CREDENTIAL_INFO_REQUEST_JSON data = CREDENTIAL_INFO_REQUEST_JSON
msg = '%s\n%s' % (len(data), data) msg = '%s\n%s' % (len(data), data)
sock.sendall(_to_bytes(msg, encoding='utf-8')) sock.sendall(_to_bytes(msg, encoding='utf-8'))
header = sock.recv(6).decode() header = sock.recv(6).decode()
if '\n' not in header: if '\n' not in header:
raise CommunicationError('saw no newline in the first 6 bytes') raise CommunicationError('saw no newline in the first 6 bytes')
len_str, json_str = header.split('\n', 1) len_str, json_str = header.split('\n', 1)
to_read = int(len_str) - len(json_str) to_read = int(len_str) - len(json_str)
if to_read > 0: if to_read > 0:
json_str += sock.recv(to_read, socket.MSG_WAITALL).decode() json_str += sock.recv(to_read, socket.MSG_WAITALL).decode()
return CredentialInfoResponse(json_str) return CredentialInfoResponse(json_str)
class DevshellCredentials(client.GoogleCredentials): class DevshellCredentials(client.GoogleCredentials):
"""Credentials object for Google Developer Shell environment. """Credentials object for Google Developer Shell environment.
This object will allow a Google Developer Shell session to identify its user This object will allow a Google Developer Shell session to identify its user
to Google and other OAuth 2.0 servers that can verify assertions. It can be to Google and other OAuth 2.0 servers that can verify assertions. It can be
@@ -102,8 +100,8 @@ class DevshellCredentials(client.GoogleCredentials):
generate and refresh its own access tokens. generate and refresh its own access tokens.
""" """
def __init__(self, user_agent=None): def __init__(self, user_agent=None):
super(DevshellCredentials, self).__init__( super(DevshellCredentials, self).__init__(
None, # access_token, initialized below None, # access_token, initialized below
None, # client_id None, # client_id
None, # client_secret None, # client_secret
@@ -111,26 +109,26 @@ class DevshellCredentials(client.GoogleCredentials):
None, # token_expiry None, # token_expiry
None, # token_uri None, # token_uri
user_agent) user_agent)
self._refresh(None) self._refresh(None)
def _refresh(self, http_request): def _refresh(self, http_request):
self.devshell_response = _SendRecv() self.devshell_response = _SendRecv()
self.access_token = self.devshell_response.access_token self.access_token = self.devshell_response.access_token
@property @property
def user_email(self): def user_email(self):
return self.devshell_response.user_email return self.devshell_response.user_email
@property @property
def project_id(self): def project_id(self):
return self.devshell_response.project_id return self.devshell_response.project_id
@classmethod @classmethod
def from_json(cls, json_data): def from_json(cls, json_data):
raise NotImplementedError( raise NotImplementedError(
'Cannot load Developer Shell credentials from JSON.') 'Cannot load Developer Shell credentials from JSON.')
@property @property
def serialization_data(self): def serialization_data(self):
raise NotImplementedError( raise NotImplementedError(
'Cannot serialize Developer Shell credentials.') 'Cannot serialize Developer Shell credentials.')

View File

@@ -27,58 +27,59 @@ import pickle
from django.db import models from django.db import models
from oauth2client.client import Storage as BaseStorage from oauth2client.client import Storage as BaseStorage
class CredentialsField(models.Field): class CredentialsField(models.Field):
__metaclass__ = models.SubfieldBase __metaclass__ = models.SubfieldBase
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
if 'null' not in kwargs: if 'null' not in kwargs:
kwargs['null'] = True kwargs['null'] = True
super(CredentialsField, self).__init__(*args, **kwargs) super(CredentialsField, self).__init__(*args, **kwargs)
def get_internal_type(self): def get_internal_type(self):
return "TextField" return "TextField"
def to_python(self, value): def to_python(self, value):
if value is None: if value is None:
return None return None
if isinstance(value, oauth2client.client.Credentials): if isinstance(value, oauth2client.client.Credentials):
return value return value
return pickle.loads(base64.b64decode(value)) return pickle.loads(base64.b64decode(value))
def get_db_prep_value(self, value, connection, prepared=False): def get_db_prep_value(self, value, connection, prepared=False):
if value is None: if value is None:
return None return None
return base64.b64encode(pickle.dumps(value)) return base64.b64encode(pickle.dumps(value))
class FlowField(models.Field): class FlowField(models.Field):
__metaclass__ = models.SubfieldBase __metaclass__ = models.SubfieldBase
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
if 'null' not in kwargs: if 'null' not in kwargs:
kwargs['null'] = True kwargs['null'] = True
super(FlowField, self).__init__(*args, **kwargs) super(FlowField, self).__init__(*args, **kwargs)
def get_internal_type(self): def get_internal_type(self):
return "TextField" return "TextField"
def to_python(self, value): def to_python(self, value):
if value is None: if value is None:
return None return None
if isinstance(value, oauth2client.client.Flow): if isinstance(value, oauth2client.client.Flow):
return value return value
return pickle.loads(base64.b64decode(value)) return pickle.loads(base64.b64decode(value))
def get_db_prep_value(self, value, connection, prepared=False): def get_db_prep_value(self, value, connection, prepared=False):
if value is None: if value is None:
return None return None
return base64.b64encode(pickle.dumps(value)) return base64.b64encode(pickle.dumps(value))
class Storage(BaseStorage): class Storage(BaseStorage):
"""Store and retrieve a single credential to and from """Store and retrieve a single credential to and from
the datastore. the datastore.
This Storage helper presumes the Credentials This Storage helper presumes the Credentials
@@ -86,8 +87,8 @@ class Storage(BaseStorage):
on a db model class. on a db model class.
""" """
def __init__(self, model_class, key_name, key_value, property_name): def __init__(self, model_class, key_name, key_value, property_name):
"""Constructor for Storage. """Constructor for Storage.
Args: Args:
model: db.Model, model class model: db.Model, model class
@@ -95,47 +96,47 @@ class Storage(BaseStorage):
key_value: string, key value for the entity that has the credentials key_value: string, key value for the entity that has the credentials
property_name: string, name of the property that is an CredentialsProperty property_name: string, name of the property that is an CredentialsProperty
""" """
self.model_class = model_class self.model_class = model_class
self.key_name = key_name self.key_name = key_name
self.key_value = key_value self.key_value = key_value
self.property_name = property_name self.property_name = property_name
def locked_get(self): def locked_get(self):
"""Retrieve Credential from datastore. """Retrieve Credential from datastore.
Returns: Returns:
oauth2client.Credentials oauth2client.Credentials
""" """
credential = None credential = None
query = {self.key_name: self.key_value} query = {self.key_name: self.key_value}
entities = self.model_class.objects.filter(**query) entities = self.model_class.objects.filter(**query)
if len(entities) > 0: if len(entities) > 0:
credential = getattr(entities[0], self.property_name) credential = getattr(entities[0], self.property_name)
if credential and hasattr(credential, 'set_store'): if credential and hasattr(credential, 'set_store'):
credential.set_store(self) credential.set_store(self)
return credential return credential
def locked_put(self, credentials, overwrite=False): def locked_put(self, credentials, overwrite=False):
"""Write a Credentials to the datastore. """Write a Credentials to the datastore.
Args: Args:
credentials: Credentials, the credentials to store. credentials: Credentials, the credentials to store.
overwrite: Boolean, indicates whether you would like these credentials to overwrite: Boolean, indicates whether you would like these credentials to
overwrite any existing stored credentials. overwrite any existing stored credentials.
""" """
args = {self.key_name: self.key_value} args = {self.key_name: self.key_value}
if overwrite: if overwrite:
entity, unused_is_new = self.model_class.objects.get_or_create(**args) entity, unused_is_new = self.model_class.objects.get_or_create(**args)
else: else:
entity = self.model_class(**args) entity = self.model_class(**args)
setattr(entity, self.property_name, credentials) setattr(entity, self.property_name, credentials)
entity.save() entity.save()
def locked_delete(self): def locked_delete(self):
"""Delete Credentials from the datastore.""" """Delete Credentials from the datastore."""
query = {self.key_name: self.key_value} query = {self.key_name: self.key_value}
entities = self.model_class.objects.filter(**query).delete() entities = self.model_class.objects.filter(**query).delete()

View File

@@ -28,37 +28,37 @@ from oauth2client.client import Storage as BaseStorage
class CredentialsFileSymbolicLinkError(Exception): class CredentialsFileSymbolicLinkError(Exception):
"""Credentials files must not be symbolic links.""" """Credentials files must not be symbolic links."""
class Storage(BaseStorage): class Storage(BaseStorage):
"""Store and retrieve a single credential to and from a file.""" """Store and retrieve a single credential to and from a file."""
def __init__(self, filename): def __init__(self, filename):
self._filename = filename self._filename = filename
self._lock = threading.Lock() self._lock = threading.Lock()
def _validate_file(self): def _validate_file(self):
if os.path.islink(self._filename): if os.path.islink(self._filename):
raise CredentialsFileSymbolicLinkError( raise CredentialsFileSymbolicLinkError(
'File: %s is a symbolic link.' % self._filename) 'File: %s is a symbolic link.' % self._filename)
def acquire_lock(self): def acquire_lock(self):
"""Acquires any lock necessary to access this Storage. """Acquires any lock necessary to access this Storage.
This lock is not reentrant.""" This lock is not reentrant."""
self._lock.acquire() self._lock.acquire()
def release_lock(self): def release_lock(self):
"""Release the Storage lock. """Release the Storage lock.
Trying to release a lock that isn't held will result in a Trying to release a lock that isn't held will result in a
RuntimeError. RuntimeError.
""" """
self._lock.release() self._lock.release()
def locked_get(self): def locked_get(self):
"""Retrieve Credential from file. """Retrieve Credential from file.
Returns: Returns:
oauth2client.client.Credentials oauth2client.client.Credentials
@@ -66,38 +66,38 @@ class Storage(BaseStorage):
Raises: Raises:
CredentialsFileSymbolicLinkError if the file is a symbolic link. CredentialsFileSymbolicLinkError if the file is a symbolic link.
""" """
credentials = None credentials = None
self._validate_file() self._validate_file()
try: try:
f = open(self._filename, 'rb') f = open(self._filename, 'rb')
content = f.read() content = f.read()
f.close() f.close()
except IOError: except IOError:
return credentials return credentials
try: try:
credentials = Credentials.new_from_json(content) credentials = Credentials.new_from_json(content)
credentials.set_store(self) credentials.set_store(self)
except ValueError: except ValueError:
pass pass
return credentials return credentials
def _create_file_if_needed(self): def _create_file_if_needed(self):
"""Create an empty file if necessary. """Create an empty file if necessary.
This method will not initialize the file. Instead it implements a This method will not initialize the file. Instead it implements a
simple version of "touch" to ensure the file has been created. simple version of "touch" to ensure the file has been created.
""" """
if not os.path.exists(self._filename): if not os.path.exists(self._filename):
old_umask = os.umask(0o177) old_umask = os.umask(0o177)
try: try:
open(self._filename, 'a+b').close() open(self._filename, 'a+b').close()
finally: finally:
os.umask(old_umask) os.umask(old_umask)
def locked_put(self, credentials): def locked_put(self, credentials):
"""Write Credentials to file. """Write Credentials to file.
Args: Args:
credentials: Credentials, the credentials to store. credentials: Credentials, the credentials to store.
@@ -106,17 +106,17 @@ class Storage(BaseStorage):
CredentialsFileSymbolicLinkError if the file is a symbolic link. CredentialsFileSymbolicLinkError if the file is a symbolic link.
""" """
self._create_file_if_needed() self._create_file_if_needed()
self._validate_file() self._validate_file()
f = open(self._filename, 'w') f = open(self._filename, 'w')
f.write(credentials.to_json()) f.write(credentials.to_json())
f.close() f.close()
def locked_delete(self): def locked_delete(self):
"""Delete Credentials file. """Delete Credentials file.
Args: Args:
credentials: Credentials, the credentials to store. credentials: Credentials, the credentials to store.
""" """
os.unlink(self._filename) os.unlink(self._filename)

View File

@@ -189,8 +189,7 @@ from oauth2client.client import Storage
from oauth2client import clientsecrets from oauth2client import clientsecrets
from oauth2client import util from oauth2client import util
DEFAULT_SCOPES = ('email', )
DEFAULT_SCOPES = ('email',)
class UserOAuth2(object): class UserOAuth2(object):
@@ -461,6 +460,7 @@ class UserOAuth2(object):
be redirected to the authorization flow. Once complete, the user will be redirected to the authorization flow. Once complete, the user will
be redirected back to the original page. be redirected back to the original page.
""" """
def curry_wrapper(wrapped_function): def curry_wrapper(wrapped_function):
@wraps(wrapped_function) @wraps(wrapped_function)
def required_wrapper(*args, **kwargs): def required_wrapper(*args, **kwargs):
@@ -519,6 +519,7 @@ class FlaskSessionStorage(Storage):
credentials. We strongly recommend using a server-side session credentials. We strongly recommend using a server-side session
implementation. implementation.
""" """
def locked_get(self): def locked_get(self):
serialized = session.get('google_oauth2_credentials') serialized = session.get('google_oauth2_credentials')

View File

@@ -36,7 +36,7 @@ META = ('http://metadata.google.internal/0.1/meta-data/service-accounts/'
class AppAssertionCredentials(AssertionCredentials): class AppAssertionCredentials(AssertionCredentials):
"""Credentials object for Compute Engine Assertion Grants """Credentials object for Compute Engine Assertion Grants
This object will allow a Compute Engine instance to identify itself to This object will allow a Compute Engine instance to identify itself to
Google and other OAuth 2.0 servers that can verify assertions. It can be used Google and other OAuth 2.0 servers that can verify assertions. It can be used
@@ -48,27 +48,27 @@ class AppAssertionCredentials(AssertionCredentials):
generate and refresh its own access tokens. generate and refresh its own access tokens.
""" """
@util.positional(2) @util.positional(2)
def __init__(self, scope, **kwargs): def __init__(self, scope, **kwargs):
"""Constructor for AppAssertionCredentials """Constructor for AppAssertionCredentials
Args: Args:
scope: string or iterable of strings, scope(s) of the credentials being scope: string or iterable of strings, scope(s) of the credentials being
requested. requested.
""" """
self.scope = util.scopes_to_string(scope) self.scope = util.scopes_to_string(scope)
self.kwargs = kwargs self.kwargs = kwargs
# Assertion type is no longer used, but still in the parent class signature. # Assertion type is no longer used, but still in the parent class signature.
super(AppAssertionCredentials, self).__init__(None) super(AppAssertionCredentials, self).__init__(None)
@classmethod @classmethod
def from_json(cls, json_data): def from_json(cls, json_data):
data = json.loads(_from_bytes(json_data)) data = json.loads(_from_bytes(json_data))
return AppAssertionCredentials(data['scope']) return AppAssertionCredentials(data['scope'])
def _refresh(self, http_request): def _refresh(self, http_request):
"""Refreshes the access_token. """Refreshes the access_token.
Skip all the storage hoops and just refresh using the API. Skip all the storage hoops and just refresh using the API.
@@ -79,29 +79,29 @@ class AppAssertionCredentials(AssertionCredentials):
Raises: Raises:
AccessTokenRefreshError: When the refresh fails. AccessTokenRefreshError: When the refresh fails.
""" """
query = '?scope=%s' % urllib.parse.quote(self.scope, '') query = '?scope=%s' % urllib.parse.quote(self.scope, '')
uri = META.replace('{?scope}', query) uri = META.replace('{?scope}', query)
response, content = http_request(uri) response, content = http_request(uri)
content = _from_bytes(content) content = _from_bytes(content)
if response.status == 200: if response.status == 200:
try: try:
d = json.loads(content) d = json.loads(content)
except Exception as e: except Exception as e:
raise AccessTokenRefreshError(str(e)) raise AccessTokenRefreshError(str(e))
self.access_token = d['accessToken'] self.access_token = d['accessToken']
else: else:
if response.status == 404: if response.status == 404:
content += (' This can occur if a VM was created' content += (' This can occur if a VM was created'
' with no service account or scopes.') ' with no service account or scopes.')
raise AccessTokenRefreshError(content) raise AccessTokenRefreshError(content)
@property @property
def serialization_data(self): def serialization_data(self):
raise NotImplementedError( raise NotImplementedError(
'Cannot serialize credentials for GCE service accounts.') 'Cannot serialize credentials for GCE service accounts.')
def create_scoped_required(self): def create_scoped_required(self):
return not self.scope return not self.scope
def create_scoped(self, scopes): def create_scoped(self, scopes):
return AppAssertionCredentials(scopes, **self.kwargs) return AppAssertionCredentials(scopes, **self.kwargs)

View File

@@ -28,7 +28,7 @@ from oauth2client.client import Storage as BaseStorage
class Storage(BaseStorage): class Storage(BaseStorage):
"""Store and retrieve a single credential to and from the keyring. """Store and retrieve a single credential to and from the keyring.
To use this module you must have the keyring module installed. See To use this module you must have the keyring module installed. See
<http://pypi.python.org/pypi/keyring/>. This is an optional module and is not <http://pypi.python.org/pypi/keyring/>. This is an optional module and is not
@@ -48,63 +48,63 @@ class Storage(BaseStorage):
""" """
def __init__(self, service_name, user_name): def __init__(self, service_name, user_name):
"""Constructor. """Constructor.
Args: Args:
service_name: string, The name of the service under which the credentials service_name: string, The name of the service under which the credentials
are stored. are stored.
user_name: string, The name of the user to store credentials for. user_name: string, The name of the user to store credentials for.
""" """
self._service_name = service_name self._service_name = service_name
self._user_name = user_name self._user_name = user_name
self._lock = threading.Lock() self._lock = threading.Lock()
def acquire_lock(self): def acquire_lock(self):
"""Acquires any lock necessary to access this Storage. """Acquires any lock necessary to access this Storage.
This lock is not reentrant.""" This lock is not reentrant."""
self._lock.acquire() self._lock.acquire()
def release_lock(self): def release_lock(self):
"""Release the Storage lock. """Release the Storage lock.
Trying to release a lock that isn't held will result in a Trying to release a lock that isn't held will result in a
RuntimeError. RuntimeError.
""" """
self._lock.release() self._lock.release()
def locked_get(self): def locked_get(self):
"""Retrieve Credential from file. """Retrieve Credential from file.
Returns: Returns:
oauth2client.client.Credentials oauth2client.client.Credentials
""" """
credentials = None credentials = None
content = keyring.get_password(self._service_name, self._user_name) content = keyring.get_password(self._service_name, self._user_name)
if content is not None: if content is not None:
try: try:
credentials = Credentials.new_from_json(content) credentials = Credentials.new_from_json(content)
credentials.set_store(self) credentials.set_store(self)
except ValueError: except ValueError:
pass pass
return credentials return credentials
def locked_put(self, credentials): def locked_put(self, credentials):
"""Write Credentials to file. """Write Credentials to file.
Args: Args:
credentials: Credentials, the credentials to store. credentials: Credentials, the credentials to store.
""" """
keyring.set_password(self._service_name, self._user_name, keyring.set_password(self._service_name, self._user_name,
credentials.to_json()) credentials.to_json())
def locked_delete(self): def locked_delete(self):
"""Delete Credentials file. """Delete Credentials file.
Args: Args:
credentials: Credentials, the credentials to store. credentials: Credentials, the credentials to store.
""" """
keyring.set_password(self._service_name, self._user_name, '') keyring.set_password(self._service_name, self._user_name, '')

View File

@@ -45,68 +45,69 @@ logger = logging.getLogger(__name__)
class CredentialsFileSymbolicLinkError(Exception): class CredentialsFileSymbolicLinkError(Exception):
"""Credentials files must not be symbolic links.""" """Credentials files must not be symbolic links."""
class AlreadyLockedException(Exception): class AlreadyLockedException(Exception):
"""Trying to lock a file that has already been locked by the LockedFile.""" """Trying to lock a file that has already been locked by the LockedFile."""
pass pass
def validate_file(filename): def validate_file(filename):
if os.path.islink(filename): if os.path.islink(filename):
raise CredentialsFileSymbolicLinkError( raise CredentialsFileSymbolicLinkError(
'File: %s is a symbolic link.' % filename) 'File: %s is a symbolic link.' % filename)
class _Opener(object):
"""Base class for different locking primitives."""
def __init__(self, filename, mode, fallback_mode): class _Opener(object):
"""Create an Opener. """Base class for different locking primitives."""
def __init__(self, filename, mode, fallback_mode):
"""Create an Opener.
Args: Args:
filename: string, The pathname of the file. filename: string, The pathname of the file.
mode: string, The preferred mode to access the file with. mode: string, The preferred mode to access the file with.
fallback_mode: string, The mode to use if locking fails. fallback_mode: string, The mode to use if locking fails.
""" """
self._locked = False self._locked = False
self._filename = filename self._filename = filename
self._mode = mode self._mode = mode
self._fallback_mode = fallback_mode self._fallback_mode = fallback_mode
self._fh = None self._fh = None
self._lock_fd = None self._lock_fd = None
def is_locked(self): def is_locked(self):
"""Was the file locked.""" """Was the file locked."""
return self._locked return self._locked
def file_handle(self): def file_handle(self):
"""The file handle to the file. Valid only after opened.""" """The file handle to the file. Valid only after opened."""
return self._fh return self._fh
def filename(self): def filename(self):
"""The filename that is being locked.""" """The filename that is being locked."""
return self._filename return self._filename
def open_and_lock(self, timeout, delay): def open_and_lock(self, timeout, delay):
"""Open the file and lock it. """Open the file and lock it.
Args: Args:
timeout: float, How long to try to lock for. timeout: float, How long to try to lock for.
delay: float, How long to wait between retries. delay: float, How long to wait between retries.
""" """
pass pass
def unlock_and_close(self): def unlock_and_close(self):
"""Unlock and close the file.""" """Unlock and close the file."""
pass pass
class _PosixOpener(_Opener): class _PosixOpener(_Opener):
"""Lock files using Posix advisory lock files.""" """Lock files using Posix advisory lock files."""
def open_and_lock(self, timeout, delay): def open_and_lock(self, timeout, delay):
"""Open the file and lock it. """Open the file and lock it.
Tries to create a .lock file next to the file we're trying to open. Tries to create a .lock file next to the file we're trying to open.
@@ -119,141 +120,67 @@ class _PosixOpener(_Opener):
IOError: if the open fails. IOError: if the open fails.
CredentialsFileSymbolicLinkError if the file is a symbolic link. CredentialsFileSymbolicLinkError if the file is a symbolic link.
""" """
if self._locked: if self._locked:
raise AlreadyLockedException('File %s is already locked' % raise AlreadyLockedException('File %s is already locked' %
self._filename) self._filename)
self._locked = False self._locked = False
validate_file(self._filename) validate_file(self._filename)
try: try:
self._fh = open(self._filename, self._mode) self._fh = open(self._filename, self._mode)
except IOError as e: except IOError as e:
# If we can't access with _mode, try _fallback_mode and don't lock. # If we can't access with _mode, try _fallback_mode and don't lock.
if e.errno == errno.EACCES: if e.errno == errno.EACCES:
self._fh = open(self._filename, self._fallback_mode) self._fh = open(self._filename, self._fallback_mode)
return return
lock_filename = self._posix_lockfile(self._filename) lock_filename = self._posix_lockfile(self._filename)
start_time = time.time() start_time = time.time()
while True: while True:
try: try:
self._lock_fd = os.open(lock_filename, self._lock_fd = os.open(lock_filename,
os.O_CREAT|os.O_EXCL|os.O_RDWR) os.O_CREAT | os.O_EXCL | os.O_RDWR)
self._locked = True self._locked = True
break break
except OSError as e: except OSError as e:
if e.errno != errno.EEXIST: if e.errno != errno.EEXIST:
raise raise
if (time.time() - start_time) >= timeout: if (time.time() - start_time) >= timeout:
logger.warn('Could not acquire lock %s in %s seconds', logger.warn('Could not acquire lock %s in %s seconds',
lock_filename, timeout) lock_filename, timeout)
# Close the file and open in fallback_mode. # Close the file and open in fallback_mode.
if self._fh: if self._fh:
self._fh.close() self._fh.close()
self._fh = open(self._filename, self._fallback_mode) self._fh = open(self._filename, self._fallback_mode)
return return
time.sleep(delay) time.sleep(delay)
def unlock_and_close(self):
"""Unlock a file by removing the .lock file, and close the handle."""
if self._locked:
lock_filename = self._posix_lockfile(self._filename)
os.close(self._lock_fd)
os.unlink(lock_filename)
self._locked = False
self._lock_fd = None
if self._fh:
self._fh.close()
def _posix_lockfile(self, filename):
"""The name of the lock file to use for posix locking."""
return '%s.lock' % filename
try:
import fcntl
class _FcntlOpener(_Opener):
"""Open, lock, and unlock a file using fcntl.lockf."""
def open_and_lock(self, timeout, delay):
"""Open the file and lock it.
Args:
timeout: float, How long to try to lock for.
delay: float, How long to wait between retries
Raises:
AlreadyLockedException: if the lock is already acquired.
IOError: if the open fails.
CredentialsFileSymbolicLinkError if the file is a symbolic link.
"""
if self._locked:
raise AlreadyLockedException('File %s is already locked' %
self._filename)
start_time = time.time()
validate_file(self._filename)
try:
self._fh = open(self._filename, self._mode)
except IOError as e:
# If we can't access with _mode, try _fallback_mode and don't lock.
if e.errno in (errno.EPERM, errno.EACCES):
self._fh = open(self._filename, self._fallback_mode)
return
# We opened in _mode, try to lock the file.
while True:
try:
fcntl.lockf(self._fh.fileno(), fcntl.LOCK_EX)
self._locked = True
return
except IOError as e:
# If not retrying, then just pass on the error.
if timeout == 0:
raise
if e.errno != errno.EACCES:
raise
# We could not acquire the lock. Try again.
if (time.time() - start_time) >= timeout:
logger.warn('Could not lock %s in %s seconds',
self._filename, timeout)
if self._fh:
self._fh.close()
self._fh = open(self._filename, self._fallback_mode)
return
time.sleep(delay)
def unlock_and_close(self): def unlock_and_close(self):
"""Close and unlock the file using the fcntl.lockf primitive.""" """Unlock a file by removing the .lock file, and close the handle."""
if self._locked: if self._locked:
fcntl.lockf(self._fh.fileno(), fcntl.LOCK_UN) lock_filename = self._posix_lockfile(self._filename)
self._locked = False os.close(self._lock_fd)
if self._fh: os.unlink(lock_filename)
self._fh.close() self._locked = False
except ImportError: self._lock_fd = None
_FcntlOpener = None if self._fh:
self._fh.close()
def _posix_lockfile(self, filename):
"""The name of the lock file to use for posix locking."""
return '%s.lock' % filename
try: try:
import pywintypes import fcntl
import win32con
import win32file
class _Win32Opener(_Opener):
"""Open, lock, and unlock a file using windows primitives."""
# Error #33: class _FcntlOpener(_Opener):
# 'The process cannot access the file because another process' """Open, lock, and unlock a file using fcntl.lockf."""
FILE_IN_USE_ERROR = 33
# Error #158: def open_and_lock(self, timeout, delay):
# 'The segment is already unlocked.' """Open the file and lock it.
FILE_ALREADY_UNLOCKED_ERROR = 158
def open_and_lock(self, timeout, delay):
"""Open the file and lock it.
Args: Args:
timeout: float, How long to try to lock for. timeout: float, How long to try to lock for.
@@ -264,71 +191,147 @@ try:
IOError: if the open fails. IOError: if the open fails.
CredentialsFileSymbolicLinkError if the file is a symbolic link. CredentialsFileSymbolicLinkError if the file is a symbolic link.
""" """
if self._locked: if self._locked:
raise AlreadyLockedException('File %s is already locked' % raise AlreadyLockedException('File %s is already locked' %
self._filename) self._filename)
start_time = time.time() start_time = time.time()
validate_file(self._filename) validate_file(self._filename)
try: try:
self._fh = open(self._filename, self._mode) self._fh = open(self._filename, self._mode)
except IOError as e: except IOError as e:
# If we can't access with _mode, try _fallback_mode and don't lock. # If we can't access with _mode, try _fallback_mode and don't lock.
if e.errno == errno.EACCES: if e.errno in (errno.EPERM, errno.EACCES):
self._fh = open(self._filename, self._fallback_mode) self._fh = open(self._filename, self._fallback_mode)
return return
# We opened in _mode, try to lock the file. # We opened in _mode, try to lock the file.
while True: while True:
try: try:
hfile = win32file._get_osfhandle(self._fh.fileno()) fcntl.lockf(self._fh.fileno(), fcntl.LOCK_EX)
win32file.LockFileEx( self._locked = True
return
except IOError as e:
# If not retrying, then just pass on the error.
if timeout == 0:
raise
if e.errno != errno.EACCES:
raise
# We could not acquire the lock. Try again.
if (time.time() - start_time) >= timeout:
logger.warn('Could not lock %s in %s seconds',
self._filename, timeout)
if self._fh:
self._fh.close()
self._fh = open(self._filename, self._fallback_mode)
return
time.sleep(delay)
def unlock_and_close(self):
"""Close and unlock the file using the fcntl.lockf primitive."""
if self._locked:
fcntl.lockf(self._fh.fileno(), fcntl.LOCK_UN)
self._locked = False
if self._fh:
self._fh.close()
except ImportError:
_FcntlOpener = None
try:
import pywintypes
import win32con
import win32file
class _Win32Opener(_Opener):
"""Open, lock, and unlock a file using windows primitives."""
# Error #33:
# 'The process cannot access the file because another process'
FILE_IN_USE_ERROR = 33
# Error #158:
# 'The segment is already unlocked.'
FILE_ALREADY_UNLOCKED_ERROR = 158
def open_and_lock(self, timeout, delay):
"""Open the file and lock it.
Args:
timeout: float, How long to try to lock for.
delay: float, How long to wait between retries
Raises:
AlreadyLockedException: if the lock is already acquired.
IOError: if the open fails.
CredentialsFileSymbolicLinkError if the file is a symbolic link.
"""
if self._locked:
raise AlreadyLockedException('File %s is already locked' %
self._filename)
start_time = time.time()
validate_file(self._filename)
try:
self._fh = open(self._filename, self._mode)
except IOError as e:
# If we can't access with _mode, try _fallback_mode and don't lock.
if e.errno == errno.EACCES:
self._fh = open(self._filename, self._fallback_mode)
return
# We opened in _mode, try to lock the file.
while True:
try:
hfile = win32file._get_osfhandle(self._fh.fileno())
win32file.LockFileEx(
hfile, hfile,
(win32con.LOCKFILE_FAIL_IMMEDIATELY| (win32con.LOCKFILE_FAIL_IMMEDIATELY |
win32con.LOCKFILE_EXCLUSIVE_LOCK), 0, -0x10000, win32con.LOCKFILE_EXCLUSIVE_LOCK), 0, -0x10000,
pywintypes.OVERLAPPED()) pywintypes.OVERLAPPED())
self._locked = True self._locked = True
return return
except pywintypes.error as e: except pywintypes.error as e:
if timeout == 0: if timeout == 0:
raise raise
# If the error is not that the file is already in use, raise. # If the error is not that the file is already in use, raise.
if e[0] != _Win32Opener.FILE_IN_USE_ERROR: if e[0] != _Win32Opener.FILE_IN_USE_ERROR:
raise raise
# We could not acquire the lock. Try again. # We could not acquire the lock. Try again.
if (time.time() - start_time) >= timeout: if (time.time() - start_time) >= timeout:
logger.warn('Could not lock %s in %s seconds' % ( logger.warn('Could not lock %s in %s seconds' % (
self._filename, timeout)) self._filename, timeout))
if self._fh: if self._fh:
self._fh.close() self._fh.close()
self._fh = open(self._filename, self._fallback_mode) self._fh = open(self._filename, self._fallback_mode)
return return
time.sleep(delay) time.sleep(delay)
def unlock_and_close(self): def unlock_and_close(self):
"""Close and unlock the file using the win32 primitive.""" """Close and unlock the file using the win32 primitive."""
if self._locked: if self._locked:
try: try:
hfile = win32file._get_osfhandle(self._fh.fileno()) hfile = win32file._get_osfhandle(self._fh.fileno())
win32file.UnlockFileEx(hfile, 0, -0x10000, pywintypes.OVERLAPPED()) win32file.UnlockFileEx(hfile, 0, -0x10000, pywintypes.OVERLAPPED())
except pywintypes.error as e: except pywintypes.error as e:
if e[0] != _Win32Opener.FILE_ALREADY_UNLOCKED_ERROR: if e[0] != _Win32Opener.FILE_ALREADY_UNLOCKED_ERROR:
raise raise
self._locked = False self._locked = False
if self._fh: if self._fh:
self._fh.close() self._fh.close()
except ImportError: except ImportError:
_Win32Opener = None _Win32Opener = None
class LockedFile(object): class LockedFile(object):
"""Represent a file that has exclusive access.""" """Represent a file that has exclusive access."""
@util.positional(4) @util.positional(4)
def __init__(self, filename, mode, fallback_mode, use_native_locking=True): def __init__(self, filename, mode, fallback_mode, use_native_locking=True):
"""Construct a LockedFile. """Construct a LockedFile.
Args: Args:
filename: string, The path of the file to open. filename: string, The path of the file to open.
@@ -336,32 +339,32 @@ class LockedFile(object):
fallback_mode: string, The mode to use if locking fails. fallback_mode: string, The mode to use if locking fails.
use_native_locking: bool, Whether or not fcntl/win32 locking is used. use_native_locking: bool, Whether or not fcntl/win32 locking is used.
""" """
opener = None opener = None
if not opener and use_native_locking: if not opener and use_native_locking:
if _Win32Opener: if _Win32Opener:
opener = _Win32Opener(filename, mode, fallback_mode) opener = _Win32Opener(filename, mode, fallback_mode)
if _FcntlOpener: if _FcntlOpener:
opener = _FcntlOpener(filename, mode, fallback_mode) opener = _FcntlOpener(filename, mode, fallback_mode)
if not opener: if not opener:
opener = _PosixOpener(filename, mode, fallback_mode) opener = _PosixOpener(filename, mode, fallback_mode)
self._opener = opener self._opener = opener
def filename(self): def filename(self):
"""Return the filename we were constructed with.""" """Return the filename we were constructed with."""
return self._opener._filename return self._opener._filename
def file_handle(self): def file_handle(self):
"""Return the file_handle to the opened file.""" """Return the file_handle to the opened file."""
return self._opener.file_handle() return self._opener.file_handle()
def is_locked(self): def is_locked(self):
"""Return whether we successfully locked the file.""" """Return whether we successfully locked the file."""
return self._opener.is_locked() return self._opener.is_locked()
def open_and_lock(self, timeout=0, delay=0.05): def open_and_lock(self, timeout=0, delay=0.05):
"""Open the file, trying to lock it. """Open the file, trying to lock it.
Args: Args:
timeout: float, The number of seconds to try to acquire the lock. timeout: float, The number of seconds to try to acquire the lock.
@@ -371,8 +374,8 @@ class LockedFile(object):
AlreadyLockedException: if the lock is already acquired. AlreadyLockedException: if the lock is already acquired.
IOError: if the open fails. IOError: if the open fails.
""" """
self._opener.open_and_lock(timeout, delay) self._opener.open_and_lock(timeout, delay)
def unlock_and_close(self): def unlock_and_close(self):
"""Unlock and close a file.""" """Unlock and close a file."""
self._opener.unlock_and_close() self._opener.unlock_and_close()

View File

@@ -65,17 +65,17 @@ _multistores_lock = threading.Lock()
class Error(Exception): class Error(Exception):
"""Base error for this module.""" """Base error for this module."""
class NewerCredentialStoreError(Error): class NewerCredentialStoreError(Error):
"""The credential store is a newer version than supported.""" """The credential store is a newer version than supported."""
@util.positional(4) @util.positional(4)
def get_credential_storage(filename, client_id, user_agent, scope, def get_credential_storage(filename, client_id, user_agent, scope,
warn_on_readonly=True): warn_on_readonly=True):
"""Get a Storage instance for a credential. """Get a Storage instance for a credential.
Args: Args:
filename: The JSON file storing a set of credentials filename: The JSON file storing a set of credentials
@@ -88,17 +88,17 @@ def get_credential_storage(filename, client_id, user_agent, scope,
An object derived from client.Storage for getting/setting the An object derived from client.Storage for getting/setting the
credential. credential.
""" """
# Recreate the legacy key with these specific parameters # Recreate the legacy key with these specific parameters
key = {'clientId': client_id, 'userAgent': user_agent, key = {'clientId': client_id, 'userAgent': user_agent,
'scope': util.scopes_to_string(scope)} 'scope': util.scopes_to_string(scope)}
return get_credential_storage_custom_key( return get_credential_storage_custom_key(
filename, key, warn_on_readonly=warn_on_readonly) filename, key, warn_on_readonly=warn_on_readonly)
@util.positional(2) @util.positional(2)
def get_credential_storage_custom_string_key( def get_credential_storage_custom_string_key(
filename, key_string, warn_on_readonly=True): filename, key_string, warn_on_readonly=True):
"""Get a Storage instance for a credential using a single string as a key. """Get a Storage instance for a credential using a single string as a key.
Allows you to provide a string as a custom key that will be used for Allows you to provide a string as a custom key that will be used for
credential storage and retrieval. credential storage and retrieval.
@@ -112,16 +112,16 @@ def get_credential_storage_custom_string_key(
An object derived from client.Storage for getting/setting the An object derived from client.Storage for getting/setting the
credential. credential.
""" """
# Create a key dictionary that can be used # Create a key dictionary that can be used
key_dict = {'key': key_string} key_dict = {'key': key_string}
return get_credential_storage_custom_key( return get_credential_storage_custom_key(
filename, key_dict, warn_on_readonly=warn_on_readonly) filename, key_dict, warn_on_readonly=warn_on_readonly)
@util.positional(2) @util.positional(2)
def get_credential_storage_custom_key( def get_credential_storage_custom_key(
filename, key_dict, warn_on_readonly=True): filename, key_dict, warn_on_readonly=True):
"""Get a Storage instance for a credential using a dictionary as a key. """Get a Storage instance for a credential using a dictionary as a key.
Allows you to provide a dictionary as a custom key that will be used for Allows you to provide a dictionary as a custom key that will be used for
credential storage and retrieval. credential storage and retrieval.
@@ -137,14 +137,14 @@ def get_credential_storage_custom_key(
An object derived from client.Storage for getting/setting the An object derived from client.Storage for getting/setting the
credential. credential.
""" """
multistore = _get_multistore(filename, warn_on_readonly=warn_on_readonly) multistore = _get_multistore(filename, warn_on_readonly=warn_on_readonly)
key = util.dict_to_tuple_key(key_dict) key = util.dict_to_tuple_key(key_dict)
return multistore._get_storage(key) return multistore._get_storage(key)
@util.positional(1) @util.positional(1)
def get_all_credential_keys(filename, warn_on_readonly=True): def get_all_credential_keys(filename, warn_on_readonly=True):
"""Gets all the registered credential keys in the given Multistore. """Gets all the registered credential keys in the given Multistore.
Args: Args:
filename: The JSON file storing a set of credentials filename: The JSON file storing a set of credentials
@@ -155,17 +155,17 @@ def get_all_credential_keys(filename, warn_on_readonly=True):
dictionaries that can be passed into get_credential_storage_custom_key to dictionaries that can be passed into get_credential_storage_custom_key to
get the actual credentials. get the actual credentials.
""" """
multistore = _get_multistore(filename, warn_on_readonly=warn_on_readonly) multistore = _get_multistore(filename, warn_on_readonly=warn_on_readonly)
multistore._lock() multistore._lock()
try: try:
return multistore._get_all_credential_keys() return multistore._get_all_credential_keys()
finally: finally:
multistore._unlock() multistore._unlock()
@util.positional(1) @util.positional(1)
def _get_multistore(filename, warn_on_readonly=True): def _get_multistore(filename, warn_on_readonly=True):
"""A helper method to initialize the multistore with proper locking. """A helper method to initialize the multistore with proper locking.
Args: Args:
filename: The JSON file storing a set of credentials filename: The JSON file storing a set of credentials
@@ -174,176 +174,176 @@ def _get_multistore(filename, warn_on_readonly=True):
Returns: Returns:
A multistore object A multistore object
""" """
filename = os.path.expanduser(filename) filename = os.path.expanduser(filename)
_multistores_lock.acquire() _multistores_lock.acquire()
try: try:
multistore = _multistores.setdefault( multistore = _multistores.setdefault(
filename, _MultiStore(filename, warn_on_readonly=warn_on_readonly)) filename, _MultiStore(filename, warn_on_readonly=warn_on_readonly))
finally: finally:
_multistores_lock.release() _multistores_lock.release()
return multistore return multistore
class _MultiStore(object): class _MultiStore(object):
"""A file backed store for multiple credentials.""" """A file backed store for multiple credentials."""
@util.positional(2) @util.positional(2)
def __init__(self, filename, warn_on_readonly=True): def __init__(self, filename, warn_on_readonly=True):
"""Initialize the class. """Initialize the class.
This will create the file if necessary. This will create the file if necessary.
""" """
self._file = LockedFile(filename, 'r+', 'r') self._file = LockedFile(filename, 'r+', 'r')
self._thread_lock = threading.Lock() self._thread_lock = threading.Lock()
self._read_only = False self._read_only = False
self._warn_on_readonly = warn_on_readonly self._warn_on_readonly = warn_on_readonly
self._create_file_if_needed() self._create_file_if_needed()
# Cache of deserialized store. This is only valid after the # Cache of deserialized store. This is only valid after the
# _MultiStore is locked or _refresh_data_cache is called. This is # _MultiStore is locked or _refresh_data_cache is called. This is
# of the form of: # of the form of:
# #
# ((key, value), (key, value)...) -> OAuth2Credential # ((key, value), (key, value)...) -> OAuth2Credential
# #
# If this is None, then the store hasn't been read yet. # If this is None, then the store hasn't been read yet.
self._data = None self._data = None
class _Storage(BaseStorage): class _Storage(BaseStorage):
"""A Storage object that knows how to read/write a single credential.""" """A Storage object that knows how to read/write a single credential."""
def __init__(self, multistore, key): def __init__(self, multistore, key):
self._multistore = multistore self._multistore = multistore
self._key = key self._key = key
def acquire_lock(self): def acquire_lock(self):
"""Acquires any lock necessary to access this Storage. """Acquires any lock necessary to access this Storage.
This lock is not reentrant. This lock is not reentrant.
""" """
self._multistore._lock() self._multistore._lock()
def release_lock(self): def release_lock(self):
"""Release the Storage lock. """Release the Storage lock.
Trying to release a lock that isn't held will result in a Trying to release a lock that isn't held will result in a
RuntimeError. RuntimeError.
""" """
self._multistore._unlock() self._multistore._unlock()
def locked_get(self): def locked_get(self):
"""Retrieve credential. """Retrieve credential.
The Storage lock must be held when this is called. The Storage lock must be held when this is called.
Returns: Returns:
oauth2client.client.Credentials oauth2client.client.Credentials
""" """
credential = self._multistore._get_credential(self._key) credential = self._multistore._get_credential(self._key)
if credential: if credential:
credential.set_store(self) credential.set_store(self)
return credential return credential
def locked_put(self, credentials): def locked_put(self, credentials):
"""Write a credential. """Write a credential.
The Storage lock must be held when this is called. The Storage lock must be held when this is called.
Args: Args:
credentials: Credentials, the credentials to store. credentials: Credentials, the credentials to store.
""" """
self._multistore._update_credential(self._key, credentials) self._multistore._update_credential(self._key, credentials)
def locked_delete(self): def locked_delete(self):
"""Delete a credential. """Delete a credential.
The Storage lock must be held when this is called. The Storage lock must be held when this is called.
Args: Args:
credentials: Credentials, the credentials to store. credentials: Credentials, the credentials to store.
""" """
self._multistore._delete_credential(self._key) self._multistore._delete_credential(self._key)
def _create_file_if_needed(self): def _create_file_if_needed(self):
"""Create an empty file if necessary. """Create an empty file if necessary.
This method will not initialize the file. Instead it implements a This method will not initialize the file. Instead it implements a
simple version of "touch" to ensure the file has been created. simple version of "touch" to ensure the file has been created.
""" """
if not os.path.exists(self._file.filename()): if not os.path.exists(self._file.filename()):
old_umask = os.umask(0o177) old_umask = os.umask(0o177)
try: try:
open(self._file.filename(), 'a+b').close() open(self._file.filename(), 'a+b').close()
finally: finally:
os.umask(old_umask) os.umask(old_umask)
def _lock(self): def _lock(self):
"""Lock the entire multistore.""" """Lock the entire multistore."""
self._thread_lock.acquire() self._thread_lock.acquire()
try: try:
self._file.open_and_lock() self._file.open_and_lock()
except IOError as e: except IOError as e:
if e.errno == errno.ENOSYS: if e.errno == errno.ENOSYS:
logger.warn('File system does not support locking the credentials ' logger.warn('File system does not support locking the credentials '
'file.') 'file.')
elif e.errno == errno.ENOLCK: elif e.errno == errno.ENOLCK:
logger.warn('File system is out of resources for writing the ' logger.warn('File system is out of resources for writing the '
'credentials file (is your disk full?).') 'credentials file (is your disk full?).')
else: else:
raise raise
if not self._file.is_locked(): if not self._file.is_locked():
self._read_only = True self._read_only = True
if self._warn_on_readonly: if self._warn_on_readonly:
logger.warn('The credentials file (%s) is not writable. Opening in ' logger.warn('The credentials file (%s) is not writable. Opening in '
'read-only mode. Any refreshed credentials will only be ' 'read-only mode. Any refreshed credentials will only be '
'valid for this run.', self._file.filename()) 'valid for this run.', self._file.filename())
if os.path.getsize(self._file.filename()) == 0: if os.path.getsize(self._file.filename()) == 0:
logger.debug('Initializing empty multistore file') logger.debug('Initializing empty multistore file')
# The multistore is empty so write out an empty file. # The multistore is empty so write out an empty file.
self._data = {} self._data = {}
self._write() self._write()
elif not self._read_only or self._data is None: elif not self._read_only or self._data is None:
# Only refresh the data if we are read/write or we haven't # Only refresh the data if we are read/write or we haven't
# cached the data yet. If we are readonly, we assume is isn't # cached the data yet. If we are readonly, we assume is isn't
# changing out from under us and that we only have to read it # changing out from under us and that we only have to read it
# once. This prevents us from whacking any new access keys that # once. This prevents us from whacking any new access keys that
# we have cached in memory but were unable to write out. # we have cached in memory but were unable to write out.
self._refresh_data_cache() self._refresh_data_cache()
def _unlock(self): def _unlock(self):
"""Release the lock on the multistore.""" """Release the lock on the multistore."""
self._file.unlock_and_close() self._file.unlock_and_close()
self._thread_lock.release() self._thread_lock.release()
def _locked_json_read(self): def _locked_json_read(self):
"""Get the raw content of the multistore file. """Get the raw content of the multistore file.
The multistore must be locked when this is called. The multistore must be locked when this is called.
Returns: Returns:
The contents of the multistore decoded as JSON. The contents of the multistore decoded as JSON.
""" """
assert self._thread_lock.locked() assert self._thread_lock.locked()
self._file.file_handle().seek(0) self._file.file_handle().seek(0)
return json.load(self._file.file_handle()) return json.load(self._file.file_handle())
def _locked_json_write(self, data): def _locked_json_write(self, data):
"""Write a JSON serializable data structure to the multistore. """Write a JSON serializable data structure to the multistore.
The multistore must be locked when this is called. The multistore must be locked when this is called.
Args: Args:
data: The data to be serialized and written. data: The data to be serialized and written.
""" """
assert self._thread_lock.locked() assert self._thread_lock.locked()
if self._read_only: if self._read_only:
return return
self._file.file_handle().seek(0) self._file.file_handle().seek(0)
json.dump(data, self._file.file_handle(), sort_keys=True, indent=2, separators=(',', ': ')) json.dump(data, self._file.file_handle(), sort_keys=True, indent=2, separators=(',', ': '))
self._file.file_handle().truncate() self._file.file_handle().truncate()
def _refresh_data_cache(self): def _refresh_data_cache(self):
"""Refresh the contents of the multistore. """Refresh the contents of the multistore.
The multistore must be locked when this is called. The multistore must be locked when this is called.
@@ -351,41 +351,41 @@ class _MultiStore(object):
NewerCredentialStoreError: Raised when a newer client has written the NewerCredentialStoreError: Raised when a newer client has written the
store. store.
""" """
self._data = {} self._data = {}
try: try:
raw_data = self._locked_json_read() raw_data = self._locked_json_read()
except Exception: except Exception:
logger.warn('Credential data store could not be loaded. ' logger.warn('Credential data store could not be loaded. '
'Will ignore and overwrite.') 'Will ignore and overwrite.')
return return
version = 0 version = 0
try: try:
version = raw_data['file_version'] version = raw_data['file_version']
except Exception: except Exception:
logger.warn('Missing version for credential data store. It may be ' logger.warn('Missing version for credential data store. It may be '
'corrupt or an old version. Overwriting.') 'corrupt or an old version. Overwriting.')
if version > 1: if version > 1:
raise NewerCredentialStoreError( raise NewerCredentialStoreError(
'Credential file has file_version of %d. ' 'Credential file has file_version of %d. '
'Only file_version of 1 is supported.' % version) 'Only file_version of 1 is supported.' % version)
credentials = [] credentials = []
try: try:
credentials = raw_data['data'] credentials = raw_data['data']
except (TypeError, KeyError): except (TypeError, KeyError):
pass pass
for cred_entry in credentials: for cred_entry in credentials:
try: try:
(key, credential) = self._decode_credential_from_json(cred_entry) (key, credential) = self._decode_credential_from_json(cred_entry)
self._data[key] = credential self._data[key] = credential
except: except:
# If something goes wrong loading a credential, just ignore it # If something goes wrong loading a credential, just ignore it
logger.info('Error decoding credential, skipping', exc_info=True) logger.info('Error decoding credential, skipping', exc_info=True)
def _decode_credential_from_json(self, cred_entry): def _decode_credential_from_json(self, cred_entry):
"""Load a credential from our JSON serialization. """Load a credential from our JSON serialization.
Args: Args:
cred_entry: A dict entry from the data member of our format cred_entry: A dict entry from the data member of our format
@@ -394,36 +394,36 @@ class _MultiStore(object):
(key, cred) where the key is the key tuple and the cred is the (key, cred) where the key is the key tuple and the cred is the
OAuth2Credential object. OAuth2Credential object.
""" """
raw_key = cred_entry['key'] raw_key = cred_entry['key']
key = util.dict_to_tuple_key(raw_key) key = util.dict_to_tuple_key(raw_key)
credential = None credential = None
credential = Credentials.new_from_json(json.dumps(cred_entry['credential'])) credential = Credentials.new_from_json(json.dumps(cred_entry['credential']))
return (key, credential) return (key, credential)
def _write(self): def _write(self):
"""Write the cached data back out. """Write the cached data back out.
The multistore must be locked. The multistore must be locked.
""" """
raw_data = {'file_version': 1} raw_data = {'file_version': 1}
raw_creds = [] raw_creds = []
raw_data['data'] = raw_creds raw_data['data'] = raw_creds
for (cred_key, cred) in self._data.items(): for (cred_key, cred) in self._data.items():
raw_key = dict(cred_key) raw_key = dict(cred_key)
raw_cred = json.loads(cred.to_json()) raw_cred = json.loads(cred.to_json())
raw_creds.append({'key': raw_key, 'credential': raw_cred}) raw_creds.append({'key': raw_key, 'credential': raw_cred})
self._locked_json_write(raw_data) self._locked_json_write(raw_data)
def _get_all_credential_keys(self): def _get_all_credential_keys(self):
"""Gets all the registered credential keys in the multistore. """Gets all the registered credential keys in the multistore.
Returns: Returns:
A list of dictionaries corresponding to all the keys currently registered A list of dictionaries corresponding to all the keys currently registered
""" """
return [dict(key) for key in self._data.keys()] return [dict(key) for key in self._data.keys()]
def _get_credential(self, key): def _get_credential(self, key):
"""Get a credential from the multistore. """Get a credential from the multistore.
The multistore must be locked. The multistore must be locked.
@@ -433,10 +433,10 @@ class _MultiStore(object):
Returns: Returns:
The credential specified or None if not present The credential specified or None if not present
""" """
return self._data.get(key, None) return self._data.get(key, None)
def _update_credential(self, key, cred): def _update_credential(self, key, cred):
"""Update a credential and write the multistore. """Update a credential and write the multistore.
This must be called when the multistore is locked. This must be called when the multistore is locked.
@@ -444,25 +444,25 @@ class _MultiStore(object):
key: The key used to retrieve the credential key: The key used to retrieve the credential
cred: The OAuth2Credential to update/set cred: The OAuth2Credential to update/set
""" """
self._data[key] = cred self._data[key] = cred
self._write() self._write()
def _delete_credential(self, key): def _delete_credential(self, key):
"""Delete a credential and write the multistore. """Delete a credential and write the multistore.
This must be called when the multistore is locked. This must be called when the multistore is locked.
Args: Args:
key: The key used to retrieve the credential key: The key used to retrieve the credential
""" """
try: try:
del self._data[key] del self._data[key]
except KeyError: except KeyError:
pass pass
self._write() self._write()
def _get_storage(self, key): def _get_storage(self, key):
"""Get a Storage object to get/set a credential. """Get a Storage object to get/set a credential.
This Storage is a 'view' into the multistore. This Storage is a 'view' into the multistore.
@@ -472,4 +472,4 @@ class _MultiStore(object):
Returns: Returns:
A Storage object that can be used to get/set this cred A Storage object that can be used to get/set this cred
""" """
return self._Storage(self, key) return self._Storage(self, key)

View File

@@ -30,7 +30,6 @@ from oauth2client import util
from oauth2client.tools import ClientRedirectHandler from oauth2client.tools import ClientRedirectHandler
from oauth2client.tools import ClientRedirectServer from oauth2client.tools import ClientRedirectServer
FLAGS = gflags.FLAGS FLAGS = gflags.FLAGS
gflags.DEFINE_boolean('auth_local_webserver', True, gflags.DEFINE_boolean('auth_local_webserver', True,
@@ -48,7 +47,7 @@ gflags.DEFINE_multi_int('auth_host_port', [8080, 8090],
@util.positional(2) @util.positional(2)
def run(flow, storage, http=None): def run(flow, storage, http=None):
"""Core code for a command-line application. """Core code for a command-line application.
The ``run()`` function is called from your application and runs The ``run()`` function is called from your application and runs
through all the steps to obtain credentials. It takes a ``Flow`` through all the steps to obtain credentials. It takes a ``Flow``
@@ -86,76 +85,76 @@ def run(flow, storage, http=None):
Returns: Returns:
Credentials, the obtained credential. Credentials, the obtained credential.
""" """
logging.warning('This function, oauth2client.tools.run(), and the use of ' logging.warning('This function, oauth2client.tools.run(), and the use of '
'the gflags library are deprecated and will be removed in a future ' 'the gflags library are deprecated and will be removed in a future '
'version of the library.') 'version of the library.')
if FLAGS.auth_local_webserver: if FLAGS.auth_local_webserver:
success = False success = False
port_number = 0 port_number = 0
for port in FLAGS.auth_host_port: for port in FLAGS.auth_host_port:
port_number = port port_number = port
try: try:
httpd = ClientRedirectServer((FLAGS.auth_host_name, port), httpd = ClientRedirectServer((FLAGS.auth_host_name, port),
ClientRedirectHandler) ClientRedirectHandler)
except socket.error as e: except socket.error as e:
pass pass
else: else:
success = True success = True
break break
FLAGS.auth_local_webserver = success FLAGS.auth_local_webserver = success
if not success: if not success:
print('Failed to start a local webserver listening on either port 8080') print('Failed to start a local webserver listening on either port 8080')
print('or port 9090. Please check your firewall settings and locally') print('or port 9090. Please check your firewall settings and locally')
print('running programs that may be blocking or using those ports.') print('running programs that may be blocking or using those ports.')
print() print()
print('Falling back to --noauth_local_webserver and continuing with') print('Falling back to --noauth_local_webserver and continuing with')
print('authorization.') print('authorization.')
print() print()
if FLAGS.auth_local_webserver: if FLAGS.auth_local_webserver:
oauth_callback = 'http://%s:%s/' % (FLAGS.auth_host_name, port_number) oauth_callback = 'http://%s:%s/' % (FLAGS.auth_host_name, port_number)
else:
oauth_callback = client.OOB_CALLBACK_URN
flow.redirect_uri = oauth_callback
authorize_url = flow.step1_get_authorize_url()
if FLAGS.auth_local_webserver:
webbrowser.open(authorize_url, new=1, autoraise=True)
print('Your browser has been opened to visit:')
print()
print(' ' + authorize_url)
print()
print('If your browser is on a different machine then exit and re-run')
print('this application with the command-line parameter ')
print()
print(' --noauth_local_webserver')
print()
else:
print('Go to the following link in your browser:')
print()
print(' ' + authorize_url)
print()
code = None
if FLAGS.auth_local_webserver:
httpd.handle_request()
if 'error' in httpd.query_params:
sys.exit('Authentication request was rejected.')
if 'code' in httpd.query_params:
code = httpd.query_params['code']
else: else:
print('Failed to find "code" in the query parameters of the redirect.') oauth_callback = client.OOB_CALLBACK_URN
sys.exit('Try running with --noauth_local_webserver.') flow.redirect_uri = oauth_callback
else: authorize_url = flow.step1_get_authorize_url()
code = input('Enter verification code: ').strip()
try: if FLAGS.auth_local_webserver:
credential = flow.step2_exchange(code, http=http) webbrowser.open(authorize_url, new=1, autoraise=True)
except client.FlowExchangeError as e: print('Your browser has been opened to visit:')
sys.exit('Authentication has failed: %s' % e) print()
print(' ' + authorize_url)
print()
print('If your browser is on a different machine then exit and re-run')
print('this application with the command-line parameter ')
print()
print(' --noauth_local_webserver')
print()
else:
print('Go to the following link in your browser:')
print()
print(' ' + authorize_url)
print()
storage.put(credential) code = None
credential.set_store(storage) if FLAGS.auth_local_webserver:
print('Authentication successful.') httpd.handle_request()
if 'error' in httpd.query_params:
sys.exit('Authentication request was rejected.')
if 'code' in httpd.query_params:
code = httpd.query_params['code']
else:
print('Failed to find "code" in the query parameters of the redirect.')
sys.exit('Try running with --noauth_local_webserver.')
else:
code = input('Enter verification code: ').strip()
return credential try:
credential = flow.step2_exchange(code, http=http)
except client.FlowExchangeError as e:
sys.exit('Authentication has failed: %s' % e)
storage.put(credential)
credential.set_store(storage)
print('Authentication successful.')
return credential

View File

@@ -34,83 +34,83 @@ from oauth2client.client import AssertionCredentials
class _ServiceAccountCredentials(AssertionCredentials): class _ServiceAccountCredentials(AssertionCredentials):
"""Class representing a service account (signed JWT) credential.""" """Class representing a service account (signed JWT) credential."""
MAX_TOKEN_LIFETIME_SECS = 3600 # 1 hour in seconds MAX_TOKEN_LIFETIME_SECS = 3600 # 1 hour in seconds
def __init__(self, service_account_id, service_account_email, private_key_id, def __init__(self, service_account_id, service_account_email, private_key_id,
private_key_pkcs8_text, scopes, user_agent=None, private_key_pkcs8_text, scopes, user_agent=None,
token_uri=GOOGLE_TOKEN_URI, revoke_uri=GOOGLE_REVOKE_URI, token_uri=GOOGLE_TOKEN_URI, revoke_uri=GOOGLE_REVOKE_URI,
**kwargs): **kwargs):
super(_ServiceAccountCredentials, self).__init__( super(_ServiceAccountCredentials, self).__init__(
None, user_agent=user_agent, token_uri=token_uri, revoke_uri=revoke_uri) None, user_agent=user_agent, token_uri=token_uri, revoke_uri=revoke_uri)
self._service_account_id = service_account_id self._service_account_id = service_account_id
self._service_account_email = service_account_email self._service_account_email = service_account_email
self._private_key_id = private_key_id self._private_key_id = private_key_id
self._private_key = _get_private_key(private_key_pkcs8_text) self._private_key = _get_private_key(private_key_pkcs8_text)
self._private_key_pkcs8_text = private_key_pkcs8_text self._private_key_pkcs8_text = private_key_pkcs8_text
self._scopes = util.scopes_to_string(scopes) self._scopes = util.scopes_to_string(scopes)
self._user_agent = user_agent self._user_agent = user_agent
self._token_uri = token_uri self._token_uri = token_uri
self._revoke_uri = revoke_uri self._revoke_uri = revoke_uri
self._kwargs = kwargs self._kwargs = kwargs
def _generate_assertion(self): def _generate_assertion(self):
"""Generate the assertion that will be used in the request.""" """Generate the assertion that will be used in the request."""
header = { header = {
'alg': 'RS256', 'alg': 'RS256',
'typ': 'JWT', 'typ': 'JWT',
'kid': self._private_key_id 'kid': self._private_key_id
} }
now = int(time.time()) now = int(time.time())
payload = { payload = {
'aud': self._token_uri, 'aud': self._token_uri,
'scope': self._scopes, 'scope': self._scopes,
'iat': now, 'iat': now,
'exp': now + _ServiceAccountCredentials.MAX_TOKEN_LIFETIME_SECS, 'exp': now + _ServiceAccountCredentials.MAX_TOKEN_LIFETIME_SECS,
'iss': self._service_account_email 'iss': self._service_account_email
} }
payload.update(self._kwargs) payload.update(self._kwargs)
first_segment = _urlsafe_b64encode(_json_encode(header)) first_segment = _urlsafe_b64encode(_json_encode(header))
second_segment = _urlsafe_b64encode(_json_encode(payload)) second_segment = _urlsafe_b64encode(_json_encode(payload))
assertion_input = first_segment + b'.' + second_segment assertion_input = first_segment + b'.' + second_segment
# Sign the assertion. # Sign the assertion.
rsa_bytes = rsa.pkcs1.sign(assertion_input, self._private_key, 'SHA-256') rsa_bytes = rsa.pkcs1.sign(assertion_input, self._private_key, 'SHA-256')
signature = base64.urlsafe_b64encode(rsa_bytes).rstrip(b'=') signature = base64.urlsafe_b64encode(rsa_bytes).rstrip(b'=')
return assertion_input + b'.' + signature return assertion_input + b'.' + signature
def sign_blob(self, blob): def sign_blob(self, blob):
# Ensure that it is bytes # Ensure that it is bytes
blob = _to_bytes(blob, encoding='utf-8') blob = _to_bytes(blob, encoding='utf-8')
return (self._private_key_id, return (self._private_key_id,
rsa.pkcs1.sign(blob, self._private_key, 'SHA-256')) rsa.pkcs1.sign(blob, self._private_key, 'SHA-256'))
@property @property
def service_account_email(self): def service_account_email(self):
return self._service_account_email return self._service_account_email
@property @property
def serialization_data(self): def serialization_data(self):
return { return {
'type': 'service_account', 'type': 'service_account',
'client_id': self._service_account_id, 'client_id': self._service_account_id,
'client_email': self._service_account_email, 'client_email': self._service_account_email,
'private_key_id': self._private_key_id, 'private_key_id': self._private_key_id,
'private_key': self._private_key_pkcs8_text 'private_key': self._private_key_pkcs8_text
} }
def create_scoped_required(self): def create_scoped_required(self):
return not self._scopes return not self._scopes
def create_scoped(self, scopes): def create_scoped(self, scopes):
return _ServiceAccountCredentials(self._service_account_id, return _ServiceAccountCredentials(self._service_account_id,
self._service_account_email, self._service_account_email,
self._private_key_id, self._private_key_id,
self._private_key_pkcs8_text, self._private_key_pkcs8_text,
@@ -122,10 +122,10 @@ class _ServiceAccountCredentials(AssertionCredentials):
def _get_private_key(private_key_pkcs8_text): def _get_private_key(private_key_pkcs8_text):
"""Get an RSA private key object from a pkcs8 representation.""" """Get an RSA private key object from a pkcs8 representation."""
private_key_pkcs8_text = _to_bytes(private_key_pkcs8_text) private_key_pkcs8_text = _to_bytes(private_key_pkcs8_text)
der = rsa.pem.load_pem(private_key_pkcs8_text, 'PRIVATE KEY') der = rsa.pem.load_pem(private_key_pkcs8_text, 'PRIVATE KEY')
asn1_private_key, _ = decoder.decode(der, asn1Spec=PrivateKeyInfo()) asn1_private_key, _ = decoder.decode(der, asn1Spec=PrivateKeyInfo())
return rsa.PrivateKey.load_pkcs1( return rsa.PrivateKey.load_pkcs1(
asn1_private_key.getComponentByName('privateKey').asOctets(), asn1_private_key.getComponentByName('privateKey').asOctets(),
format='DER') format='DER')

View File

@@ -35,7 +35,6 @@ from six.moves import input
from oauth2client import client from oauth2client import client
from oauth2client import util from oauth2client import util
_CLIENT_SECRETS_MESSAGE = """WARNING: Please configure OAuth 2.0 _CLIENT_SECRETS_MESSAGE = """WARNING: Please configure OAuth 2.0
To make this sample run you will need to populate the client_secrets.json file To make this sample run you will need to populate the client_secrets.json file
@@ -47,22 +46,23 @@ with information from the APIs Console <https://code.google.com/apis/console>.
""" """
def _CreateArgumentParser(): def _CreateArgumentParser():
try: try:
import argparse import argparse
except ImportError: except ImportError:
return None return None
parser = argparse.ArgumentParser(add_help=False) parser = argparse.ArgumentParser(add_help=False)
parser.add_argument('--auth_host_name', default='localhost', parser.add_argument('--auth_host_name', default='localhost',
help='Hostname when running a local web server.') help='Hostname when running a local web server.')
parser.add_argument('--noauth_local_webserver', action='store_true', parser.add_argument('--noauth_local_webserver', action='store_true',
default=False, help='Do not run a local web server.') default=False, help='Do not run a local web server.')
parser.add_argument('--auth_host_port', default=[8080, 8090], type=int, parser.add_argument('--auth_host_port', default=[8080, 8090], type=int,
nargs='*', help='Port web server should listen on.') nargs='*', help='Port web server should listen on.')
parser.add_argument('--logging_level', default='ERROR', parser.add_argument('--logging_level', default='ERROR',
choices=['DEBUG', 'INFO', 'WARNING', 'ERROR', 'CRITICAL'], choices=['DEBUG', 'INFO', 'WARNING', 'ERROR', 'CRITICAL'],
help='Set the logging level of detail.') help='Set the logging level of detail.')
return parser return parser
# argparser is an ArgumentParser that contains command-line options expected # argparser is an ArgumentParser that contains command-line options expected
# by tools.run(). Pass it in as part of the 'parents' argument to your own # by tools.run(). Pass it in as part of the 'parents' argument to your own
@@ -71,45 +71,45 @@ argparser = _CreateArgumentParser()
class ClientRedirectServer(BaseHTTPServer.HTTPServer): class ClientRedirectServer(BaseHTTPServer.HTTPServer):
"""A server to handle OAuth 2.0 redirects back to localhost. """A server to handle OAuth 2.0 redirects back to localhost.
Waits for a single request and parses the query parameters Waits for a single request and parses the query parameters
into query_params and then stops serving. into query_params and then stops serving.
""" """
query_params = {} query_params = {}
class ClientRedirectHandler(BaseHTTPServer.BaseHTTPRequestHandler): class ClientRedirectHandler(BaseHTTPServer.BaseHTTPRequestHandler):
"""A handler for OAuth 2.0 redirects back to localhost. """A handler for OAuth 2.0 redirects back to localhost.
Waits for a single request and parses the query parameters Waits for a single request and parses the query parameters
into the servers query_params and then stops serving. into the servers query_params and then stops serving.
""" """
def do_GET(self): def do_GET(self):
"""Handle a GET request. """Handle a GET request.
Parses the query parameters and prints a message Parses the query parameters and prints a message
if the flow has completed. Note that we can't detect if the flow has completed. Note that we can't detect
if an error occurred. if an error occurred.
""" """
self.send_response(200) self.send_response(200)
self.send_header("Content-type", "text/html") self.send_header("Content-type", "text/html")
self.end_headers() self.end_headers()
query = self.path.split('?', 1)[-1] query = self.path.split('?', 1)[-1]
query = dict(urllib.parse.parse_qsl(query)) query = dict(urllib.parse.parse_qsl(query))
self.server.query_params = query self.server.query_params = query
self.wfile.write(b"<html><head><title>Authentication Status</title></head>") self.wfile.write(b"<html><head><title>Authentication Status</title></head>")
self.wfile.write(b"<body><p>The authentication flow has completed.</p>") self.wfile.write(b"<body><p>The authentication flow has completed.</p>")
self.wfile.write(b"</body></html>") self.wfile.write(b"</body></html>")
def log_message(self, format, *args): def log_message(self, format, *args):
"""Do not log messages to stdout while running as command line program.""" """Do not log messages to stdout while running as command line program."""
@util.positional(3) @util.positional(3)
def run_flow(flow, storage, flags, http=None): def run_flow(flow, storage, flags, http=None):
"""Core code for a command-line application. """Core code for a command-line application.
The ``run()`` function is called from your application and runs The ``run()`` function is called from your application and runs
through all the steps to obtain credentials. It takes a ``Flow`` through all the steps to obtain credentials. It takes a ``Flow``
@@ -159,91 +159,91 @@ def run_flow(flow, storage, flags, http=None):
Returns: Returns:
Credentials, the obtained credential. Credentials, the obtained credential.
""" """
logging.getLogger().setLevel(getattr(logging, flags.logging_level)) logging.getLogger().setLevel(getattr(logging, flags.logging_level))
if not flags.noauth_local_webserver: if not flags.noauth_local_webserver:
success = False success = False
port_number = 0 port_number = 0
for port in flags.auth_host_port: for port in flags.auth_host_port:
port_number = port port_number = port
try: try:
httpd = ClientRedirectServer((flags.auth_host_name, port), httpd = ClientRedirectServer((flags.auth_host_name, port),
ClientRedirectHandler) ClientRedirectHandler)
except socket.error: except socket.error:
pass pass
else: else:
success = True success = True
break break
flags.noauth_local_webserver = not success flags.noauth_local_webserver = not success
if not success: if not success:
print('Failed to start a local webserver listening on either port 8080') print('Failed to start a local webserver listening on either port 8080')
print('or port 9090. Please check your firewall settings and locally') print('or port 9090. Please check your firewall settings and locally')
print('running programs that may be blocking or using those ports.') print('running programs that may be blocking or using those ports.')
print() print()
print('Falling back to --noauth_local_webserver and continuing with') print('Falling back to --noauth_local_webserver and continuing with')
print('authorization.') print('authorization.')
print() print()
if not flags.noauth_local_webserver: if not flags.noauth_local_webserver:
oauth_callback = 'http://%s:%s/' % (flags.auth_host_name, port_number) oauth_callback = 'http://%s:%s/' % (flags.auth_host_name, port_number)
else:
oauth_callback = client.OOB_CALLBACK_URN
flow.redirect_uri = oauth_callback
authorize_url = flow.step1_get_authorize_url()
if not flags.noauth_local_webserver:
import webbrowser
webbrowser.open(authorize_url, new=1, autoraise=True)
print('Your browser has been opened to visit:')
print()
print(' ' + authorize_url)
print()
print('If your browser is on a different machine then exit and re-run this')
print('application with the command-line parameter ')
print()
print(' --noauth_local_webserver')
print()
else:
print('Go to the following link in your browser:')
print()
print(' ' + authorize_url)
print()
code = None
if not flags.noauth_local_webserver:
httpd.handle_request()
if 'error' in httpd.query_params:
sys.exit('Authentication request was rejected.')
if 'code' in httpd.query_params:
code = httpd.query_params['code']
else: else:
print('Failed to find "code" in the query parameters of the redirect.') oauth_callback = client.OOB_CALLBACK_URN
sys.exit('Try running with --noauth_local_webserver.') flow.redirect_uri = oauth_callback
else: authorize_url = flow.step1_get_authorize_url()
code = input('Enter verification code: ').strip()
try: if not flags.noauth_local_webserver:
credential = flow.step2_exchange(code, http=http) import webbrowser
except client.FlowExchangeError as e: webbrowser.open(authorize_url, new=1, autoraise=True)
sys.exit('Authentication has failed: %s' % e) print('Your browser has been opened to visit:')
print()
print(' ' + authorize_url)
print()
print('If your browser is on a different machine then exit and re-run this')
print('application with the command-line parameter ')
print()
print(' --noauth_local_webserver')
print()
else:
print('Go to the following link in your browser:')
print()
print(' ' + authorize_url)
print()
storage.put(credential) code = None
credential.set_store(storage) if not flags.noauth_local_webserver:
print('Authentication successful.') httpd.handle_request()
if 'error' in httpd.query_params:
sys.exit('Authentication request was rejected.')
if 'code' in httpd.query_params:
code = httpd.query_params['code']
else:
print('Failed to find "code" in the query parameters of the redirect.')
sys.exit('Try running with --noauth_local_webserver.')
else:
code = input('Enter verification code: ').strip()
return credential try:
credential = flow.step2_exchange(code, http=http)
except client.FlowExchangeError as e:
sys.exit('Authentication has failed: %s' % e)
storage.put(credential)
credential.set_store(storage)
print('Authentication successful.')
return credential
def message_if_missing(filename): def message_if_missing(filename):
"""Helpful message to display if the CLIENT_SECRETS file is missing.""" """Helpful message to display if the CLIENT_SECRETS file is missing."""
return _CLIENT_SECRETS_MESSAGE % filename return _CLIENT_SECRETS_MESSAGE % filename
try: try:
from oauth2client.old_run import run from oauth2client.old_run import run
from oauth2client.old_run import FLAGS from oauth2client.old_run import FLAGS
except ImportError: except ImportError:
def run(*args, **kwargs): def run(*args, **kwargs):
raise NotImplementedError( raise NotImplementedError(
'The gflags library must be installed to use tools.run(). ' 'The gflags library must be installed to use tools.run(). '
'Please install gflags or preferrably switch to using ' 'Please install gflags or preferrably switch to using '
'tools.run_flow().') 'tools.run_flow().')

View File

@@ -38,7 +38,6 @@ import types
import six import six
from six.moves import urllib from six.moves import urllib
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
POSITIONAL_WARNING = 'WARNING' POSITIONAL_WARNING = 'WARNING'
@@ -49,8 +48,9 @@ POSITIONAL_SET = frozenset([POSITIONAL_WARNING, POSITIONAL_EXCEPTION,
positional_parameters_enforcement = POSITIONAL_WARNING positional_parameters_enforcement = POSITIONAL_WARNING
def positional(max_positional_args): def positional(max_positional_args):
"""A decorator to declare that only the first N arguments my be positional. """A decorator to declare that only the first N arguments my be positional.
This decorator makes it easy to support Python 3 style keyword-only This decorator makes it easy to support Python 3 style keyword-only
parameters. For example, in Python 3 it is possible to write:: parameters. For example, in Python 3 it is possible to write::
@@ -119,33 +119,34 @@ def positional(max_positional_args):
POSITIONAL_EXCEPTION. POSITIONAL_EXCEPTION.
""" """
def positional_decorator(wrapped):
@functools.wraps(wrapped)
def positional_wrapper(*args, **kwargs):
if len(args) > max_positional_args:
plural_s = ''
if max_positional_args != 1:
plural_s = 's'
message = '%s() takes at most %d positional argument%s (%d given)' % (
wrapped.__name__, max_positional_args, plural_s, len(args))
if positional_parameters_enforcement == POSITIONAL_EXCEPTION:
raise TypeError(message)
elif positional_parameters_enforcement == POSITIONAL_WARNING:
logger.warning(message)
else: # IGNORE
pass
return wrapped(*args, **kwargs)
return positional_wrapper
if isinstance(max_positional_args, six.integer_types): def positional_decorator(wrapped):
return positional_decorator @functools.wraps(wrapped)
else: def positional_wrapper(*args, **kwargs):
args, _, _, defaults = inspect.getargspec(max_positional_args) if len(args) > max_positional_args:
return positional(len(args) - len(defaults))(max_positional_args) plural_s = ''
if max_positional_args != 1:
plural_s = 's'
message = '%s() takes at most %d positional argument%s (%d given)' % (
wrapped.__name__, max_positional_args, plural_s, len(args))
if positional_parameters_enforcement == POSITIONAL_EXCEPTION:
raise TypeError(message)
elif positional_parameters_enforcement == POSITIONAL_WARNING:
logger.warning(message)
else: # IGNORE
pass
return wrapped(*args, **kwargs)
return positional_wrapper
if isinstance(max_positional_args, six.integer_types):
return positional_decorator
else:
args, _, _, defaults = inspect.getargspec(max_positional_args)
return positional(len(args) - len(defaults))(max_positional_args)
def scopes_to_string(scopes): def scopes_to_string(scopes):
"""Converts scope value to a string. """Converts scope value to a string.
If scopes is a string then it is simply passed through. If scopes is an If scopes is a string then it is simply passed through. If scopes is an
iterable then a string is returned that is all the individual scopes iterable then a string is returned that is all the individual scopes
@@ -157,14 +158,14 @@ def scopes_to_string(scopes):
Returns: Returns:
The scopes formatted as a single string. The scopes formatted as a single string.
""" """
if isinstance(scopes, six.string_types): if isinstance(scopes, six.string_types):
return scopes return scopes
else: else:
return ' '.join(scopes) return ' '.join(scopes)
def string_to_scopes(scopes): def string_to_scopes(scopes):
"""Converts stringifed scope value to a list. """Converts stringifed scope value to a list.
If scopes is a list then it is simply passed through. If scopes is an If scopes is a list then it is simply passed through. If scopes is an
string then a list of each individual scope is returned. string then a list of each individual scope is returned.
@@ -175,16 +176,16 @@ def string_to_scopes(scopes):
Returns: Returns:
The scopes in a list. The scopes in a list.
""" """
if not scopes: if not scopes:
return [] return []
if isinstance(scopes, six.string_types): if isinstance(scopes, six.string_types):
return scopes.split(' ') return scopes.split(' ')
else: else:
return scopes return scopes
def dict_to_tuple_key(dictionary): def dict_to_tuple_key(dictionary):
"""Converts a dictionary to a tuple that can be used as an immutable key. """Converts a dictionary to a tuple that can be used as an immutable key.
The resulting key is always sorted so that logically equivalent dictionaries The resulting key is always sorted so that logically equivalent dictionaries
always produce an identical tuple for a key. always produce an identical tuple for a key.
@@ -195,11 +196,11 @@ def dict_to_tuple_key(dictionary):
Returns: Returns:
A tuple representing the dictionary in it's naturally sorted ordering. A tuple representing the dictionary in it's naturally sorted ordering.
""" """
return tuple(sorted(dictionary.items())) return tuple(sorted(dictionary.items()))
def _add_query_parameter(url, name, value): def _add_query_parameter(url, name, value):
"""Adds a query parameter to a url. """Adds a query parameter to a url.
Replaces the current value if it already exists in the URL. Replaces the current value if it already exists in the URL.
@@ -211,11 +212,11 @@ def _add_query_parameter(url, name, value):
Returns: Returns:
Updated query parameter. Does not update the url if value is None. Updated query parameter. Does not update the url if value is None.
""" """
if value is None: if value is None:
return url return url
else: else:
parsed = list(urllib.parse.urlparse(url)) parsed = list(urllib.parse.urlparse(url))
q = dict(urllib.parse.parse_qsl(parsed[4])) q = dict(urllib.parse.parse_qsl(parsed[4]))
q[name] = value q[name] = value
parsed[4] = urllib.parse.urlencode(q) parsed[4] = urllib.parse.urlencode(q)
return urllib.parse.urlunparse(parsed) return urllib.parse.urlunparse(parsed)

View File

@@ -20,7 +20,6 @@ __authors__ = [
'"Joe Gregorio" <jcgregorio@google.com>', '"Joe Gregorio" <jcgregorio@google.com>',
] ]
import base64 import base64
import hmac import hmac
import time import time
@@ -28,13 +27,11 @@ import time
import six import six
from oauth2client import util from oauth2client import util
# Delimiter character # Delimiter character
DELIMITER = b':' DELIMITER = b':'
# 1 hour in seconds # 1 hour in seconds
DEFAULT_TIMEOUT_SECS = 1*60*60 DEFAULT_TIMEOUT_SECS = 1 * 60 * 60
def _force_bytes(s): def _force_bytes(s):
@@ -48,7 +45,7 @@ def _force_bytes(s):
@util.positional(2) @util.positional(2)
def generate_token(key, user_id, action_id="", when=None): def generate_token(key, user_id, action_id="", when=None):
"""Generates a URL-safe token for the given user, action, time tuple. """Generates a URL-safe token for the given user, action, time tuple.
Args: Args:
key: secret key to use. key: secret key to use.
@@ -61,22 +58,22 @@ def generate_token(key, user_id, action_id="", when=None):
Returns: Returns:
A string XSRF protection token. A string XSRF protection token.
""" """
when = _force_bytes(when or int(time.time())) when = _force_bytes(when or int(time.time()))
digester = hmac.new(_force_bytes(key)) digester = hmac.new(_force_bytes(key))
digester.update(_force_bytes(user_id)) digester.update(_force_bytes(user_id))
digester.update(DELIMITER) digester.update(DELIMITER)
digester.update(_force_bytes(action_id)) digester.update(_force_bytes(action_id))
digester.update(DELIMITER) digester.update(DELIMITER)
digester.update(when) digester.update(when)
digest = digester.digest() digest = digester.digest()
token = base64.urlsafe_b64encode(digest + DELIMITER + when) token = base64.urlsafe_b64encode(digest + DELIMITER + when)
return token return token
@util.positional(3) @util.positional(3)
def validate_token(key, token, user_id, action_id="", current_time=None): def validate_token(key, token, user_id, action_id="", current_time=None):
"""Validates that the given token authorizes the user for the action. """Validates that the given token authorizes the user for the action.
Tokens are invalid if the time of issue is too old or if the token Tokens are invalid if the time of issue is too old or if the token
does not match what generateToken outputs (i.e. the token was forged). does not match what generateToken outputs (i.e. the token was forged).
@@ -92,27 +89,27 @@ def validate_token(key, token, user_id, action_id="", current_time=None):
A boolean - True if the user is authorized for the action, False A boolean - True if the user is authorized for the action, False
otherwise. otherwise.
""" """
if not token: if not token:
return False return False
try: try:
decoded = base64.urlsafe_b64decode(token) decoded = base64.urlsafe_b64decode(token)
token_time = int(decoded.split(DELIMITER)[-1]) token_time = int(decoded.split(DELIMITER)[-1])
except (TypeError, ValueError): except (TypeError, ValueError):
return False return False
if current_time is None: if current_time is None:
current_time = time.time() current_time = time.time()
# If the token is too old it's not valid. # If the token is too old it's not valid.
if current_time - token_time > DEFAULT_TIMEOUT_SECS: if current_time - token_time > DEFAULT_TIMEOUT_SECS:
return False return False
# The given token should match the generated one with the same time. # The given token should match the generated one with the same time.
expected_token = generate_token(key, user_id, action_id=action_id, expected_token = generate_token(key, user_id, action_id=action_id,
when=token_time) when=token_time)
if len(token) != len(expected_token): if len(token) != len(expected_token):
return False return False
# Perform constant time comparison to avoid timing attacks # Perform constant time comparison to avoid timing attacks
different = 0 different = 0
for x, y in zip(bytearray(token), bytearray(expected_token)): for x, y in zip(bytearray(token), bytearray(expected_token)):
different |= x ^ y different |= x ^ y
return not different return not different

View File

@@ -16,6 +16,7 @@ __author__ = 'afshar@google.com (Ali Afshar)'
import oauth2client.util import oauth2client.util
def setup_package(): def setup_package():
"""Run on testing package.""" """Run on testing package."""
oauth2client.util.positional_parameters_enforcement = 'EXCEPTION' oauth2client.util.positional_parameters_enforcement = 'EXCEPTION'

View File

@@ -20,47 +20,45 @@ import httplib2
# TODO(craigcitro): Find a cleaner way to share this code with googleapiclient. # TODO(craigcitro): Find a cleaner way to share this code with googleapiclient.
class HttpMock(object): class HttpMock(object):
"""Mock of httplib2.Http""" """Mock of httplib2.Http"""
def __init__(self, filename=None, headers=None): def __init__(self, filename=None, headers=None):
""" """
Args: Args:
filename: string, absolute filename to read response from filename: string, absolute filename to read response from
headers: dict, header to return with response headers: dict, header to return with response
""" """
if headers is None: if headers is None:
headers = {'status': '200 OK'} headers = {'status': '200 OK'}
if filename: if filename:
f = file(filename, 'r') f = file(filename, 'r')
self.data = f.read() self.data = f.read()
f.close() f.close()
else: else:
self.data = None self.data = None
self.response_headers = headers self.response_headers = headers
self.headers = None self.headers = None
self.uri = None self.uri = None
self.method = None self.method = None
self.body = None self.body = None
self.headers = None self.headers = None
def request(self, uri,
def request(self, uri,
method='GET', method='GET',
body=None, body=None,
headers=None, headers=None,
redirections=1, redirections=1,
connection_type=None): connection_type=None):
self.uri = uri self.uri = uri
self.method = method self.method = method
self.body = body self.body = body
self.headers = headers self.headers = headers
return httplib2.Response(self.response_headers), self.data return httplib2.Response(self.response_headers), self.data
class HttpMockSequence(object): class HttpMockSequence(object):
"""Mock of httplib2.Http """Mock of httplib2.Http
Mocks a sequence of calls to request returning different responses for each Mocks a sequence of calls to request returning different responses for each
call. Create an instance initialized with the desired response headers call. Create an instance initialized with the desired response headers
@@ -83,33 +81,33 @@ class HttpMockSequence(object):
'echo_request_uri' means return the request uri in the response body 'echo_request_uri' means return the request uri in the response body
""" """
def __init__(self, iterable): def __init__(self, iterable):
""" """
Args: Args:
iterable: iterable, a sequence of pairs of (headers, body) iterable: iterable, a sequence of pairs of (headers, body)
""" """
self._iterable = iterable self._iterable = iterable
self.follow_redirects = True self.follow_redirects = True
self.requests = [] self.requests = []
def request(self, uri, def request(self, uri,
method='GET', method='GET',
body=None, body=None,
headers=None, headers=None,
redirections=1, redirections=1,
connection_type=None): connection_type=None):
resp, content = self._iterable.pop(0) resp, content = self._iterable.pop(0)
self.requests.append({'uri': uri, 'body': body, 'headers': headers}) self.requests.append({'uri': uri, 'body': body, 'headers': headers})
# Read any underlying stream before sending the request. # Read any underlying stream before sending the request.
body_stream_content = body.read() if getattr(body, 'read', None) else None body_stream_content = body.read() if getattr(body, 'read', None) else None
if content == 'echo_request_headers': if content == 'echo_request_headers':
content = headers content = headers
elif content == 'echo_request_headers_as_json': elif content == 'echo_request_headers_as_json':
content = json.dumps(headers) content = json.dumps(headers)
elif content == 'echo_request_body': elif content == 'echo_request_body':
content = body if body_stream_content is None else body_stream_content content = body if body_stream_content is None else body_stream_content
elif content == 'echo_request_uri': elif content == 'echo_request_uri':
content = uri content = uri
elif not isinstance(content, bytes): elif not isinstance(content, bytes):
raise TypeError('http content should be bytes: %r' % (content,)) raise TypeError('http content should be bytes: %r' % (content, ))
return httplib2.Response(resp), content return httplib2.Response(resp), content

View File

@@ -25,93 +25,93 @@ from oauth2client._helpers import _urlsafe_b64encode
class Test__parse_pem_key(unittest.TestCase): class Test__parse_pem_key(unittest.TestCase):
def test_valid_input(self): def test_valid_input(self):
test_string = b'1234-----BEGIN FOO BAR BAZ' test_string = b'1234-----BEGIN FOO BAR BAZ'
result = _parse_pem_key(test_string) result = _parse_pem_key(test_string)
self.assertEqual(result, test_string[4:]) self.assertEqual(result, test_string[4:])
def test_bad_input(self): def test_bad_input(self):
test_string = b'DOES NOT HAVE DASHES' test_string = b'DOES NOT HAVE DASHES'
result = _parse_pem_key(test_string) result = _parse_pem_key(test_string)
self.assertEqual(result, None) self.assertEqual(result, None)
class Test__json_encode(unittest.TestCase): class Test__json_encode(unittest.TestCase):
def test_dictionary_input(self): def test_dictionary_input(self):
# Use only a single key since dictionary hash order # Use only a single key since dictionary hash order
# is non-deterministic. # is non-deterministic.
data = {u'foo': 10} data = {u'foo': 10}
result = _json_encode(data) result = _json_encode(data)
self.assertEqual(result, """{"foo":10}""") self.assertEqual(result, """{"foo":10}""")
def test_list_input(self): def test_list_input(self):
data = [42, 1337] data = [42, 1337]
result = _json_encode(data) result = _json_encode(data)
self.assertEqual(result, """[42,1337]""") self.assertEqual(result, """[42,1337]""")
class Test__to_bytes(unittest.TestCase): class Test__to_bytes(unittest.TestCase):
def test_with_bytes(self): def test_with_bytes(self):
value = b'bytes-val' value = b'bytes-val'
self.assertEqual(_to_bytes(value), value) self.assertEqual(_to_bytes(value), value)
def test_with_unicode(self): def test_with_unicode(self):
value = u'string-val' value = u'string-val'
encoded_value = b'string-val' encoded_value = b'string-val'
self.assertEqual(_to_bytes(value), encoded_value) self.assertEqual(_to_bytes(value), encoded_value)
def test_with_nonstring_type(self): def test_with_nonstring_type(self):
value = object() value = object()
self.assertRaises(ValueError, _to_bytes, value) self.assertRaises(ValueError, _to_bytes, value)
class Test__from_bytes(unittest.TestCase): class Test__from_bytes(unittest.TestCase):
def test_with_unicode(self): def test_with_unicode(self):
value = u'bytes-val' value = u'bytes-val'
self.assertEqual(_from_bytes(value), value) self.assertEqual(_from_bytes(value), value)
def test_with_bytes(self): def test_with_bytes(self):
value = b'string-val' value = b'string-val'
decoded_value = u'string-val' decoded_value = u'string-val'
self.assertEqual(_from_bytes(value), decoded_value) self.assertEqual(_from_bytes(value), decoded_value)
def test_with_nonstring_type(self): def test_with_nonstring_type(self):
value = object() value = object()
self.assertRaises(ValueError, _from_bytes, value) self.assertRaises(ValueError, _from_bytes, value)
class Test__urlsafe_b64encode(unittest.TestCase): class Test__urlsafe_b64encode(unittest.TestCase):
DEADBEEF_ENCODED = b'ZGVhZGJlZWY' DEADBEEF_ENCODED = b'ZGVhZGJlZWY'
def test_valid_input_bytes(self): def test_valid_input_bytes(self):
test_string = b'deadbeef' test_string = b'deadbeef'
result = _urlsafe_b64encode(test_string) result = _urlsafe_b64encode(test_string)
self.assertEqual(result, self.DEADBEEF_ENCODED) self.assertEqual(result, self.DEADBEEF_ENCODED)
def test_valid_input_unicode(self): def test_valid_input_unicode(self):
test_string = u'deadbeef' test_string = u'deadbeef'
result = _urlsafe_b64encode(test_string) result = _urlsafe_b64encode(test_string)
self.assertEqual(result, self.DEADBEEF_ENCODED) self.assertEqual(result, self.DEADBEEF_ENCODED)
class Test__urlsafe_b64decode(unittest.TestCase): class Test__urlsafe_b64decode(unittest.TestCase):
def test_valid_input_bytes(self): def test_valid_input_bytes(self):
test_string = b'ZGVhZGJlZWY' test_string = b'ZGVhZGJlZWY'
result = _urlsafe_b64decode(test_string) result = _urlsafe_b64decode(test_string)
self.assertEqual(result, b'deadbeef') self.assertEqual(result, b'deadbeef')
def test_valid_input_unicode(self): def test_valid_input_unicode(self):
test_string = b'ZGVhZGJlZWY' test_string = b'ZGVhZGJlZWY'
result = _urlsafe_b64decode(test_string) result = _urlsafe_b64decode(test_string)
self.assertEqual(result, b'deadbeef') self.assertEqual(result, b'deadbeef')
def test_bad_input(self): def test_bad_input(self):
import binascii import binascii
bad_string = b'+' bad_string = b'+'
self.assertRaises((TypeError, binascii.Error), self.assertRaises((TypeError, binascii.Error),
_urlsafe_b64decode, bad_string) _urlsafe_b64decode, bad_string)

View File

@@ -22,42 +22,42 @@ from oauth2client.crypt import PyCryptoVerifier
class TestPyCryptoVerifier(unittest.TestCase): class TestPyCryptoVerifier(unittest.TestCase):
PUBLIC_KEY_FILENAME = os.path.join(os.path.dirname(__file__), PUBLIC_KEY_FILENAME = os.path.join(os.path.dirname(__file__),
'data', 'publickey.pem') 'data', 'publickey.pem')
PRIVATE_KEY_FILENAME = os.path.join(os.path.dirname(__file__), PRIVATE_KEY_FILENAME = os.path.join(os.path.dirname(__file__),
'data', 'privatekey.pem') 'data', 'privatekey.pem')
def _load_public_key_bytes(self): def _load_public_key_bytes(self):
with open(self.PUBLIC_KEY_FILENAME, 'rb') as fh: with open(self.PUBLIC_KEY_FILENAME, 'rb') as fh:
return fh.read() return fh.read()
def _load_private_key_bytes(self): def _load_private_key_bytes(self):
with open(self.PRIVATE_KEY_FILENAME, 'rb') as fh: with open(self.PRIVATE_KEY_FILENAME, 'rb') as fh:
return fh.read() return fh.read()
def test_verify_success(self): def test_verify_success(self):
to_sign = b'foo' to_sign = b'foo'
signer = PyCryptoSigner.from_string(self._load_private_key_bytes()) signer = PyCryptoSigner.from_string(self._load_private_key_bytes())
actual_signature = signer.sign(to_sign) actual_signature = signer.sign(to_sign)
verifier = PyCryptoVerifier.from_string(self._load_public_key_bytes(), verifier = PyCryptoVerifier.from_string(self._load_public_key_bytes(),
is_x509_cert=True) is_x509_cert=True)
self.assertTrue(verifier.verify(to_sign, actual_signature)) self.assertTrue(verifier.verify(to_sign, actual_signature))
def test_verify_failure(self): def test_verify_failure(self):
verifier = PyCryptoVerifier.from_string(self._load_public_key_bytes(), verifier = PyCryptoVerifier.from_string(self._load_public_key_bytes(),
is_x509_cert=True) is_x509_cert=True)
bad_signature = b'' bad_signature = b''
self.assertFalse(verifier.verify(b'foo', bad_signature)) self.assertFalse(verifier.verify(b'foo', bad_signature))
def test_verify_bad_key(self): def test_verify_bad_key(self):
verifier = PyCryptoVerifier.from_string(self._load_public_key_bytes(), verifier = PyCryptoVerifier.from_string(self._load_public_key_bytes(),
is_x509_cert=True) is_x509_cert=True)
bad_signature = b'' bad_signature = b''
self.assertFalse(verifier.verify(b'foo', bad_signature)) self.assertFalse(verifier.verify(b'foo', bad_signature))
def test_from_string_unicode_key(self): def test_from_string_unicode_key(self):
public_key = self._load_public_key_bytes() public_key = self._load_public_key_bytes()
public_key = public_key.decode('utf-8') public_key = public_key.decode('utf-8')
verifier = PyCryptoVerifier.from_string(public_key, is_x509_cert=True) verifier = PyCryptoVerifier.from_string(public_key, is_x509_cert=True)
self.assertTrue(isinstance(verifier, PyCryptoVerifier)) self.assertTrue(isinstance(verifier, PyCryptoVerifier))

File diff suppressed because it is too large Load Diff

View File

@@ -16,7 +16,6 @@
__author__ = 'jcgregorio@google.com (Joe Gregorio)' __author__ = 'jcgregorio@google.com (Joe Gregorio)'
import os import os
import unittest import unittest
from io import StringIO from io import StringIO
@@ -25,22 +24,22 @@ import httplib2
from oauth2client import clientsecrets from oauth2client import clientsecrets
DATA_DIR = os.path.join(os.path.dirname(__file__), 'data') DATA_DIR = os.path.join(os.path.dirname(__file__), 'data')
VALID_FILE = os.path.join(DATA_DIR, 'client_secrets.json') VALID_FILE = os.path.join(DATA_DIR, 'client_secrets.json')
INVALID_FILE = os.path.join(DATA_DIR, 'unfilled_client_secrets.json') INVALID_FILE = os.path.join(DATA_DIR, 'unfilled_client_secrets.json')
NONEXISTENT_FILE = os.path.join(__file__, '..', 'afilethatisntthere.json') NONEXISTENT_FILE = os.path.join(__file__, '..', 'afilethatisntthere.json')
class OAuth2CredentialsTests(unittest.TestCase): class OAuth2CredentialsTests(unittest.TestCase):
def setUp(self): def setUp(self):
pass pass
def tearDown(self): def tearDown(self):
pass pass
def test_validate_error(self): def test_validate_error(self):
ERRORS = [ ERRORS = [
('{}', 'Invalid'), ('{}', 'Invalid'),
('{"foo": {}}', 'Unknown'), ('{"foo": {}}', 'Unknown'),
('{"web": {}}', 'Missing'), ('{"web": {}}', 'Missing'),
@@ -56,95 +55,95 @@ class OAuth2CredentialsTests(unittest.TestCase):
} }
""", 'Property'), """, 'Property'),
] ]
for src, match in ERRORS: for src, match in ERRORS:
# Ensure that it is unicode # Ensure that it is unicode
try: try:
src = src.decode('utf-8') src = src.decode('utf-8')
except AttributeError: except AttributeError:
pass pass
# Test load(s) # Test load(s)
try: try:
clientsecrets.loads(src) clientsecrets.loads(src)
self.fail(src + ' should not be a valid client_secrets file.') self.fail(src + ' should not be a valid client_secrets file.')
except clientsecrets.InvalidClientSecretsError as e: except clientsecrets.InvalidClientSecretsError as e:
self.assertTrue(str(e).startswith(match)) self.assertTrue(str(e).startswith(match))
# Test loads(fp) # Test loads(fp)
try: try:
fp = StringIO(src) fp = StringIO(src)
clientsecrets.load(fp) clientsecrets.load(fp)
self.fail(src + ' should not be a valid client_secrets file.') self.fail(src + ' should not be a valid client_secrets file.')
except clientsecrets.InvalidClientSecretsError as e: except clientsecrets.InvalidClientSecretsError as e:
self.assertTrue(str(e).startswith(match)) self.assertTrue(str(e).startswith(match))
def test_load_by_filename(self): def test_load_by_filename(self):
try: try:
clientsecrets._loadfile(NONEXISTENT_FILE) clientsecrets._loadfile(NONEXISTENT_FILE)
self.fail('should fail to load a missing client_secrets file.') self.fail('should fail to load a missing client_secrets file.')
except clientsecrets.InvalidClientSecretsError as e: except clientsecrets.InvalidClientSecretsError as e:
self.assertTrue(str(e).startswith('File')) self.assertTrue(str(e).startswith('File'))
class CachedClientsecretsTests(unittest.TestCase): class CachedClientsecretsTests(unittest.TestCase):
class CacheMock(object): class CacheMock(object):
def __init__(self): def __init__(self):
self.cache = {} self.cache = {}
self.last_get_ns = None self.last_get_ns = None
self.last_set_ns = None self.last_set_ns = None
def get(self, key, namespace=''): def get(self, key, namespace=''):
# ignoring namespace for easier testing # ignoring namespace for easier testing
self.last_get_ns = namespace self.last_get_ns = namespace
return self.cache.get(key, None) return self.cache.get(key, None)
def set(self, key, value, namespace=''): def set(self, key, value, namespace=''):
# ignoring namespace for easier testing # ignoring namespace for easier testing
self.last_set_ns = namespace self.last_set_ns = namespace
self.cache[key] = value self.cache[key] = value
def setUp(self): def setUp(self):
self.cache_mock = self.CacheMock() self.cache_mock = self.CacheMock()
def test_cache_miss(self): def test_cache_miss(self):
client_type, client_info = clientsecrets.loadfile( client_type, client_info = clientsecrets.loadfile(
VALID_FILE, cache=self.cache_mock) VALID_FILE, cache=self.cache_mock)
self.assertEqual('web', client_type) self.assertEqual('web', client_type)
self.assertEqual('foo_client_secret', client_info['client_secret']) self.assertEqual('foo_client_secret', client_info['client_secret'])
cached = self.cache_mock.cache[VALID_FILE] cached = self.cache_mock.cache[VALID_FILE]
self.assertEqual({client_type: client_info}, cached) self.assertEqual({client_type: client_info}, cached)
# make sure we're using non-empty namespace # make sure we're using non-empty namespace
ns = self.cache_mock.last_set_ns ns = self.cache_mock.last_set_ns
self.assertTrue(bool(ns)) self.assertTrue(bool(ns))
# make sure they're equal # make sure they're equal
self.assertEqual(ns, self.cache_mock.last_get_ns) self.assertEqual(ns, self.cache_mock.last_get_ns)
def test_cache_hit(self): def test_cache_hit(self):
self.cache_mock.cache[NONEXISTENT_FILE] = { 'web': 'secret info' } self.cache_mock.cache[NONEXISTENT_FILE] = {'web': 'secret info'}
client_type, client_info = clientsecrets.loadfile( client_type, client_info = clientsecrets.loadfile(
NONEXISTENT_FILE, cache=self.cache_mock) NONEXISTENT_FILE, cache=self.cache_mock)
self.assertEqual('web', client_type) self.assertEqual('web', client_type)
self.assertEqual('secret info', client_info) self.assertEqual('secret info', client_info)
# make sure we didn't do any set() RPCs # make sure we didn't do any set() RPCs
self.assertEqual(None, self.cache_mock.last_set_ns) self.assertEqual(None, self.cache_mock.last_set_ns)
def test_validation(self): def test_validation(self):
try: try:
clientsecrets.loadfile(INVALID_FILE, cache=self.cache_mock) clientsecrets.loadfile(INVALID_FILE, cache=self.cache_mock)
self.fail('Expected InvalidClientSecretsError to be raised ' self.fail('Expected InvalidClientSecretsError to be raised '
'while loading %s' % INVALID_FILE) 'while loading %s' % INVALID_FILE)
except clientsecrets.InvalidClientSecretsError: except clientsecrets.InvalidClientSecretsError:
pass pass
def test_without_cache(self): def test_without_cache(self):
# this also ensures loadfile() is backward compatible # this also ensures loadfile() is backward compatible
client_type, client_info = clientsecrets.loadfile(VALID_FILE) client_type, client_info = clientsecrets.loadfile(VALID_FILE)
self.assertEqual('web', client_type) self.assertEqual('web', client_type)
self.assertEqual('foo_client_secret', client_info['client_secret']) self.assertEqual('foo_client_secret', client_info['client_secret'])
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()

View File

@@ -18,10 +18,10 @@ import sys
import unittest import unittest
try: try:
reload reload
except NameError: except NameError:
# For Python3 (though importlib should be used, silly 3.3). # For Python3 (though importlib should be used, silly 3.3).
from imp import reload from imp import reload
from oauth2client import _helpers from oauth2client import _helpers
from oauth2client.client import HAS_OPENSSL from oauth2client.client import HAS_OPENSSL
@@ -30,44 +30,44 @@ from oauth2client import crypt
def datafile(filename): def datafile(filename):
f = open(os.path.join(os.path.dirname(__file__), 'data', filename), 'rb') f = open(os.path.join(os.path.dirname(__file__), 'data', filename), 'rb')
data = f.read() data = f.read()
f.close() f.close()
return data return data
class Test_pkcs12_key_as_pem(unittest.TestCase): class Test_pkcs12_key_as_pem(unittest.TestCase):
def _make_signed_jwt_creds(self, private_key_file='privatekey.p12', def _make_signed_jwt_creds(self, private_key_file='privatekey.p12',
private_key=None): private_key=None):
private_key = private_key or datafile(private_key_file) private_key = private_key or datafile(private_key_file)
return SignedJwtAssertionCredentials( return SignedJwtAssertionCredentials(
'some_account@example.com', 'some_account@example.com',
private_key, private_key,
scope='read+write', scope='read+write',
sub='joe@example.org') sub='joe@example.org')
def _succeeds_helper(self, password=None): def _succeeds_helper(self, password=None):
self.assertEqual(True, HAS_OPENSSL) self.assertEqual(True, HAS_OPENSSL)
credentials = self._make_signed_jwt_creds() credentials = self._make_signed_jwt_creds()
if password is None: if password is None:
password = credentials.private_key_password password = credentials.private_key_password
pem_contents = crypt.pkcs12_key_as_pem(credentials.private_key, password) pem_contents = crypt.pkcs12_key_as_pem(credentials.private_key, password)
pkcs12_key_as_pem = datafile('pem_from_pkcs12.pem') pkcs12_key_as_pem = datafile('pem_from_pkcs12.pem')
pkcs12_key_as_pem = _helpers._parse_pem_key(pkcs12_key_as_pem) pkcs12_key_as_pem = _helpers._parse_pem_key(pkcs12_key_as_pem)
alternate_pem = datafile('pem_from_pkcs12_alternate.pem') alternate_pem = datafile('pem_from_pkcs12_alternate.pem')
self.assertTrue(pem_contents in [pkcs12_key_as_pem, alternate_pem]) self.assertTrue(pem_contents in [pkcs12_key_as_pem, alternate_pem])
def test_succeeds(self): def test_succeeds(self):
self._succeeds_helper() self._succeeds_helper()
def test_succeeds_with_unicode_password(self): def test_succeeds_with_unicode_password(self):
password = u'notasecret' password = u'notasecret'
self._succeeds_helper(password) self._succeeds_helper(password)
def test_with_nonsense_key(self): def test_with_nonsense_key(self):
from OpenSSL import crypto from OpenSSL import crypto
credentials = self._make_signed_jwt_creds(private_key=b'NOT_A_KEY') credentials = self._make_signed_jwt_creds(private_key=b'NOT_A_KEY')
self.assertRaises(crypto.Error, crypt.pkcs12_key_as_pem, self.assertRaises(crypto.Error, crypt.pkcs12_key_as_pem,
credentials.private_key, credentials.private_key_password) credentials.private_key, credentials.private_key_password)

View File

@@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
"""Tests for oauth2client.devshell.""" """Tests for oauth2client.devshell."""
import os import os
@@ -30,110 +29,110 @@ from oauth2client.devshell import NoDevshellServer
class _AuthReferenceServer(threading.Thread): class _AuthReferenceServer(threading.Thread):
def __init__(self, response=None): def __init__(self, response=None):
super(_AuthReferenceServer, self).__init__(None) super(_AuthReferenceServer, self).__init__(None)
self.response = (response or self.response = (response or
'["joe@example.com", "fooproj", "sometoken"]') '["joe@example.com", "fooproj", "sometoken"]')
def __enter__(self): def __enter__(self):
self.start_server() self.start_server()
def start_server(self): def start_server(self):
self._socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM) self._socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
self._socket.bind(('localhost', 0)) self._socket.bind(('localhost', 0))
port = self._socket.getsockname()[1] port = self._socket.getsockname()[1]
os.environ[DEVSHELL_ENV] = str(port) os.environ[DEVSHELL_ENV] = str(port)
self._socket.listen(0) self._socket.listen(0)
self.daemon = True self.daemon = True
self.start() self.start()
return self return self
def __exit__(self, e_type, value, traceback): def __exit__(self, e_type, value, traceback):
self.stop_server() self.stop_server()
def stop_server(self): def stop_server(self):
del os.environ[DEVSHELL_ENV] del os.environ[DEVSHELL_ENV]
self._socket.close() self._socket.close()
def run(self): def run(self):
s = None s = None
try: try:
# Do not set the timeout on the socket, leave it in the blocking mode as # Do not set the timeout on the socket, leave it in the blocking mode as
# setting the timeout seems to cause spurious EAGAIN errors on OSX. # setting the timeout seems to cause spurious EAGAIN errors on OSX.
self._socket.settimeout(None) self._socket.settimeout(None)
s, unused_addr = self._socket.accept() s, unused_addr = self._socket.accept()
resp_buffer = '' resp_buffer = ''
resp_1 = s.recv(6).decode() resp_1 = s.recv(6).decode()
if '\n' not in resp_1: if '\n' not in resp_1:
raise Exception('invalid request data') raise Exception('invalid request data')
nstr, extra = resp_1.split('\n', 1) nstr, extra = resp_1.split('\n', 1)
resp_buffer = extra resp_buffer = extra
n = int(nstr) n = int(nstr)
to_read = n-len(extra) to_read = n - len(extra)
if to_read > 0: if to_read > 0:
resp_buffer += s.recv(to_read, socket.MSG_WAITALL) resp_buffer += s.recv(to_read, socket.MSG_WAITALL)
if resp_buffer != CREDENTIAL_INFO_REQUEST_JSON: if resp_buffer != CREDENTIAL_INFO_REQUEST_JSON:
raise Exception('bad request') raise Exception('bad request')
l = len(self.response) l = len(self.response)
s.sendall(('%d\n%s' % (l, self.response)).encode()) s.sendall(('%d\n%s' % (l, self.response)).encode())
finally: finally:
if s: if s:
s.close() s.close()
class DevshellCredentialsTests(unittest.TestCase): class DevshellCredentialsTests(unittest.TestCase):
def test_signals_no_server(self): def test_signals_no_server(self):
self.assertRaises(NoDevshellServer, DevshellCredentials) self.assertRaises(NoDevshellServer, DevshellCredentials)
def test_request_response(self): def test_request_response(self):
with _AuthReferenceServer(): with _AuthReferenceServer():
response = _SendRecv() response = _SendRecv()
self.assertEqual(response.user_email, 'joe@example.com') self.assertEqual(response.user_email, 'joe@example.com')
self.assertEqual(response.project_id, 'fooproj') self.assertEqual(response.project_id, 'fooproj')
self.assertEqual(response.access_token, 'sometoken') self.assertEqual(response.access_token, 'sometoken')
def test_no_refresh_token(self): def test_no_refresh_token(self):
with _AuthReferenceServer(): with _AuthReferenceServer():
creds = DevshellCredentials() creds = DevshellCredentials()
self.assertEquals(None, creds.refresh_token) self.assertEquals(None, creds.refresh_token)
def test_reads_credentials(self): def test_reads_credentials(self):
with _AuthReferenceServer(): with _AuthReferenceServer():
creds = DevshellCredentials() creds = DevshellCredentials()
self.assertEqual('joe@example.com', creds.user_email) self.assertEqual('joe@example.com', creds.user_email)
self.assertEqual('fooproj', creds.project_id) self.assertEqual('fooproj', creds.project_id)
self.assertEqual('sometoken', creds.access_token) self.assertEqual('sometoken', creds.access_token)
def test_handles_skipped_fields(self): def test_handles_skipped_fields(self):
with _AuthReferenceServer('["joe@example.com"]'): with _AuthReferenceServer('["joe@example.com"]'):
creds = DevshellCredentials() creds = DevshellCredentials()
self.assertEqual('joe@example.com', creds.user_email) self.assertEqual('joe@example.com', creds.user_email)
self.assertEqual(None, creds.project_id) self.assertEqual(None, creds.project_id)
self.assertEqual(None, creds.access_token) self.assertEqual(None, creds.access_token)
def test_handles_tiny_response(self): def test_handles_tiny_response(self):
with _AuthReferenceServer('[]'): with _AuthReferenceServer('[]'):
creds = DevshellCredentials() creds = DevshellCredentials()
self.assertEqual(None, creds.user_email) self.assertEqual(None, creds.user_email)
self.assertEqual(None, creds.project_id) self.assertEqual(None, creds.project_id)
self.assertEqual(None, creds.access_token) self.assertEqual(None, creds.access_token)
def test_handles_ignores_extra_fields(self): def test_handles_ignores_extra_fields(self):
with _AuthReferenceServer( with _AuthReferenceServer(
'["joe@example.com", "fooproj", "sometoken", "extra"]'): '["joe@example.com", "fooproj", "sometoken", "extra"]'):
creds = DevshellCredentials() creds = DevshellCredentials()
self.assertEqual('joe@example.com', creds.user_email) self.assertEqual('joe@example.com', creds.user_email)
self.assertEqual('fooproj', creds.project_id) self.assertEqual('fooproj', creds.project_id)
self.assertEqual('sometoken', creds.access_token) self.assertEqual('sometoken', creds.access_token)
def test_refuses_to_save_to_well_known_file(self): def test_refuses_to_save_to_well_known_file(self):
ORIGINAL_ISDIR = os.path.isdir ORIGINAL_ISDIR = os.path.isdir
try: try:
os.path.isdir = lambda path: True os.path.isdir = lambda path: True
with _AuthReferenceServer(): with _AuthReferenceServer():
creds = DevshellCredentials() creds = DevshellCredentials()
self.assertRaises(NotImplementedError, save_to_well_known_file, creds) self.assertRaises(NotImplementedError, save_to_well_known_file, creds)
finally: finally:
os.path.isdir = ORIGINAL_ISDIR os.path.isdir = ORIGINAL_ISDIR

View File

@@ -14,7 +14,6 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
"""Discovery document tests """Discovery document tests
Unit tests for objects created from discovery documents. Unit tests for objects created from discovery documents.
@@ -31,10 +30,10 @@ import unittest
# Ensure that if app engine is available, we use the correct django from it # Ensure that if app engine is available, we use the correct django from it
try: try:
from google.appengine.dist import use_library from google.appengine.dist import use_library
use_library('django', '1.5') use_library('django', '1.5')
except ImportError: except ImportError:
pass pass
from oauth2client.client import Credentials from oauth2client.client import Credentials
from oauth2client.client import Flow from oauth2client.client import Flow
@@ -51,39 +50,39 @@ from oauth2client.django_orm import FlowField
class TestCredentialsField(unittest.TestCase): class TestCredentialsField(unittest.TestCase):
def setUp(self): def setUp(self):
self.field = CredentialsField() self.field = CredentialsField()
self.credentials = Credentials() self.credentials = Credentials()
self.pickle = base64.b64encode(pickle.dumps(self.credentials)) self.pickle = base64.b64encode(pickle.dumps(self.credentials))
def test_field_is_text(self): def test_field_is_text(self):
self.assertEquals(self.field.get_internal_type(), 'TextField') self.assertEquals(self.field.get_internal_type(), 'TextField')
def test_field_unpickled(self): def test_field_unpickled(self):
self.assertTrue(isinstance(self.field.to_python(self.pickle), Credentials)) self.assertTrue(isinstance(self.field.to_python(self.pickle), Credentials))
def test_field_pickled(self): def test_field_pickled(self):
prep_value = self.field.get_db_prep_value(self.credentials, prep_value = self.field.get_db_prep_value(self.credentials,
connection=None) connection=None)
self.assertEqual(prep_value, self.pickle) self.assertEqual(prep_value, self.pickle)
class TestFlowField(unittest.TestCase): class TestFlowField(unittest.TestCase):
def setUp(self): def setUp(self):
self.field = FlowField() self.field = FlowField()
self.flow = Flow() self.flow = Flow()
self.pickle = base64.b64encode(pickle.dumps(self.flow)) self.pickle = base64.b64encode(pickle.dumps(self.flow))
def test_field_is_text(self): def test_field_is_text(self):
self.assertEquals(self.field.get_internal_type(), 'TextField') self.assertEquals(self.field.get_internal_type(), 'TextField')
def test_field_unpickled(self): def test_field_unpickled(self):
self.assertTrue(isinstance(self.field.to_python(self.pickle), Flow)) self.assertTrue(isinstance(self.field.to_python(self.pickle), Flow))
def test_field_pickled(self): def test_field_pickled(self):
prep_value = self.field.get_db_prep_value(self.flow, connection=None) prep_value = self.field.get_db_prep_value(self.flow, connection=None)
self.assertEqual(prep_value, self.pickle) self.assertEqual(prep_value, self.pickle)
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()

View File

@@ -14,7 +14,6 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
"""Oauth2client.file tests """Oauth2client.file tests
Unit tests for oauth2client.file Unit tests for oauth2client.file
@@ -42,363 +41,362 @@ from oauth2client.client import AccessTokenCredentials
from oauth2client.client import OAuth2Credentials from oauth2client.client import OAuth2Credentials
from six.moves import http_client from six.moves import http_client
try: try:
# Python2 # Python2
from future_builtins import oct from future_builtins import oct
except: except:
pass pass
FILENAME = tempfile.mktemp('oauth2client_test.data') FILENAME = tempfile.mktemp('oauth2client_test.data')
class OAuth2ClientFileTests(unittest.TestCase): class OAuth2ClientFileTests(unittest.TestCase):
def tearDown(self): def tearDown(self):
try: try:
os.unlink(FILENAME) os.unlink(FILENAME)
except OSError: except OSError:
pass pass
def setUp(self): def setUp(self):
try: try:
os.unlink(FILENAME) os.unlink(FILENAME)
except OSError: except OSError:
pass pass
def create_test_credentials(self, client_id='some_client_id', def create_test_credentials(self, client_id='some_client_id',
expiration=None): expiration=None):
access_token = 'foo' access_token = 'foo'
client_secret = 'cOuDdkfjxxnv+' client_secret = 'cOuDdkfjxxnv+'
refresh_token = '1/0/a.df219fjls0' refresh_token = '1/0/a.df219fjls0'
token_expiry = expiration or datetime.datetime.utcnow() token_expiry = expiration or datetime.datetime.utcnow()
token_uri = 'https://www.google.com/accounts/o8/oauth2/token' token_uri = 'https://www.google.com/accounts/o8/oauth2/token'
user_agent = 'refresh_checker/1.0' user_agent = 'refresh_checker/1.0'
credentials = OAuth2Credentials( credentials = OAuth2Credentials(
access_token, client_id, client_secret, access_token, client_id, client_secret,
refresh_token, token_expiry, token_uri, refresh_token, token_expiry, token_uri,
user_agent) user_agent)
return credentials return credentials
def test_non_existent_file_storage(self): def test_non_existent_file_storage(self):
s = file.Storage(FILENAME) s = file.Storage(FILENAME)
credentials = s.get() credentials = s.get()
self.assertEquals(None, credentials) self.assertEquals(None, credentials)
def test_no_sym_link_credentials(self): def test_no_sym_link_credentials(self):
if hasattr(os, 'symlink'): if hasattr(os, 'symlink'):
SYMFILENAME = FILENAME + '.sym' SYMFILENAME = FILENAME + '.sym'
os.symlink(FILENAME, SYMFILENAME) os.symlink(FILENAME, SYMFILENAME)
s = file.Storage(SYMFILENAME) s = file.Storage(SYMFILENAME)
try: try:
s.get() s.get()
self.fail('Should have raised an exception.') self.fail('Should have raised an exception.')
except file.CredentialsFileSymbolicLinkError: except file.CredentialsFileSymbolicLinkError:
pass pass
finally: finally:
os.unlink(SYMFILENAME) os.unlink(SYMFILENAME)
def test_pickle_and_json_interop(self): def test_pickle_and_json_interop(self):
# Write a file with a pickled OAuth2Credentials. # Write a file with a pickled OAuth2Credentials.
credentials = self.create_test_credentials() credentials = self.create_test_credentials()
f = open(FILENAME, 'wb') f = open(FILENAME, 'wb')
pickle.dump(credentials, f) pickle.dump(credentials, f)
f.close() f.close()
# Storage should be not be able to read that object, as the capability to # Storage should be not be able to read that object, as the capability to
# read and write credentials as pickled objects has been removed. # read and write credentials as pickled objects has been removed.
s = file.Storage(FILENAME) s = file.Storage(FILENAME)
read_credentials = s.get() read_credentials = s.get()
self.assertEquals(None, read_credentials) self.assertEquals(None, read_credentials)
# Now write it back out and confirm it has been rewritten as JSON # Now write it back out and confirm it has been rewritten as JSON
s.put(credentials) s.put(credentials)
with open(FILENAME) as f: with open(FILENAME) as f:
data = json.load(f) data = json.load(f)
self.assertEquals(data['access_token'], 'foo') self.assertEquals(data['access_token'], 'foo')
self.assertEquals(data['_class'], 'OAuth2Credentials') self.assertEquals(data['_class'], 'OAuth2Credentials')
self.assertEquals(data['_module'], OAuth2Credentials.__module__) self.assertEquals(data['_module'], OAuth2Credentials.__module__)
def test_token_refresh_store_expired(self): def test_token_refresh_store_expired(self):
expiration = datetime.datetime.utcnow() - datetime.timedelta(minutes=15) expiration = datetime.datetime.utcnow() - datetime.timedelta(minutes=15)
credentials = self.create_test_credentials(expiration=expiration) credentials = self.create_test_credentials(expiration=expiration)
s = file.Storage(FILENAME) s = file.Storage(FILENAME)
s.put(credentials) s.put(credentials)
credentials = s.get() credentials = s.get()
new_cred = copy.copy(credentials) new_cred = copy.copy(credentials)
new_cred.access_token = 'bar' new_cred.access_token = 'bar'
s.put(new_cred) s.put(new_cred)
access_token = '1/3w' access_token = '1/3w'
token_response = {'access_token': access_token, 'expires_in': 3600} token_response = {'access_token': access_token, 'expires_in': 3600}
http = HttpMockSequence([ http = HttpMockSequence([
({'status': '200'}, json.dumps(token_response).encode('utf-8')), ({'status': '200'}, json.dumps(token_response).encode('utf-8')),
]) ])
credentials._refresh(http.request) credentials._refresh(http.request)
self.assertEquals(credentials.access_token, access_token) self.assertEquals(credentials.access_token, access_token)
def test_token_refresh_store_expires_soon(self): def test_token_refresh_store_expires_soon(self):
# Tests the case where an access token that is valid when it is read from # Tests the case where an access token that is valid when it is read from
# the store expires before the original request succeeds. # the store expires before the original request succeeds.
expiration = datetime.datetime.utcnow() + datetime.timedelta(minutes=15) expiration = datetime.datetime.utcnow() + datetime.timedelta(minutes=15)
credentials = self.create_test_credentials(expiration=expiration) credentials = self.create_test_credentials(expiration=expiration)
s = file.Storage(FILENAME) s = file.Storage(FILENAME)
s.put(credentials) s.put(credentials)
credentials = s.get() credentials = s.get()
new_cred = copy.copy(credentials) new_cred = copy.copy(credentials)
new_cred.access_token = 'bar' new_cred.access_token = 'bar'
s.put(new_cred) s.put(new_cred)
access_token = '1/3w' access_token = '1/3w'
token_response = {'access_token': access_token, 'expires_in': 3600} token_response = {'access_token': access_token, 'expires_in': 3600}
http = HttpMockSequence([ http = HttpMockSequence([
({'status': str(http_client.UNAUTHORIZED)}, b'Initial token expired'), ({'status': str(http_client.UNAUTHORIZED)}, b'Initial token expired'),
({'status': str(http_client.UNAUTHORIZED)}, b'Store token expired'), ({'status': str(http_client.UNAUTHORIZED)}, b'Store token expired'),
({'status': str(http_client.OK)}, ({'status': str(http_client.OK)},
json.dumps(token_response).encode('utf-8')), json.dumps(token_response).encode('utf-8')),
({'status': str(http_client.OK)}, ({'status': str(http_client.OK)},
b'Valid response to original request') b'Valid response to original request')
]) ])
credentials.authorize(http) credentials.authorize(http)
http.request('https://example.com') http.request('https://example.com')
self.assertEqual(credentials.access_token, access_token) self.assertEqual(credentials.access_token, access_token)
def test_token_refresh_good_store(self): def test_token_refresh_good_store(self):
expiration = datetime.datetime.utcnow() + datetime.timedelta(minutes=15) expiration = datetime.datetime.utcnow() + datetime.timedelta(minutes=15)
credentials = self.create_test_credentials(expiration=expiration) credentials = self.create_test_credentials(expiration=expiration)
s = file.Storage(FILENAME) s = file.Storage(FILENAME)
s.put(credentials) s.put(credentials)
credentials = s.get() credentials = s.get()
new_cred = copy.copy(credentials) new_cred = copy.copy(credentials)
new_cred.access_token = 'bar' new_cred.access_token = 'bar'
s.put(new_cred) s.put(new_cred)
credentials._refresh(lambda x: x) credentials._refresh(lambda x: x)
self.assertEquals(credentials.access_token, 'bar') self.assertEquals(credentials.access_token, 'bar')
def test_token_refresh_stream_body(self): def test_token_refresh_stream_body(self):
expiration = datetime.datetime.utcnow() + datetime.timedelta(minutes=15) expiration = datetime.datetime.utcnow() + datetime.timedelta(minutes=15)
credentials = self.create_test_credentials(expiration=expiration) credentials = self.create_test_credentials(expiration=expiration)
s = file.Storage(FILENAME) s = file.Storage(FILENAME)
s.put(credentials) s.put(credentials)
credentials = s.get() credentials = s.get()
new_cred = copy.copy(credentials) new_cred = copy.copy(credentials)
new_cred.access_token = 'bar' new_cred.access_token = 'bar'
s.put(new_cred) s.put(new_cred)
valid_access_token = '1/3w' valid_access_token = '1/3w'
token_response = {'access_token': valid_access_token, 'expires_in': 3600} token_response = {'access_token': valid_access_token, 'expires_in': 3600}
http = HttpMockSequence([ http = HttpMockSequence([
({'status': str(http_client.UNAUTHORIZED)}, b'Initial token expired'), ({'status': str(http_client.UNAUTHORIZED)}, b'Initial token expired'),
({'status': str(http_client.UNAUTHORIZED)}, b'Store token expired'), ({'status': str(http_client.UNAUTHORIZED)}, b'Store token expired'),
({'status': str(http_client.OK)}, ({'status': str(http_client.OK)},
json.dumps(token_response).encode('utf-8')), json.dumps(token_response).encode('utf-8')),
({'status': str(http_client.OK)}, 'echo_request_body') ({'status': str(http_client.OK)}, 'echo_request_body')
]) ])
body = six.StringIO('streaming body') body = six.StringIO('streaming body')
credentials.authorize(http) credentials.authorize(http)
_, content = http.request('https://example.com', body=body) _, content = http.request('https://example.com', body=body)
self.assertEqual(content, 'streaming body') self.assertEqual(content, 'streaming body')
self.assertEqual(credentials.access_token, valid_access_token) self.assertEqual(credentials.access_token, valid_access_token)
def test_credentials_delete(self): def test_credentials_delete(self):
credentials = self.create_test_credentials() credentials = self.create_test_credentials()
s = file.Storage(FILENAME) s = file.Storage(FILENAME)
s.put(credentials) s.put(credentials)
credentials = s.get() credentials = s.get()
self.assertNotEquals(None, credentials) self.assertNotEquals(None, credentials)
s.delete() s.delete()
credentials = s.get() credentials = s.get()
self.assertEquals(None, credentials) self.assertEquals(None, credentials)
def test_access_token_credentials(self): def test_access_token_credentials(self):
access_token = 'foo' access_token = 'foo'
user_agent = 'refresh_checker/1.0' user_agent = 'refresh_checker/1.0'
credentials = AccessTokenCredentials(access_token, user_agent) credentials = AccessTokenCredentials(access_token, user_agent)
s = file.Storage(FILENAME) s = file.Storage(FILENAME)
credentials = s.put(credentials) credentials = s.put(credentials)
credentials = s.get() credentials = s.get()
self.assertNotEquals(None, credentials) self.assertNotEquals(None, credentials)
self.assertEquals('foo', credentials.access_token) self.assertEquals('foo', credentials.access_token)
mode = os.stat(FILENAME).st_mode mode = os.stat(FILENAME).st_mode
if os.name == 'posix': if os.name == 'posix':
self.assertEquals('0o600', oct(stat.S_IMODE(os.stat(FILENAME).st_mode))) self.assertEquals('0o600', oct(stat.S_IMODE(os.stat(FILENAME).st_mode)))
def test_read_only_file_fail_lock(self): def test_read_only_file_fail_lock(self):
credentials = self.create_test_credentials() credentials = self.create_test_credentials()
open(FILENAME, 'a+b').close() open(FILENAME, 'a+b').close()
os.chmod(FILENAME, 0o400) os.chmod(FILENAME, 0o400)
store = multistore_file.get_credential_storage( store = multistore_file.get_credential_storage(
FILENAME, FILENAME,
credentials.client_id, credentials.client_id,
credentials.user_agent, credentials.user_agent,
['some-scope', 'some-other-scope']) ['some-scope', 'some-other-scope'])
store.put(credentials) store.put(credentials)
if os.name == 'posix': if os.name == 'posix':
self.assertTrue(store._multistore._read_only) self.assertTrue(store._multistore._read_only)
os.chmod(FILENAME, 0o600) os.chmod(FILENAME, 0o600)
def test_multistore_no_symbolic_link_files(self): def test_multistore_no_symbolic_link_files(self):
if hasattr(os, 'symlink'): if hasattr(os, 'symlink'):
SYMFILENAME = FILENAME + 'sym' SYMFILENAME = FILENAME + 'sym'
os.symlink(FILENAME, SYMFILENAME) os.symlink(FILENAME, SYMFILENAME)
store = multistore_file.get_credential_storage( store = multistore_file.get_credential_storage(
SYMFILENAME, SYMFILENAME,
'some_client_id', 'some_client_id',
'user-agent/1.0', 'user-agent/1.0',
['some-scope', 'some-other-scope']) ['some-scope', 'some-other-scope'])
try: try:
store.get() store.get()
self.fail('Should have raised an exception.') self.fail('Should have raised an exception.')
except locked_file.CredentialsFileSymbolicLinkError: except locked_file.CredentialsFileSymbolicLinkError:
pass pass
finally: finally:
os.unlink(SYMFILENAME) os.unlink(SYMFILENAME)
def test_multistore_non_existent_file(self): def test_multistore_non_existent_file(self):
store = multistore_file.get_credential_storage( store = multistore_file.get_credential_storage(
FILENAME, FILENAME,
'some_client_id', 'some_client_id',
'user-agent/1.0', 'user-agent/1.0',
['some-scope', 'some-other-scope']) ['some-scope', 'some-other-scope'])
credentials = store.get() credentials = store.get()
self.assertEquals(None, credentials) self.assertEquals(None, credentials)
def test_multistore_file(self): def test_multistore_file(self):
credentials = self.create_test_credentials() credentials = self.create_test_credentials()
store = multistore_file.get_credential_storage( store = multistore_file.get_credential_storage(
FILENAME, FILENAME,
credentials.client_id, credentials.client_id,
credentials.user_agent, credentials.user_agent,
['some-scope', 'some-other-scope']) ['some-scope', 'some-other-scope'])
store.put(credentials) store.put(credentials)
credentials = store.get() credentials = store.get()
self.assertNotEquals(None, credentials) self.assertNotEquals(None, credentials)
self.assertEquals('foo', credentials.access_token) self.assertEquals('foo', credentials.access_token)
store.delete() store.delete()
credentials = store.get() credentials = store.get()
self.assertEquals(None, credentials) self.assertEquals(None, credentials)
if os.name == 'posix': if os.name == 'posix':
self.assertEquals('0o600', oct(stat.S_IMODE(os.stat(FILENAME).st_mode))) self.assertEquals('0o600', oct(stat.S_IMODE(os.stat(FILENAME).st_mode)))
def test_multistore_file_custom_key(self): def test_multistore_file_custom_key(self):
credentials = self.create_test_credentials() credentials = self.create_test_credentials()
custom_key = {'myapp': 'testing', 'clientid': 'some client'} custom_key = {'myapp': 'testing', 'clientid': 'some client'}
store = multistore_file.get_credential_storage_custom_key( store = multistore_file.get_credential_storage_custom_key(
FILENAME, custom_key) FILENAME, custom_key)
store.put(credentials) store.put(credentials)
stored_credentials = store.get() stored_credentials = store.get()
self.assertNotEquals(None, stored_credentials) self.assertNotEquals(None, stored_credentials)
self.assertEqual(credentials.access_token, stored_credentials.access_token) self.assertEqual(credentials.access_token, stored_credentials.access_token)
store.delete() store.delete()
stored_credentials = store.get() stored_credentials = store.get()
self.assertEquals(None, stored_credentials) self.assertEquals(None, stored_credentials)
def test_multistore_file_custom_string_key(self): def test_multistore_file_custom_string_key(self):
credentials = self.create_test_credentials() credentials = self.create_test_credentials()
# store with string key # store with string key
store = multistore_file.get_credential_storage_custom_string_key( store = multistore_file.get_credential_storage_custom_string_key(
FILENAME, 'mykey') FILENAME, 'mykey')
store.put(credentials) store.put(credentials)
stored_credentials = store.get() stored_credentials = store.get()
self.assertNotEquals(None, stored_credentials) self.assertNotEquals(None, stored_credentials)
self.assertEqual(credentials.access_token, stored_credentials.access_token) self.assertEqual(credentials.access_token, stored_credentials.access_token)
# try retrieving with a dictionary # try retrieving with a dictionary
store_dict = multistore_file.get_credential_storage_custom_string_key( store_dict = multistore_file.get_credential_storage_custom_string_key(
FILENAME, {'key': 'mykey'}) FILENAME, {'key': 'mykey'})
stored_credentials = store.get() stored_credentials = store.get()
self.assertNotEquals(None, stored_credentials) self.assertNotEquals(None, stored_credentials)
self.assertEqual(credentials.access_token, stored_credentials.access_token) self.assertEqual(credentials.access_token, stored_credentials.access_token)
store.delete() store.delete()
stored_credentials = store.get() stored_credentials = store.get()
self.assertEquals(None, stored_credentials) self.assertEquals(None, stored_credentials)
def test_multistore_file_backwards_compatibility(self): def test_multistore_file_backwards_compatibility(self):
credentials = self.create_test_credentials() credentials = self.create_test_credentials()
scopes = ['scope1', 'scope2'] scopes = ['scope1', 'scope2']
# store the credentials using the legacy key method # store the credentials using the legacy key method
store = multistore_file.get_credential_storage( store = multistore_file.get_credential_storage(
FILENAME, 'client_id', 'user_agent', scopes) FILENAME, 'client_id', 'user_agent', scopes)
store.put(credentials) store.put(credentials)
# retrieve the credentials using a custom key that matches the legacy key # retrieve the credentials using a custom key that matches the legacy key
key = {'clientId': 'client_id', 'userAgent': 'user_agent', key = {'clientId': 'client_id', 'userAgent': 'user_agent',
'scope': util.scopes_to_string(scopes)} 'scope': util.scopes_to_string(scopes)}
store = multistore_file.get_credential_storage_custom_key(FILENAME, key) store = multistore_file.get_credential_storage_custom_key(FILENAME, key)
stored_credentials = store.get() stored_credentials = store.get()
self.assertEqual(credentials.access_token, stored_credentials.access_token) self.assertEqual(credentials.access_token, stored_credentials.access_token)
def test_multistore_file_get_all_keys(self): def test_multistore_file_get_all_keys(self):
# start with no keys # start with no keys
keys = multistore_file.get_all_credential_keys(FILENAME) keys = multistore_file.get_all_credential_keys(FILENAME)
self.assertEquals([], keys) self.assertEquals([], keys)
# store credentials # store credentials
credentials = self.create_test_credentials(client_id='client1') credentials = self.create_test_credentials(client_id='client1')
custom_key = {'myapp': 'testing', 'clientid': 'client1'} custom_key = {'myapp': 'testing', 'clientid': 'client1'}
store1 = multistore_file.get_credential_storage_custom_key( store1 = multistore_file.get_credential_storage_custom_key(
FILENAME, custom_key) FILENAME, custom_key)
store1.put(credentials) store1.put(credentials)
keys = multistore_file.get_all_credential_keys(FILENAME) keys = multistore_file.get_all_credential_keys(FILENAME)
self.assertEquals([custom_key], keys) self.assertEquals([custom_key], keys)
# store more credentials # store more credentials
credentials = self.create_test_credentials(client_id='client2') credentials = self.create_test_credentials(client_id='client2')
string_key = 'string_key' string_key = 'string_key'
store2 = multistore_file.get_credential_storage_custom_string_key( store2 = multistore_file.get_credential_storage_custom_string_key(
FILENAME, string_key) FILENAME, string_key)
store2.put(credentials) store2.put(credentials)
keys = multistore_file.get_all_credential_keys(FILENAME) keys = multistore_file.get_all_credential_keys(FILENAME)
self.assertEquals(2, len(keys)) self.assertEquals(2, len(keys))
self.assertTrue(custom_key in keys) self.assertTrue(custom_key in keys)
self.assertTrue({'key': string_key} in keys) self.assertTrue({'key': string_key} in keys)
# back to no keys # back to no keys
store1.delete() store1.delete()
store2.delete() store2.delete()
keys = multistore_file.get_all_credential_keys(FILENAME) keys = multistore_file.get_all_credential_keys(FILENAME)
self.assertEquals([], keys) self.assertEquals([], keys)
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()

View File

@@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
"""Unit tests for the Flask utilities""" """Unit tests for the Flask utilities"""
__author__ = 'jonwayne@google.com (Jon Wayne Parrott)' __author__ = 'jonwayne@google.com (Jon Wayne Parrott)'
@@ -35,6 +34,7 @@ from oauth2client.client import OAuth2Credentials
class Http2Mock(object): class Http2Mock(object):
"""Mock httplib2.Http for code exchange / refresh""" """Mock httplib2.Http for code exchange / refresh"""
def __init__(self, status=httplib.OK, **kwargs): def __init__(self, status=httplib.OK, **kwargs):
self.status = status self.status = status
self.content = { self.content = {

View File

@@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
"""Tests for oauth2client.gce. """Tests for oauth2client.gce.
Unit tests for oauth2client.gce. Unit tests for oauth2client.gce.
@@ -36,86 +35,86 @@ from oauth2client.gce import AppAssertionCredentials
class AssertionCredentialsTests(unittest.TestCase): class AssertionCredentialsTests(unittest.TestCase):
def _refresh_success_helper(self, bytes_response=False): def _refresh_success_helper(self, bytes_response=False):
access_token = u'this-is-a-token' access_token = u'this-is-a-token'
return_val = json.dumps({u'accessToken': access_token}) return_val = json.dumps({u'accessToken': access_token})
if bytes_response: if bytes_response:
return_val = _to_bytes(return_val) return_val = _to_bytes(return_val)
http = mock.MagicMock() http = mock.MagicMock()
http.request = mock.MagicMock( http.request = mock.MagicMock(
return_value=(mock.Mock(status=200), return_val)) return_value=(mock.Mock(status=200), return_val))
scopes = ['http://example.com/a', 'http://example.com/b'] scopes = ['http://example.com/a', 'http://example.com/b']
credentials = AppAssertionCredentials(scope=scopes) credentials = AppAssertionCredentials(scope=scopes)
self.assertEquals(None, credentials.access_token) self.assertEquals(None, credentials.access_token)
credentials.refresh(http) credentials.refresh(http)
self.assertEquals(access_token, credentials.access_token) self.assertEquals(access_token, credentials.access_token)
base_metadata_uri = ('http://metadata.google.internal/0.1/meta-data/' base_metadata_uri = ('http://metadata.google.internal/0.1/meta-data/'
'service-accounts/default/acquire') 'service-accounts/default/acquire')
escaped_scopes = urllib.parse.quote(' '.join(scopes), safe='') escaped_scopes = urllib.parse.quote(' '.join(scopes), safe='')
request_uri = base_metadata_uri + '?scope=' + escaped_scopes request_uri = base_metadata_uri + '?scope=' + escaped_scopes
http.request.assert_called_once_with(request_uri) http.request.assert_called_once_with(request_uri)
def test_refresh_success(self): def test_refresh_success(self):
self._refresh_success_helper(bytes_response=False) self._refresh_success_helper(bytes_response=False)
def test_refresh_success_bytes(self): def test_refresh_success_bytes(self):
self._refresh_success_helper(bytes_response=True) self._refresh_success_helper(bytes_response=True)
def test_fail_refresh(self): def test_fail_refresh(self):
http = mock.MagicMock() http = mock.MagicMock()
http.request = mock.MagicMock(return_value=(mock.Mock(status=400), '{}')) http.request = mock.MagicMock(return_value=(mock.Mock(status=400), '{}'))
c = AppAssertionCredentials(scope=['http://example.com/a', c = AppAssertionCredentials(scope=['http://example.com/a',
'http://example.com/b']) 'http://example.com/b'])
self.assertRaises(AccessTokenRefreshError, c.refresh, http) self.assertRaises(AccessTokenRefreshError, c.refresh, http)
def test_to_from_json(self): def test_to_from_json(self):
c = AppAssertionCredentials(scope=['http://example.com/a', c = AppAssertionCredentials(scope=['http://example.com/a',
'http://example.com/b']) 'http://example.com/b'])
json = c.to_json() json = c.to_json()
c2 = Credentials.new_from_json(json) c2 = Credentials.new_from_json(json)
self.assertEqual(c.access_token, c2.access_token) self.assertEqual(c.access_token, c2.access_token)
def test_create_scoped_required_without_scopes(self): def test_create_scoped_required_without_scopes(self):
credentials = AppAssertionCredentials([]) credentials = AppAssertionCredentials([])
self.assertTrue(credentials.create_scoped_required()) self.assertTrue(credentials.create_scoped_required())
def test_create_scoped_required_with_scopes(self): def test_create_scoped_required_with_scopes(self):
credentials = AppAssertionCredentials(['dummy_scope']) credentials = AppAssertionCredentials(['dummy_scope'])
self.assertFalse(credentials.create_scoped_required()) self.assertFalse(credentials.create_scoped_required())
def test_create_scoped(self): def test_create_scoped(self):
credentials = AppAssertionCredentials([]) credentials = AppAssertionCredentials([])
new_credentials = credentials.create_scoped(['dummy_scope']) new_credentials = credentials.create_scoped(['dummy_scope'])
self.assertNotEqual(credentials, new_credentials) self.assertNotEqual(credentials, new_credentials)
self.assertTrue(isinstance(new_credentials, AppAssertionCredentials)) self.assertTrue(isinstance(new_credentials, AppAssertionCredentials))
self.assertEqual('dummy_scope', new_credentials.scope) self.assertEqual('dummy_scope', new_credentials.scope)
def test_get_access_token(self): def test_get_access_token(self):
http = mock.MagicMock() http = mock.MagicMock()
http.request = mock.MagicMock( http.request = mock.MagicMock(
return_value=(mock.Mock(status=200), return_value=(mock.Mock(status=200),
'{"accessToken": "this-is-a-token"}')) '{"accessToken": "this-is-a-token"}'))
credentials = AppAssertionCredentials(['dummy_scope']) credentials = AppAssertionCredentials(['dummy_scope'])
token = credentials.get_access_token(http=http) token = credentials.get_access_token(http=http)
self.assertEqual('this-is-a-token', token.access_token) self.assertEqual('this-is-a-token', token.access_token)
self.assertEqual(None, token.expires_in) self.assertEqual(None, token.expires_in)
http.request.assert_called_once_with( http.request.assert_called_once_with(
'http://metadata.google.internal/0.1/meta-data/service-accounts/' 'http://metadata.google.internal/0.1/meta-data/service-accounts/'
'default/acquire?scope=dummy_scope') 'default/acquire?scope=dummy_scope')
def test_save_to_well_known_file(self): def test_save_to_well_known_file(self):
import os import os
ORIGINAL_ISDIR = os.path.isdir ORIGINAL_ISDIR = os.path.isdir
try: try:
os.path.isdir = lambda path: True os.path.isdir = lambda path: True
credentials = AppAssertionCredentials([]) credentials = AppAssertionCredentials([])
self.assertRaises(NotImplementedError, save_to_well_known_file, self.assertRaises(NotImplementedError, save_to_well_known_file,
credentials) credentials)
finally: finally:
os.path.isdir = ORIGINAL_ISDIR os.path.isdir = ORIGINAL_ISDIR

View File

@@ -19,9 +19,9 @@ import unittest
class ImportTest(unittest.TestCase): class ImportTest(unittest.TestCase):
def test_tools_import(self): def test_tools_import(self):
import oauth2client.tools import oauth2client.tools
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()

View File

@@ -14,7 +14,6 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
"""Oauth2client tests """Oauth2client tests
Unit tests for oauth2client. Unit tests for oauth2client.
@@ -42,295 +41,295 @@ from oauth2client.file import Storage
def datafile(filename): def datafile(filename):
f = open(os.path.join(os.path.dirname(__file__), 'data', filename), 'rb') f = open(os.path.join(os.path.dirname(__file__), 'data', filename), 'rb')
data = f.read() data = f.read()
f.close() f.close()
return data return data
class CryptTests(unittest.TestCase): class CryptTests(unittest.TestCase):
def setUp(self): def setUp(self):
self.format = 'p12' self.format = 'p12'
self.signer = crypt.OpenSSLSigner self.signer = crypt.OpenSSLSigner
self.verifier = crypt.OpenSSLVerifier self.verifier = crypt.OpenSSLVerifier
def test_sign_and_verify(self): def test_sign_and_verify(self):
self._check_sign_and_verify('privatekey.%s' % self.format) self._check_sign_and_verify('privatekey.%s' % self.format)
def test_sign_and_verify_from_converted_pkcs12(self): def test_sign_and_verify_from_converted_pkcs12(self):
# Tests that following instructions to convert from PKCS12 to PEM works. # Tests that following instructions to convert from PKCS12 to PEM works.
if self.format == 'pem': if self.format == 'pem':
self._check_sign_and_verify('pem_from_pkcs12.pem') self._check_sign_and_verify('pem_from_pkcs12.pem')
def _check_sign_and_verify(self, private_key_file): def _check_sign_and_verify(self, private_key_file):
private_key = datafile(private_key_file) private_key = datafile(private_key_file)
public_key = datafile('publickey.pem') public_key = datafile('publickey.pem')
# We pass in a non-bytes password to make sure all branches # We pass in a non-bytes password to make sure all branches
# are traversed in tests. # are traversed in tests.
signer = self.signer.from_string(private_key, signer = self.signer.from_string(private_key,
password=u'notasecret') password=u'notasecret')
signature = signer.sign('foo') signature = signer.sign('foo')
verifier = self.verifier.from_string(public_key, True) verifier = self.verifier.from_string(public_key, True)
self.assertTrue(verifier.verify(b'foo', signature)) self.assertTrue(verifier.verify(b'foo', signature))
self.assertFalse(verifier.verify(b'bar', signature)) self.assertFalse(verifier.verify(b'bar', signature))
self.assertFalse(verifier.verify(b'foo', b'bad signagure')) self.assertFalse(verifier.verify(b'foo', b'bad signagure'))
self.assertFalse(verifier.verify(b'foo', u'bad signagure')) self.assertFalse(verifier.verify(b'foo', u'bad signagure'))
def _check_jwt_failure(self, jwt, expected_error): def _check_jwt_failure(self, jwt, expected_error):
public_key = datafile('publickey.pem') public_key = datafile('publickey.pem')
certs = {'foo': public_key} certs = {'foo': public_key}
audience = ('https://www.googleapis.com/auth/id?client_id=' audience = ('https://www.googleapis.com/auth/id?client_id='
'external_public_key@testing.gserviceaccount.com') 'external_public_key@testing.gserviceaccount.com')
try: try:
crypt.verify_signed_jwt_with_certs(jwt, certs, audience) crypt.verify_signed_jwt_with_certs(jwt, certs, audience)
self.fail() self.fail()
except crypt.AppIdentityError as e: except crypt.AppIdentityError as e:
self.assertTrue(expected_error in str(e)) self.assertTrue(expected_error in str(e))
def _create_signed_jwt(self): def _create_signed_jwt(self):
private_key = datafile('privatekey.%s' % self.format) private_key = datafile('privatekey.%s' % self.format)
signer = self.signer.from_string(private_key) signer = self.signer.from_string(private_key)
audience = 'some_audience_address@testing.gserviceaccount.com' audience = 'some_audience_address@testing.gserviceaccount.com'
now = int(time.time()) now = int(time.time())
return crypt.make_signed_jwt(signer, { return crypt.make_signed_jwt(signer, {
'aud': audience, 'aud': audience,
'iat': now, 'iat': now,
'exp': now + 300, 'exp': now + 300,
'user': 'billy bob', 'user': 'billy bob',
'metadata': {'meta': 'data'}, 'metadata': {'meta': 'data'},
}) })
def test_verify_id_token(self): def test_verify_id_token(self):
jwt = self._create_signed_jwt() jwt = self._create_signed_jwt()
public_key = datafile('publickey.pem') public_key = datafile('publickey.pem')
certs = {'foo': public_key} certs = {'foo': public_key}
audience = 'some_audience_address@testing.gserviceaccount.com' audience = 'some_audience_address@testing.gserviceaccount.com'
contents = crypt.verify_signed_jwt_with_certs(jwt, certs, audience) contents = crypt.verify_signed_jwt_with_certs(jwt, certs, audience)
self.assertEqual('billy bob', contents['user']) self.assertEqual('billy bob', contents['user'])
self.assertEqual('data', contents['metadata']['meta']) self.assertEqual('data', contents['metadata']['meta'])
def test_verify_id_token_with_certs_uri(self): def test_verify_id_token_with_certs_uri(self):
jwt = self._create_signed_jwt() jwt = self._create_signed_jwt()
http = HttpMockSequence([ http = HttpMockSequence([
({'status': '200'}, datafile('certs.json')), ({'status': '200'}, datafile('certs.json')),
]) ])
contents = verify_id_token( contents = verify_id_token(
jwt, 'some_audience_address@testing.gserviceaccount.com', http=http) jwt, 'some_audience_address@testing.gserviceaccount.com', http=http)
self.assertEqual('billy bob', contents['user']) self.assertEqual('billy bob', contents['user'])
self.assertEqual('data', contents['metadata']['meta']) self.assertEqual('data', contents['metadata']['meta'])
def test_verify_id_token_with_certs_uri_fails(self): def test_verify_id_token_with_certs_uri_fails(self):
jwt = self._create_signed_jwt() jwt = self._create_signed_jwt()
http = HttpMockSequence([ http = HttpMockSequence([
({'status': '404'}, datafile('certs.json')), ({'status': '404'}, datafile('certs.json')),
]) ])
self.assertRaises(VerifyJwtTokenError, verify_id_token, jwt, self.assertRaises(VerifyJwtTokenError, verify_id_token, jwt,
'some_audience_address@testing.gserviceaccount.com', 'some_audience_address@testing.gserviceaccount.com',
http=http) http=http)
def test_verify_id_token_bad_tokens(self): def test_verify_id_token_bad_tokens(self):
private_key = datafile('privatekey.%s' % self.format) private_key = datafile('privatekey.%s' % self.format)
# Wrong number of segments # Wrong number of segments
self._check_jwt_failure('foo', 'Wrong number of segments') self._check_jwt_failure('foo', 'Wrong number of segments')
# Not json # Not json
self._check_jwt_failure('foo.bar.baz', 'Can\'t parse token') self._check_jwt_failure('foo.bar.baz', 'Can\'t parse token')
# Bad signature # Bad signature
jwt = b'.'.join([b'foo', crypt._urlsafe_b64encode('{"a":"b"}'), b'baz']) jwt = b'.'.join([b'foo', crypt._urlsafe_b64encode('{"a":"b"}'), b'baz'])
self._check_jwt_failure(jwt, 'Invalid token signature') self._check_jwt_failure(jwt, 'Invalid token signature')
# No expiration # No expiration
signer = self.signer.from_string(private_key) signer = self.signer.from_string(private_key)
audience = ('https:#www.googleapis.com/auth/id?client_id=' audience = ('https:#www.googleapis.com/auth/id?client_id='
'external_public_key@testing.gserviceaccount.com') 'external_public_key@testing.gserviceaccount.com')
jwt = crypt.make_signed_jwt(signer, { jwt = crypt.make_signed_jwt(signer, {
'aud': audience, 'aud': audience,
'iat': time.time(), 'iat': time.time(),
}) })
self._check_jwt_failure(jwt, 'No exp field in token') self._check_jwt_failure(jwt, 'No exp field in token')
# No issued at # No issued at
jwt = crypt.make_signed_jwt(signer, { jwt = crypt.make_signed_jwt(signer, {
'aud': 'audience', 'aud': 'audience',
'exp': time.time() + 400, 'exp': time.time() + 400,
}) })
self._check_jwt_failure(jwt, 'No iat field in token') self._check_jwt_failure(jwt, 'No iat field in token')
# Too early # Too early
jwt = crypt.make_signed_jwt(signer, { jwt = crypt.make_signed_jwt(signer, {
'aud': 'audience', 'aud': 'audience',
'iat': time.time() + 301, 'iat': time.time() + 301,
'exp': time.time() + 400, 'exp': time.time() + 400,
}) })
self._check_jwt_failure(jwt, 'Token used too early') self._check_jwt_failure(jwt, 'Token used too early')
# Too late # Too late
jwt = crypt.make_signed_jwt(signer, { jwt = crypt.make_signed_jwt(signer, {
'aud': 'audience', 'aud': 'audience',
'iat': time.time() - 500, 'iat': time.time() - 500,
'exp': time.time() - 301, 'exp': time.time() - 301,
}) })
self._check_jwt_failure(jwt, 'Token used too late') self._check_jwt_failure(jwt, 'Token used too late')
# Wrong target # Wrong target
jwt = crypt.make_signed_jwt(signer, { jwt = crypt.make_signed_jwt(signer, {
'aud': 'somebody else', 'aud': 'somebody else',
'iat': time.time(), 'iat': time.time(),
'exp': time.time() + 300, 'exp': time.time() + 300,
}) })
self._check_jwt_failure(jwt, 'Wrong recipient') self._check_jwt_failure(jwt, 'Wrong recipient')
def test_from_string_non_509_cert(self): def test_from_string_non_509_cert(self):
# Use a private key instead of a certificate to test the other branch # Use a private key instead of a certificate to test the other branch
# of from_string(). # of from_string().
public_key = datafile('privatekey.pem') public_key = datafile('privatekey.pem')
verifier = self.verifier.from_string(public_key, is_x509_cert=False) verifier = self.verifier.from_string(public_key, is_x509_cert=False)
self.assertTrue(isinstance(verifier, self.verifier)) self.assertTrue(isinstance(verifier, self.verifier))
class PEMCryptTestsPyCrypto(CryptTests): class PEMCryptTestsPyCrypto(CryptTests):
def setUp(self): def setUp(self):
self.format = 'pem' self.format = 'pem'
self.signer = crypt.PyCryptoSigner self.signer = crypt.PyCryptoSigner
self.verifier = crypt.PyCryptoVerifier self.verifier = crypt.PyCryptoVerifier
class PEMCryptTestsOpenSSL(CryptTests): class PEMCryptTestsOpenSSL(CryptTests):
def setUp(self): def setUp(self):
self.format = 'pem' self.format = 'pem'
self.signer = crypt.OpenSSLSigner self.signer = crypt.OpenSSLSigner
self.verifier = crypt.OpenSSLVerifier self.verifier = crypt.OpenSSLVerifier
class SignedJwtAssertionCredentialsTests(unittest.TestCase): class SignedJwtAssertionCredentialsTests(unittest.TestCase):
def setUp(self): def setUp(self):
self.format = 'p12' self.format = 'p12'
crypt.Signer = crypt.OpenSSLSigner crypt.Signer = crypt.OpenSSLSigner
def test_credentials_good(self): def test_credentials_good(self):
private_key = datafile('privatekey.%s' % self.format) private_key = datafile('privatekey.%s' % self.format)
credentials = SignedJwtAssertionCredentials( credentials = SignedJwtAssertionCredentials(
'some_account@example.com', 'some_account@example.com',
private_key, private_key,
scope='read+write', scope='read+write',
sub='joe@example.org') sub='joe@example.org')
http = HttpMockSequence([ http = HttpMockSequence([
({'status': '200'}, b'{"access_token":"1/3w","expires_in":3600}'), ({'status': '200'}, b'{"access_token":"1/3w","expires_in":3600}'),
({'status': '200'}, 'echo_request_headers'), ({'status': '200'}, 'echo_request_headers'),
]) ])
http = credentials.authorize(http) http = credentials.authorize(http)
resp, content = http.request('http://example.org') resp, content = http.request('http://example.org')
self.assertEqual(b'Bearer 1/3w', content[b'Authorization']) self.assertEqual(b'Bearer 1/3w', content[b'Authorization'])
def test_credentials_to_from_json(self): def test_credentials_to_from_json(self):
private_key = datafile('privatekey.%s' % self.format) private_key = datafile('privatekey.%s' % self.format)
credentials = SignedJwtAssertionCredentials( credentials = SignedJwtAssertionCredentials(
'some_account@example.com', 'some_account@example.com',
private_key, private_key,
scope='read+write', scope='read+write',
sub='joe@example.org') sub='joe@example.org')
json = credentials.to_json() json = credentials.to_json()
restored = Credentials.new_from_json(json) restored = Credentials.new_from_json(json)
self.assertEqual(credentials.private_key, restored.private_key) self.assertEqual(credentials.private_key, restored.private_key)
self.assertEqual(credentials.private_key_password, self.assertEqual(credentials.private_key_password,
restored.private_key_password) restored.private_key_password)
self.assertEqual(credentials.kwargs, restored.kwargs) self.assertEqual(credentials.kwargs, restored.kwargs)
def _credentials_refresh(self, credentials): def _credentials_refresh(self, credentials):
http = HttpMockSequence([ http = HttpMockSequence([
({'status': '200'}, b'{"access_token":"1/3w","expires_in":3600}'), ({'status': '200'}, b'{"access_token":"1/3w","expires_in":3600}'),
({'status': '401'}, b''), ({'status': '401'}, b''),
({'status': '200'}, b'{"access_token":"3/3w","expires_in":3600}'), ({'status': '200'}, b'{"access_token":"3/3w","expires_in":3600}'),
({'status': '200'}, 'echo_request_headers'), ({'status': '200'}, 'echo_request_headers'),
]) ])
http = credentials.authorize(http) http = credentials.authorize(http)
_, content = http.request('http://example.org') _, content = http.request('http://example.org')
return content return content
def test_credentials_refresh_without_storage(self): def test_credentials_refresh_without_storage(self):
private_key = datafile('privatekey.%s' % self.format) private_key = datafile('privatekey.%s' % self.format)
credentials = SignedJwtAssertionCredentials( credentials = SignedJwtAssertionCredentials(
'some_account@example.com', 'some_account@example.com',
private_key, private_key,
scope='read+write', scope='read+write',
sub='joe@example.org') sub='joe@example.org')
content = self._credentials_refresh(credentials) content = self._credentials_refresh(credentials)
self.assertEqual(b'Bearer 3/3w', content[b'Authorization']) self.assertEqual(b'Bearer 3/3w', content[b'Authorization'])
def test_credentials_refresh_with_storage(self): def test_credentials_refresh_with_storage(self):
private_key = datafile('privatekey.%s' % self.format) private_key = datafile('privatekey.%s' % self.format)
credentials = SignedJwtAssertionCredentials( credentials = SignedJwtAssertionCredentials(
'some_account@example.com', 'some_account@example.com',
private_key, private_key,
scope='read+write', scope='read+write',
sub='joe@example.org') sub='joe@example.org')
(filehandle, filename) = tempfile.mkstemp() (filehandle, filename) = tempfile.mkstemp()
os.close(filehandle) os.close(filehandle)
store = Storage(filename) store = Storage(filename)
store.put(credentials) store.put(credentials)
credentials.set_store(store) credentials.set_store(store)
content = self._credentials_refresh(credentials) content = self._credentials_refresh(credentials)
self.assertEqual(b'Bearer 3/3w', content[b'Authorization']) self.assertEqual(b'Bearer 3/3w', content[b'Authorization'])
os.unlink(filename) os.unlink(filename)
class PEMSignedJwtAssertionCredentialsOpenSSLTests( class PEMSignedJwtAssertionCredentialsOpenSSLTests(
SignedJwtAssertionCredentialsTests): SignedJwtAssertionCredentialsTests):
def setUp(self): def setUp(self):
self.format = 'pem' self.format = 'pem'
crypt.Signer = crypt.OpenSSLSigner crypt.Signer = crypt.OpenSSLSigner
class PEMSignedJwtAssertionCredentialsPyCryptoTests( class PEMSignedJwtAssertionCredentialsPyCryptoTests(
SignedJwtAssertionCredentialsTests): SignedJwtAssertionCredentialsTests):
def setUp(self): def setUp(self):
self.format = 'pem' self.format = 'pem'
crypt.Signer = crypt.PyCryptoSigner crypt.Signer = crypt.PyCryptoSigner
class PKCSSignedJwtAssertionCredentialsPyCryptoTests(unittest.TestCase): class PKCSSignedJwtAssertionCredentialsPyCryptoTests(unittest.TestCase):
def test_for_failure(self): def test_for_failure(self):
crypt.Signer = crypt.PyCryptoSigner crypt.Signer = crypt.PyCryptoSigner
private_key = datafile('privatekey.p12') private_key = datafile('privatekey.p12')
credentials = SignedJwtAssertionCredentials( credentials = SignedJwtAssertionCredentials(
'some_account@example.com', 'some_account@example.com',
private_key, private_key,
scope='read+write', scope='read+write',
sub='joe@example.org') sub='joe@example.org')
try: try:
credentials._generate_assertion() credentials._generate_assertion()
self.fail() self.fail()
except NotImplementedError: except NotImplementedError:
pass pass
class TestHasOpenSSLFlag(unittest.TestCase): class TestHasOpenSSLFlag(unittest.TestCase):
def test_true(self): def test_true(self):
self.assertEqual(True, HAS_OPENSSL) self.assertEqual(True, HAS_OPENSSL)
self.assertEqual(True, HAS_CRYPTO) self.assertEqual(True, HAS_CRYPTO)
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()

View File

@@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
"""Tests for oauth2client.keyring_storage tests. """Tests for oauth2client.keyring_storage tests.
Unit tests for oauth2client.keyring_storage. Unit tests for oauth2client.keyring_storage.
@@ -33,59 +32,59 @@ from oauth2client.keyring_storage import Storage
class OAuth2ClientKeyringTests(unittest.TestCase): class OAuth2ClientKeyringTests(unittest.TestCase):
def test_non_existent_credentials_storage(self): def test_non_existent_credentials_storage(self):
with mock.patch.object(keyring, 'get_password', with mock.patch.object(keyring, 'get_password',
return_value=None, return_value=None,
autospec=True) as get_password: autospec=True) as get_password:
s = Storage('my_unit_test', 'me') s = Storage('my_unit_test', 'me')
credentials = s.get() credentials = s.get()
self.assertEquals(None, credentials) self.assertEquals(None, credentials)
get_password.assert_called_once_with('my_unit_test', 'me') get_password.assert_called_once_with('my_unit_test', 'me')
def test_malformed_credentials_in_storage(self): def test_malformed_credentials_in_storage(self):
with mock.patch.object(keyring, 'get_password', with mock.patch.object(keyring, 'get_password',
return_value='{', return_value='{',
autospec=True) as get_password: autospec=True) as get_password:
s = Storage('my_unit_test', 'me') s = Storage('my_unit_test', 'me')
credentials = s.get() credentials = s.get()
self.assertEquals(None, credentials) self.assertEquals(None, credentials)
get_password.assert_called_once_with('my_unit_test', 'me') get_password.assert_called_once_with('my_unit_test', 'me')
def test_json_credentials_storage(self): def test_json_credentials_storage(self):
access_token = 'foo' access_token = 'foo'
client_id = 'some_client_id' client_id = 'some_client_id'
client_secret = 'cOuDdkfjxxnv+' client_secret = 'cOuDdkfjxxnv+'
refresh_token = '1/0/a.df219fjls0' refresh_token = '1/0/a.df219fjls0'
token_expiry = datetime.datetime.utcnow() token_expiry = datetime.datetime.utcnow()
user_agent = 'refresh_checker/1.0' user_agent = 'refresh_checker/1.0'
credentials = OAuth2Credentials( credentials = OAuth2Credentials(
access_token, client_id, client_secret, access_token, client_id, client_secret,
refresh_token, token_expiry, GOOGLE_TOKEN_URI, refresh_token, token_expiry, GOOGLE_TOKEN_URI,
user_agent) user_agent)
# Setting autospec on a mock with an iterable side_effect is # Setting autospec on a mock with an iterable side_effect is
# currently broken (http://bugs.python.org/issue17826), so instead # currently broken (http://bugs.python.org/issue17826), so instead
# we patch twice. # we patch twice.
with mock.patch.object(keyring, 'get_password', with mock.patch.object(keyring, 'get_password',
return_value=None, return_value=None,
autospec=True) as get_password: autospec=True) as get_password:
with mock.patch.object(keyring, 'set_password', with mock.patch.object(keyring, 'set_password',
return_value=None, return_value=None,
autospec=True) as set_password: autospec=True) as set_password:
s = Storage('my_unit_test', 'me') s = Storage('my_unit_test', 'me')
self.assertEquals(None, s.get()) self.assertEquals(None, s.get())
s.put(credentials) s.put(credentials)
set_password.assert_called_once_with( set_password.assert_called_once_with(
'my_unit_test', 'me', credentials.to_json()) 'my_unit_test', 'me', credentials.to_json())
get_password.assert_called_once_with('my_unit_test', 'me') get_password.assert_called_once_with('my_unit_test', 'me')
with mock.patch.object(keyring, 'get_password', with mock.patch.object(keyring, 'get_password',
return_value=credentials.to_json(), return_value=credentials.to_json(),
autospec=True) as get_password: autospec=True) as get_password:
restored = s.get() restored = s.get()
self.assertEqual('foo', restored.access_token) self.assertEqual('foo', restored.access_token)
self.assertEqual('some_client_id', restored.client_id) self.assertEqual('some_client_id', restored.client_id)
get_password.assert_called_once_with('my_unit_test', 'me') get_password.assert_called_once_with('my_unit_test', 'me')

File diff suppressed because it is too large Load Diff

View File

@@ -14,7 +14,6 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
"""Oauth2client tests. """Oauth2client tests.
Unit tests for service account credentials implemented using RSA. Unit tests for service account credentials implemented using RSA.
@@ -31,94 +30,94 @@ from oauth2client.service_account import _ServiceAccountCredentials
def datafile(filename): def datafile(filename):
# TODO(orestica): Refactor this using pkgutil.get_data # TODO(orestica): Refactor this using pkgutil.get_data
f = open(os.path.join(os.path.dirname(__file__), 'data', filename), 'rb') f = open(os.path.join(os.path.dirname(__file__), 'data', filename), 'rb')
data = f.read() data = f.read()
f.close() f.close()
return data return data
class ServiceAccountCredentialsTests(unittest.TestCase): class ServiceAccountCredentialsTests(unittest.TestCase):
def setUp(self): def setUp(self):
self.service_account_id = '123' self.service_account_id = '123'
self.service_account_email = 'dummy@google.com' self.service_account_email = 'dummy@google.com'
self.private_key_id = 'ABCDEF' self.private_key_id = 'ABCDEF'
self.private_key = datafile('pem_from_pkcs12.pem') self.private_key = datafile('pem_from_pkcs12.pem')
self.scopes = ['dummy_scope'] self.scopes = ['dummy_scope']
self.credentials = _ServiceAccountCredentials(self.service_account_id, self.credentials = _ServiceAccountCredentials(self.service_account_id,
self.service_account_email, self.service_account_email,
self.private_key_id, self.private_key_id,
self.private_key, self.private_key,
[]) [])
def test_sign_blob(self): def test_sign_blob(self):
private_key_id, signature = self.credentials.sign_blob('Google') private_key_id, signature = self.credentials.sign_blob('Google')
self.assertEqual( self.private_key_id, private_key_id) self.assertEqual(self.private_key_id, private_key_id)
pub_key = rsa.PublicKey.load_pkcs1_openssl_pem( pub_key = rsa.PublicKey.load_pkcs1_openssl_pem(
datafile('publickey_openssl.pem')) datafile('publickey_openssl.pem'))
self.assertTrue(rsa.pkcs1.verify(b'Google', signature, pub_key)) self.assertTrue(rsa.pkcs1.verify(b'Google', signature, pub_key))
try: try:
rsa.pkcs1.verify(b'Orest', signature, pub_key) rsa.pkcs1.verify(b'Orest', signature, pub_key)
self.fail('Verification should have failed!') self.fail('Verification should have failed!')
except rsa.pkcs1.VerificationError: except rsa.pkcs1.VerificationError:
pass # Expected pass # Expected
try: try:
rsa.pkcs1.verify(b'Google', b'bad signature', pub_key) rsa.pkcs1.verify(b'Google', b'bad signature', pub_key)
self.fail('Verification should have failed!') self.fail('Verification should have failed!')
except rsa.pkcs1.VerificationError: except rsa.pkcs1.VerificationError:
pass # Expected pass # Expected
def test_service_account_email(self): def test_service_account_email(self):
self.assertEqual(self.service_account_email, self.assertEqual(self.service_account_email,
self.credentials.service_account_email) self.credentials.service_account_email)
def test_create_scoped_required_without_scopes(self): def test_create_scoped_required_without_scopes(self):
self.assertTrue(self.credentials.create_scoped_required()) self.assertTrue(self.credentials.create_scoped_required())
def test_create_scoped_required_with_scopes(self): def test_create_scoped_required_with_scopes(self):
self.credentials = _ServiceAccountCredentials(self.service_account_id, self.credentials = _ServiceAccountCredentials(self.service_account_id,
self.service_account_email, self.service_account_email,
self.private_key_id, self.private_key_id,
self.private_key, self.private_key,
self.scopes) self.scopes)
self.assertFalse(self.credentials.create_scoped_required()) self.assertFalse(self.credentials.create_scoped_required())
def test_create_scoped(self): def test_create_scoped(self):
new_credentials = self.credentials.create_scoped(self.scopes) new_credentials = self.credentials.create_scoped(self.scopes)
self.assertNotEqual(self.credentials, new_credentials) self.assertNotEqual(self.credentials, new_credentials)
self.assertTrue(isinstance(new_credentials, _ServiceAccountCredentials)) self.assertTrue(isinstance(new_credentials, _ServiceAccountCredentials))
self.assertEqual('dummy_scope', new_credentials._scopes) self.assertEqual('dummy_scope', new_credentials._scopes)
def test_access_token(self): def test_access_token(self):
S = 2 # number of seconds in which the token expires S = 2 # number of seconds in which the token expires
token_response_first = {'access_token': 'first_token', 'expires_in': S} token_response_first = {'access_token': 'first_token', 'expires_in': S}
token_response_second = {'access_token': 'second_token', 'expires_in': S} token_response_second = {'access_token': 'second_token', 'expires_in': S}
http = HttpMockSequence([ http = HttpMockSequence([
({'status': '200'}, json.dumps(token_response_first).encode('utf-8')), ({'status': '200'}, json.dumps(token_response_first).encode('utf-8')),
({'status': '200'}, json.dumps(token_response_second).encode('utf-8')), ({'status': '200'}, json.dumps(token_response_second).encode('utf-8')),
]) ])
token = self.credentials.get_access_token(http=http) token = self.credentials.get_access_token(http=http)
self.assertEqual('first_token', token.access_token) self.assertEqual('first_token', token.access_token)
self.assertEqual(S - 1, token.expires_in) self.assertEqual(S - 1, token.expires_in)
self.assertFalse(self.credentials.access_token_expired) self.assertFalse(self.credentials.access_token_expired)
self.assertEqual(token_response_first, self.credentials.token_response) self.assertEqual(token_response_first, self.credentials.token_response)
token = self.credentials.get_access_token(http=http) token = self.credentials.get_access_token(http=http)
self.assertEqual('first_token', token.access_token) self.assertEqual('first_token', token.access_token)
self.assertEqual(S - 1, token.expires_in) self.assertEqual(S - 1, token.expires_in)
self.assertFalse(self.credentials.access_token_expired) self.assertFalse(self.credentials.access_token_expired)
self.assertEqual(token_response_first, self.credentials.token_response) self.assertEqual(token_response_first, self.credentials.token_response)
time.sleep(S + 0.5) # some margin to avoid flakiness time.sleep(S + 0.5) # some margin to avoid flakiness
self.assertTrue(self.credentials.access_token_expired) self.assertTrue(self.credentials.access_token_expired)
token = self.credentials.get_access_token(http=http) token = self.credentials.get_access_token(http=http)
self.assertEqual('second_token', token.access_token) self.assertEqual('second_token', token.access_token)
self.assertEqual(S - 1, token.expires_in) self.assertEqual(S - 1, token.expires_in)
self.assertFalse(self.credentials.access_token_expired) self.assertFalse(self.credentials.access_token_expired)
self.assertEqual(token_response_second, self.credentials.token_response) self.assertEqual(token_response_second, self.credentials.token_response)

View File

@@ -5,6 +5,7 @@ from oauth2client import tools
from six.moves.urllib import request from six.moves.urllib import request
import threading import threading
class TestClientRedirectServer(unittest.TestCase): class TestClientRedirectServer(unittest.TestCase):
"""Test the ClientRedirectServer and ClientRedirectHandler classes.""" """Test the ClientRedirectServer and ClientRedirectHandler classes."""
@@ -15,16 +16,15 @@ class TestClientRedirectServer(unittest.TestCase):
httpd = tools.ClientRedirectServer(('localhost', 0), tools.ClientRedirectHandler) httpd = tools.ClientRedirectServer(('localhost', 0), tools.ClientRedirectHandler)
code = 'foo' code = 'foo'
url = 'http://localhost:%i?code=%s' % (httpd.server_address[1], code) url = 'http://localhost:%i?code=%s' % (httpd.server_address[1], code)
t = threading.Thread(target = httpd.handle_request) t = threading.Thread(target=httpd.handle_request)
t.setDaemon(True) t.setDaemon(True)
t.start() t.start()
f = request.urlopen( url ) f = request.urlopen(url)
self.assertTrue(f.read()) self.assertTrue(f.read())
t.join() t.join()
httpd.server_close() httpd.server_close()
self.assertEqual(httpd.query_params.get('code'),code) self.assertEqual(httpd.query_params.get('code'), code)
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()

View File

@@ -9,8 +9,8 @@ from oauth2client import util
class ScopeToStringTests(unittest.TestCase): class ScopeToStringTests(unittest.TestCase):
def test_iterables(self): def test_iterables(self):
cases = [ cases = [
('', ''), ('', ''),
('', ()), ('', ()),
('', []), ('', []),
@@ -22,36 +22,37 @@ class ScopeToStringTests(unittest.TestCase):
('a b', ('a', 'b')), ('a b', ('a', 'b')),
('a b', 'a b'), ('a b', 'a b'),
('a b', (s for s in ['a', 'b'])), ('a b', (s for s in ['a', 'b'])),
] ]
for expected, case in cases: for expected, case in cases:
self.assertEqual(expected, util.scopes_to_string(case)) self.assertEqual(expected, util.scopes_to_string(case))
class StringToScopeTests(unittest.TestCase): class StringToScopeTests(unittest.TestCase):
def test_conversion(self): def test_conversion(self):
cases = [ cases = [
(['a', 'b'], ['a', 'b']), (['a', 'b'], ['a', 'b']),
('', []), ('', []),
('a', ['a']), ('a', ['a']),
('a b c d e f', ['a', 'b', 'c', 'd', 'e', 'f']), ('a b c d e f', ['a', 'b', 'c', 'd', 'e', 'f']),
] ]
for case, expected in cases:
self.assertEqual(expected, util.string_to_scopes(case))
for case, expected in cases:
self.assertEqual(expected, util.string_to_scopes(case))
class KeyConversionTests(unittest.TestCase): class KeyConversionTests(unittest.TestCase):
def test_key_conversions(self): def test_key_conversions(self):
d = {'somekey': 'some value', 'another': 'something else', 'onemore': 'foo'} d = {'somekey': 'some value', 'another': 'something else', 'onemore': 'foo'}
tuple_key = util.dict_to_tuple_key(d) tuple_key = util.dict_to_tuple_key(d)
# the resulting key should be naturally sorted # the resulting key should be naturally sorted
self.assertEqual( self.assertEqual(
(('another', 'something else'), (('another', 'something else'),
('onemore', 'foo'), ('onemore', 'foo'),
('somekey', 'some value')), ('somekey', 'some value')),
tuple_key) tuple_key)
# check we get the original dictionary back # check we get the original dictionary back
self.assertEqual(d, dict(tuple_key)) self.assertEqual(d, dict(tuple_key))

View File

@@ -34,78 +34,78 @@ TEST_EXTRA_INFO_2 = 'more_extra_info'
class XsrfUtilTests(unittest.TestCase): class XsrfUtilTests(unittest.TestCase):
"""Test xsrfutil functions.""" """Test xsrfutil functions."""
def testGenerateAndValidateToken(self): def testGenerateAndValidateToken(self):
"""Test generating and validating a token.""" """Test generating and validating a token."""
token = xsrfutil.generate_token(TEST_KEY, token = xsrfutil.generate_token(TEST_KEY,
TEST_USER_ID_1, TEST_USER_ID_1,
action_id=TEST_ACTION_ID_1, action_id=TEST_ACTION_ID_1,
when=TEST_TIME) when=TEST_TIME)
# Check that the token is considered valid when it should be. # Check that the token is considered valid when it should be.
self.assertTrue(xsrfutil.validate_token(TEST_KEY, self.assertTrue(xsrfutil.validate_token(TEST_KEY,
token, token,
TEST_USER_ID_1, TEST_USER_ID_1,
action_id=TEST_ACTION_ID_1, action_id=TEST_ACTION_ID_1,
current_time=TEST_TIME)) current_time=TEST_TIME))
# Should still be valid 15 minutes later. # Should still be valid 15 minutes later.
later15mins = TEST_TIME + 15*60 later15mins = TEST_TIME + 15 * 60
self.assertTrue(xsrfutil.validate_token(TEST_KEY, self.assertTrue(xsrfutil.validate_token(TEST_KEY,
token, token,
TEST_USER_ID_1, TEST_USER_ID_1,
action_id=TEST_ACTION_ID_1, action_id=TEST_ACTION_ID_1,
current_time=later15mins)) current_time=later15mins))
# But not if beyond the timeout. # But not if beyond the timeout.
later2hours = TEST_TIME + 2*60*60 later2hours = TEST_TIME + 2 * 60 * 60
self.assertFalse(xsrfutil.validate_token(TEST_KEY, self.assertFalse(xsrfutil.validate_token(TEST_KEY,
token, token,
TEST_USER_ID_1, TEST_USER_ID_1,
action_id=TEST_ACTION_ID_1, action_id=TEST_ACTION_ID_1,
current_time=later2hours)) current_time=later2hours))
# Or if the key is different. # Or if the key is different.
self.assertFalse(xsrfutil.validate_token('another key', self.assertFalse(xsrfutil.validate_token('another key',
token, token,
TEST_USER_ID_1, TEST_USER_ID_1,
action_id=TEST_ACTION_ID_1, action_id=TEST_ACTION_ID_1,
current_time=later15mins)) current_time=later15mins))
# Or the user ID.... # Or the user ID....
self.assertFalse(xsrfutil.validate_token(TEST_KEY, self.assertFalse(xsrfutil.validate_token(TEST_KEY,
token, token,
TEST_USER_ID_2, TEST_USER_ID_2,
action_id=TEST_ACTION_ID_1, action_id=TEST_ACTION_ID_1,
current_time=later15mins)) current_time=later15mins))
# Or the action ID... # Or the action ID...
self.assertFalse(xsrfutil.validate_token(TEST_KEY, self.assertFalse(xsrfutil.validate_token(TEST_KEY,
token, token,
TEST_USER_ID_1, TEST_USER_ID_1,
action_id=TEST_ACTION_ID_2, action_id=TEST_ACTION_ID_2,
current_time=later15mins)) current_time=later15mins))
# Invalid when truncated # Invalid when truncated
self.assertFalse(xsrfutil.validate_token(TEST_KEY, self.assertFalse(xsrfutil.validate_token(TEST_KEY,
token[:-1], token[:-1],
TEST_USER_ID_1, TEST_USER_ID_1,
action_id=TEST_ACTION_ID_1, action_id=TEST_ACTION_ID_1,
current_time=later15mins)) current_time=later15mins))
# Invalid with extra garbage # Invalid with extra garbage
self.assertFalse(xsrfutil.validate_token(TEST_KEY, self.assertFalse(xsrfutil.validate_token(TEST_KEY,
token + b'x', token + b'x',
TEST_USER_ID_1, TEST_USER_ID_1,
action_id=TEST_ACTION_ID_1, action_id=TEST_ACTION_ID_1,
current_time=later15mins)) current_time=later15mins))
# Invalid with token of None # Invalid with token of None
self.assertFalse(xsrfutil.validate_token(TEST_KEY, self.assertFalse(xsrfutil.validate_token(TEST_KEY,
None, None,
TEST_USER_ID_1, TEST_USER_ID_1,
action_id=TEST_ACTION_ID_1)) action_id=TEST_ACTION_ID_1))
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()