Raw pep8ify changes.

Simply ran

  pep8ify -w oauth2client/
  pep8ify -w tests/
This commit is contained in:
Danny Hermes
2015-08-19 22:03:50 -07:00
parent 043e066e54
commit 34c1ff543d
40 changed files with 4635 additions and 4654 deletions

View File

@@ -19,7 +19,7 @@ import six
def _parse_pem_key(raw_key_input):
"""Identify and extract PEM keys.
"""Identify and extract PEM keys.
Determines whether the given key is in the format of PEM key, and extracts
the relevant part of the key if it is.
@@ -30,17 +30,17 @@ def _parse_pem_key(raw_key_input):
Returns:
string, The actual key if the contents are from a PEM file, or else None.
"""
offset = raw_key_input.find(b'-----BEGIN ')
if offset != -1:
return raw_key_input[offset:]
offset = raw_key_input.find(b'-----BEGIN ')
if offset != -1:
return raw_key_input[offset:]
def _json_encode(data):
return json.dumps(data, separators=(',', ':'))
return json.dumps(data, separators=(',', ':'))
def _to_bytes(value, encoding='ascii'):
"""Converts a string value to bytes, if necessary.
"""Converts a string value to bytes, if necessary.
Unfortunately, ``six.b`` is insufficient for this task since in
Python2 it does not modify ``unicode`` objects.
@@ -60,16 +60,16 @@ def _to_bytes(value, encoding='ascii'):
Raises:
ValueError if the value could not be converted to bytes.
"""
result = (value.encode(encoding)
result = (value.encode(encoding)
if isinstance(value, six.text_type) else value)
if isinstance(result, six.binary_type):
return result
else:
raise ValueError('%r could not be converted to bytes' % (value,))
if isinstance(result, six.binary_type):
return result
else:
raise ValueError('%r could not be converted to bytes' % (value, ))
def _from_bytes(value):
"""Converts bytes to a string value, if necessary.
"""Converts bytes to a string value, if necessary.
Args:
value: The string/bytes value to be converted.
@@ -81,21 +81,21 @@ def _from_bytes(value):
Raises:
ValueError if the value could not be converted to unicode.
"""
result = (value.decode('utf-8')
result = (value.decode('utf-8')
if isinstance(value, six.binary_type) else value)
if isinstance(result, six.text_type):
return result
else:
raise ValueError('%r could not be converted to unicode' % (value,))
if isinstance(result, six.text_type):
return result
else:
raise ValueError('%r could not be converted to unicode' % (value, ))
def _urlsafe_b64encode(raw_bytes):
raw_bytes = _to_bytes(raw_bytes, encoding='utf-8')
return base64.urlsafe_b64encode(raw_bytes).rstrip(b'=')
raw_bytes = _to_bytes(raw_bytes, encoding='utf-8')
return base64.urlsafe_b64encode(raw_bytes).rstrip(b'=')
def _urlsafe_b64decode(b64string):
# Guard against unicode strings, which base64 can't handle.
b64string = _to_bytes(b64string)
padded = b64string + b'=' * (4 - len(b64string) % 4)
return base64.urlsafe_b64decode(padded)
# Guard against unicode strings, which base64 can't handle.
b64string = _to_bytes(b64string)
padded = b64string + b'=' * (4 - len(b64string) % 4)
return base64.urlsafe_b64decode(padded)

View File

@@ -22,18 +22,18 @@ from oauth2client._helpers import _to_bytes
class OpenSSLVerifier(object):
"""Verifies the signature on a message."""
"""Verifies the signature on a message."""
def __init__(self, pubkey):
"""Constructor.
def __init__(self, pubkey):
"""Constructor.
Args:
pubkey, OpenSSL.crypto.PKey, The public key to verify with.
"""
self._pubkey = pubkey
self._pubkey = pubkey
def verify(self, message, signature):
"""Verifies a message against a signature.
def verify(self, message, signature):
"""Verifies a message against a signature.
Args:
message: string or bytes, The message to verify. If string, will be
@@ -45,17 +45,17 @@ class OpenSSLVerifier(object):
True if message was signed by the private key associated with the public
key that this object was constructed with.
"""
message = _to_bytes(message, encoding='utf-8')
signature = _to_bytes(signature, encoding='utf-8')
try:
crypto.verify(self._pubkey, signature, message, 'sha256')
return True
except crypto.Error:
return False
message = _to_bytes(message, encoding='utf-8')
signature = _to_bytes(signature, encoding='utf-8')
try:
crypto.verify(self._pubkey, signature, message, 'sha256')
return True
except crypto.Error:
return False
@staticmethod
@staticmethod
def from_string(key_pem, is_x509_cert):
"""Construct a Verified instance from a string.
"""Construct a Verified instance from a string.
Args:
key_pem: string, public key in PEM format.
@@ -68,26 +68,26 @@ class OpenSSLVerifier(object):
Raises:
OpenSSL.crypto.Error if the key_pem can't be parsed.
"""
if is_x509_cert:
pubkey = crypto.load_certificate(crypto.FILETYPE_PEM, key_pem)
else:
pubkey = crypto.load_privatekey(crypto.FILETYPE_PEM, key_pem)
return OpenSSLVerifier(pubkey)
if is_x509_cert:
pubkey = crypto.load_certificate(crypto.FILETYPE_PEM, key_pem)
else:
pubkey = crypto.load_privatekey(crypto.FILETYPE_PEM, key_pem)
return OpenSSLVerifier(pubkey)
class OpenSSLSigner(object):
"""Signs messages with a private key."""
"""Signs messages with a private key."""
def __init__(self, pkey):
"""Constructor.
def __init__(self, pkey):
"""Constructor.
Args:
pkey, OpenSSL.crypto.PKey (or equiv), The private key to sign with.
"""
self._key = pkey
self._key = pkey
def sign(self, message):
"""Signs a message.
def sign(self, message):
"""Signs a message.
Args:
message: bytes, Message to be signed.
@@ -95,12 +95,12 @@ class OpenSSLSigner(object):
Returns:
string, The signature of the message for the given key.
"""
message = _to_bytes(message, encoding='utf-8')
return crypto.sign(self._key, message, 'sha256')
message = _to_bytes(message, encoding='utf-8')
return crypto.sign(self._key, message, 'sha256')
@staticmethod
def from_string(key, password=b'notasecret'):
"""Construct a Signer instance from a string.
@staticmethod
def from_string(key, password=b'notasecret'):
"""Construct a Signer instance from a string.
Args:
key: string, private key in PKCS12 or PEM format.
@@ -112,17 +112,17 @@ class OpenSSLSigner(object):
Raises:
OpenSSL.crypto.Error if the key can't be parsed.
"""
parsed_pem_key = _parse_pem_key(key)
if parsed_pem_key:
pkey = crypto.load_privatekey(crypto.FILETYPE_PEM, parsed_pem_key)
else:
password = _to_bytes(password, encoding='utf-8')
pkey = crypto.load_pkcs12(key, password).get_privatekey()
return OpenSSLSigner(pkey)
parsed_pem_key = _parse_pem_key(key)
if parsed_pem_key:
pkey = crypto.load_privatekey(crypto.FILETYPE_PEM, parsed_pem_key)
else:
password = _to_bytes(password, encoding='utf-8')
pkey = crypto.load_pkcs12(key, password).get_privatekey()
return OpenSSLSigner(pkey)
def pkcs12_key_as_pem(private_key_text, private_key_password):
"""Convert the contents of a PKCS12 key to PEM using OpenSSL.
"""Convert the contents of a PKCS12 key to PEM using OpenSSL.
Args:
private_key_text: String. Private key.
@@ -131,9 +131,9 @@ def pkcs12_key_as_pem(private_key_text, private_key_password):
Returns:
String. PEM contents of ``private_key_text``.
"""
decoded_body = base64.b64decode(private_key_text)
private_key_password = _to_bytes(private_key_password)
decoded_body = base64.b64decode(private_key_text)
private_key_password = _to_bytes(private_key_password)
pkcs12 = crypto.load_pkcs12(decoded_body, private_key_password)
return crypto.dump_privatekey(crypto.FILETYPE_PEM,
pkcs12 = crypto.load_pkcs12(decoded_body, private_key_password)
return crypto.dump_privatekey(crypto.FILETYPE_PEM,
pkcs12.get_privatekey())

View File

@@ -25,18 +25,18 @@ from oauth2client._helpers import _urlsafe_b64decode
class PyCryptoVerifier(object):
"""Verifies the signature on a message."""
"""Verifies the signature on a message."""
def __init__(self, pubkey):
"""Constructor.
def __init__(self, pubkey):
"""Constructor.
Args:
pubkey, OpenSSL.crypto.PKey (or equiv), The public key to verify with.
"""
self._pubkey = pubkey
self._pubkey = pubkey
def verify(self, message, signature):
"""Verifies a message against a signature.
def verify(self, message, signature):
"""Verifies a message against a signature.
Args:
message: string or bytes, The message to verify. If string, will be
@@ -47,13 +47,13 @@ class PyCryptoVerifier(object):
True if message was signed by the private key associated with the public
key that this object was constructed with.
"""
message = _to_bytes(message, encoding='utf-8')
return PKCS1_v1_5.new(self._pubkey).verify(
message = _to_bytes(message, encoding='utf-8')
return PKCS1_v1_5.new(self._pubkey).verify(
SHA256.new(message), signature)
@staticmethod
def from_string(key_pem, is_x509_cert):
"""Construct a Verified instance from a string.
@staticmethod
def from_string(key_pem, is_x509_cert):
"""Construct a Verified instance from a string.
Args:
key_pem: string, public key in PEM format.
@@ -63,33 +63,33 @@ class PyCryptoVerifier(object):
Returns:
Verifier instance.
"""
if is_x509_cert:
key_pem = _to_bytes(key_pem)
pemLines = key_pem.replace(b' ', b'').split()
certDer = _urlsafe_b64decode(b''.join(pemLines[1:-1]))
certSeq = DerSequence()
certSeq.decode(certDer)
tbsSeq = DerSequence()
tbsSeq.decode(certSeq[0])
pubkey = RSA.importKey(tbsSeq[6])
else:
pubkey = RSA.importKey(key_pem)
return PyCryptoVerifier(pubkey)
if is_x509_cert:
key_pem = _to_bytes(key_pem)
pemLines = key_pem.replace(b' ', b'').split()
certDer = _urlsafe_b64decode(b''.join(pemLines[1:-1]))
certSeq = DerSequence()
certSeq.decode(certDer)
tbsSeq = DerSequence()
tbsSeq.decode(certSeq[0])
pubkey = RSA.importKey(tbsSeq[6])
else:
pubkey = RSA.importKey(key_pem)
return PyCryptoVerifier(pubkey)
class PyCryptoSigner(object):
"""Signs messages with a private key."""
"""Signs messages with a private key."""
def __init__(self, pkey):
"""Constructor.
def __init__(self, pkey):
"""Constructor.
Args:
pkey, OpenSSL.crypto.PKey (or equiv), The private key to sign with.
"""
self._key = pkey
self._key = pkey
def sign(self, message):
"""Signs a message.
def sign(self, message):
"""Signs a message.
Args:
message: string, Message to be signed.
@@ -97,12 +97,12 @@ class PyCryptoSigner(object):
Returns:
string, The signature of the message for the given key.
"""
message = _to_bytes(message, encoding='utf-8')
return PKCS1_v1_5.new(self._key).sign(SHA256.new(message))
message = _to_bytes(message, encoding='utf-8')
return PKCS1_v1_5.new(self._key).sign(SHA256.new(message))
@staticmethod
def from_string(key, password='notasecret'):
"""Construct a Signer instance from a string.
@staticmethod
def from_string(key, password='notasecret'):
"""Construct a Signer instance from a string.
Args:
key: string, private key in PEM format.
@@ -114,13 +114,13 @@ class PyCryptoSigner(object):
Raises:
NotImplementedError if the key isn't in PEM format.
"""
parsed_pem_key = _parse_pem_key(key)
if parsed_pem_key:
pkey = RSA.importKey(parsed_pem_key)
else:
raise NotImplementedError(
parsed_pem_key = _parse_pem_key(key)
if parsed_pem_key:
pkey = RSA.importKey(parsed_pem_key)
else:
raise NotImplementedError(
'PKCS12 format is not supported by the PyCrypto library. '
'Try converting to a "PEM" '
'(openssl pkcs12 -in xxxxx.p12 -nodes -nocerts > privatekey.pem) '
'or using PyOpenSSL if native code is an option.')
return PyCryptoSigner(pkey)
return PyCryptoSigner(pkey)

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@@ -23,7 +23,6 @@ __author__ = 'jcgregorio@google.com (Joe Gregorio)'
import json
import six
# Properties that make a client_secrets.json file valid.
TYPE_WEB = 'web'
TYPE_INSTALLED = 'installed'
@@ -59,65 +58,65 @@ VALID_CLIENT = {
class Error(Exception):
"""Base error for this module."""
pass
"""Base error for this module."""
pass
class InvalidClientSecretsError(Error):
"""Format of ClientSecrets file is invalid."""
pass
"""Format of ClientSecrets file is invalid."""
pass
def _validate_clientsecrets(obj):
_INVALID_FILE_FORMAT_MSG = (
_INVALID_FILE_FORMAT_MSG = (
'Invalid file format. See '
'https://developers.google.com/api-client-library/'
'python/guide/aaa_client_secrets')
if obj is None:
raise InvalidClientSecretsError(_INVALID_FILE_FORMAT_MSG)
if len(obj) != 1:
raise InvalidClientSecretsError(
if obj is None:
raise InvalidClientSecretsError(_INVALID_FILE_FORMAT_MSG)
if len(obj) != 1:
raise InvalidClientSecretsError(
_INVALID_FILE_FORMAT_MSG + ' '
'Expected a JSON object with a single property for a "web" or '
'"installed" application')
client_type = tuple(obj)[0]
if client_type not in VALID_CLIENT:
raise InvalidClientSecretsError('Unknown client type: %s.' % (client_type,))
client_info = obj[client_type]
for prop_name in VALID_CLIENT[client_type]['required']:
if prop_name not in client_info:
raise InvalidClientSecretsError(
client_type = tuple(obj)[0]
if client_type not in VALID_CLIENT:
raise InvalidClientSecretsError('Unknown client type: %s.' % (client_type, ))
client_info = obj[client_type]
for prop_name in VALID_CLIENT[client_type]['required']:
if prop_name not in client_info:
raise InvalidClientSecretsError(
'Missing property "%s" in a client type of "%s".' % (prop_name,
client_type))
for prop_name in VALID_CLIENT[client_type]['string']:
if client_info[prop_name].startswith('[['):
raise InvalidClientSecretsError(
for prop_name in VALID_CLIENT[client_type]['string']:
if client_info[prop_name].startswith('[['):
raise InvalidClientSecretsError(
'Property "%s" is not configured.' % prop_name)
return client_type, client_info
return client_type, client_info
def load(fp):
obj = json.load(fp)
return _validate_clientsecrets(obj)
obj = json.load(fp)
return _validate_clientsecrets(obj)
def loads(s):
obj = json.loads(s)
return _validate_clientsecrets(obj)
obj = json.loads(s)
return _validate_clientsecrets(obj)
def _loadfile(filename):
try:
with open(filename, 'r') as fp:
obj = json.load(fp)
except IOError:
raise InvalidClientSecretsError('File not found: "%s"' % filename)
return _validate_clientsecrets(obj)
try:
with open(filename, 'r') as fp:
obj = json.load(fp)
except IOError:
raise InvalidClientSecretsError('File not found: "%s"' % filename)
return _validate_clientsecrets(obj)
def loadfile(filename, cache=None):
"""Loading of client_secrets JSON file, optionally backed by a cache.
"""Loading of client_secrets JSON file, optionally backed by a cache.
Typical cache storage would be App Engine memcache service,
but you can pass in any other cache client that implements
@@ -149,15 +148,15 @@ def loadfile(filename, cache=None):
JSON contents is validated only during first load. Cache hits are not
validated.
"""
_SECRET_NAMESPACE = 'oauth2client:secrets#ns'
_SECRET_NAMESPACE = 'oauth2client:secrets#ns'
if not cache:
return _loadfile(filename)
if not cache:
return _loadfile(filename)
obj = cache.get(filename, namespace=_SECRET_NAMESPACE)
if obj is None:
client_type, client_info = _loadfile(filename)
obj = {client_type: client_info}
cache.set(filename, obj, namespace=_SECRET_NAMESPACE)
obj = cache.get(filename, namespace=_SECRET_NAMESPACE)
if obj is None:
client_type, client_info = _loadfile(filename)
obj = {client_type: client_info}
cache.set(filename, obj, namespace=_SECRET_NAMESPACE)
return next(six.iteritems(obj))
return next(six.iteritems(obj))

View File

@@ -25,51 +25,51 @@ from oauth2client._helpers import _to_bytes
from oauth2client._helpers import _urlsafe_b64decode
from oauth2client._helpers import _urlsafe_b64encode
CLOCK_SKEW_SECS = 300 # 5 minutes in seconds
AUTH_TOKEN_LIFETIME_SECS = 300 # 5 minutes in seconds
MAX_TOKEN_LIFETIME_SECS = 86400 # 1 day in seconds
logger = logging.getLogger(__name__)
class AppIdentityError(Exception):
pass
pass
try:
from oauth2client._openssl_crypt import OpenSSLVerifier
from oauth2client._openssl_crypt import OpenSSLSigner
from oauth2client._openssl_crypt import pkcs12_key_as_pem
from oauth2client._openssl_crypt import OpenSSLVerifier
from oauth2client._openssl_crypt import OpenSSLSigner
from oauth2client._openssl_crypt import pkcs12_key_as_pem
except ImportError:
OpenSSLVerifier = None
OpenSSLSigner = None
def pkcs12_key_as_pem(*args, **kwargs):
raise NotImplementedError('pkcs12_key_as_pem requires OpenSSL.')
OpenSSLVerifier = None
OpenSSLSigner = None
def pkcs12_key_as_pem(*args, **kwargs):
raise NotImplementedError('pkcs12_key_as_pem requires OpenSSL.')
try:
from oauth2client._pycrypto_crypt import PyCryptoVerifier
from oauth2client._pycrypto_crypt import PyCryptoSigner
from oauth2client._pycrypto_crypt import PyCryptoVerifier
from oauth2client._pycrypto_crypt import PyCryptoSigner
except ImportError:
PyCryptoVerifier = None
PyCryptoSigner = None
PyCryptoVerifier = None
PyCryptoSigner = None
if OpenSSLSigner:
Signer = OpenSSLSigner
Verifier = OpenSSLVerifier
Signer = OpenSSLSigner
Verifier = OpenSSLVerifier
elif PyCryptoSigner:
Signer = PyCryptoSigner
Verifier = PyCryptoVerifier
Signer = PyCryptoSigner
Verifier = PyCryptoVerifier
else:
raise ImportError('No encryption library found. Please install either '
raise ImportError('No encryption library found. Please install either '
'PyOpenSSL, or PyCrypto 2.6 or later')
def make_signed_jwt(signer, payload):
"""Make a signed JWT.
"""Make a signed JWT.
See http://self-issued.info/docs/draft-jones-json-web-token.html.
@@ -80,24 +80,24 @@ def make_signed_jwt(signer, payload):
Returns:
string, The JWT for the payload.
"""
header = {'typ': 'JWT', 'alg': 'RS256'}
header = {'typ': 'JWT', 'alg': 'RS256'}
segments = [
segments = [
_urlsafe_b64encode(_json_encode(header)),
_urlsafe_b64encode(_json_encode(payload)),
]
signing_input = b'.'.join(segments)
]
signing_input = b'.'.join(segments)
signature = signer.sign(signing_input)
segments.append(_urlsafe_b64encode(signature))
signature = signer.sign(signing_input)
segments.append(_urlsafe_b64encode(signature))
logger.debug(str(segments))
logger.debug(str(segments))
return b'.'.join(segments)
return b'.'.join(segments)
def verify_signed_jwt_with_certs(jwt, certs, audience):
"""Verify a JWT against public certs.
"""Verify a JWT against public certs.
See http://self-issued.info/docs/draft-jones-json-web-token.html.
@@ -113,61 +113,61 @@ def verify_signed_jwt_with_certs(jwt, certs, audience):
Raises:
AppIdentityError if any checks are failed.
"""
jwt = _to_bytes(jwt)
segments = jwt.split(b'.')
jwt = _to_bytes(jwt)
segments = jwt.split(b'.')
if len(segments) != 3:
raise AppIdentityError('Wrong number of segments in token: %s' % jwt)
signed = segments[0] + b'.' + segments[1]
if len(segments) != 3:
raise AppIdentityError('Wrong number of segments in token: %s' % jwt)
signed = segments[0] + b'.' + segments[1]
signature = _urlsafe_b64decode(segments[2])
signature = _urlsafe_b64decode(segments[2])
# Parse token.
json_body = _urlsafe_b64decode(segments[1])
try:
parsed = json.loads(_from_bytes(json_body))
except:
raise AppIdentityError('Can\'t parse token: %s' % json_body)
# Parse token.
json_body = _urlsafe_b64decode(segments[1])
try:
parsed = json.loads(_from_bytes(json_body))
except:
raise AppIdentityError('Can\'t parse token: %s' % json_body)
# Check signature.
verified = False
for pem in certs.values():
verifier = Verifier.from_string(pem, True)
if verifier.verify(signed, signature):
verified = True
break
if not verified:
raise AppIdentityError('Invalid token signature: %s' % jwt)
# Check signature.
verified = False
for pem in certs.values():
verifier = Verifier.from_string(pem, True)
if verifier.verify(signed, signature):
verified = True
break
if not verified:
raise AppIdentityError('Invalid token signature: %s' % jwt)
# Check creation timestamp.
iat = parsed.get('iat')
if iat is None:
raise AppIdentityError('No iat field in token: %s' % json_body)
earliest = iat - CLOCK_SKEW_SECS
# Check creation timestamp.
iat = parsed.get('iat')
if iat is None:
raise AppIdentityError('No iat field in token: %s' % json_body)
earliest = iat - CLOCK_SKEW_SECS
# Check expiration timestamp.
now = int(time.time())
exp = parsed.get('exp')
if exp is None:
raise AppIdentityError('No exp field in token: %s' % json_body)
if exp >= now + MAX_TOKEN_LIFETIME_SECS:
raise AppIdentityError('exp field too far in future: %s' % json_body)
latest = exp + CLOCK_SKEW_SECS
# Check expiration timestamp.
now = int(time.time())
exp = parsed.get('exp')
if exp is None:
raise AppIdentityError('No exp field in token: %s' % json_body)
if exp >= now + MAX_TOKEN_LIFETIME_SECS:
raise AppIdentityError('exp field too far in future: %s' % json_body)
latest = exp + CLOCK_SKEW_SECS
if now < earliest:
raise AppIdentityError('Token used too early, %d < %d: %s' %
if now < earliest:
raise AppIdentityError('Token used too early, %d < %d: %s' %
(now, earliest, json_body))
if now > latest:
raise AppIdentityError('Token used too late, %d > %d: %s' %
if now > latest:
raise AppIdentityError('Token used too late, %d > %d: %s' %
(now, latest, json_body))
# Check audience.
if audience is not None:
aud = parsed.get('aud')
if aud is None:
raise AppIdentityError('No aud field in token: %s' % json_body)
if aud != audience:
raise AppIdentityError('Wrong recipient, %s != %s: %s' %
# Check audience.
if audience is not None:
aud = parsed.get('aud')
if aud is None:
raise AppIdentityError('No aud field in token: %s' % json_body)
if aud != audience:
raise AppIdentityError('Wrong recipient, %s != %s: %s' %
(aud, audience, json_body))
return parsed
return parsed

View File

@@ -20,22 +20,20 @@ import os
from oauth2client._helpers import _to_bytes
from oauth2client import client
DEVSHELL_ENV = 'DEVSHELL_CLIENT_PORT'
class Error(Exception):
"""Errors for this module."""
pass
"""Errors for this module."""
pass
class CommunicationError(Error):
"""Errors for communication with the Developer Shell server."""
"""Errors for communication with the Developer Shell server."""
class NoDevshellServer(Error):
"""Error when no Developer Shell server can be contacted."""
"""Error when no Developer Shell server can be contacted."""
# The request for credential information to the Developer Shell client socket is
# always an empty PBLite-formatted JSON object, so just define it as a constant.
@@ -43,7 +41,7 @@ CREDENTIAL_INFO_REQUEST_JSON = '[]'
class CredentialInfoResponse(object):
"""Credential information response from Developer Shell server.
"""Credential information response from Developer Shell server.
The credential information response from Developer Shell socket is a
PBLite-formatted JSON array with fields encoded by their index in the array:
@@ -52,46 +50,46 @@ class CredentialInfoResponse(object):
* Index 2 - OAuth2 access token. None if there is no valid auth context.
"""
def __init__(self, json_string):
"""Initialize the response data from JSON PBLite array."""
pbl = json.loads(json_string)
if not isinstance(pbl, list):
raise ValueError('Not a list: ' + str(pbl))
pbl_len = len(pbl)
self.user_email = pbl[0] if pbl_len > 0 else None
self.project_id = pbl[1] if pbl_len > 1 else None
self.access_token = pbl[2] if pbl_len > 2 else None
def __init__(self, json_string):
"""Initialize the response data from JSON PBLite array."""
pbl = json.loads(json_string)
if not isinstance(pbl, list):
raise ValueError('Not a list: ' + str(pbl))
pbl_len = len(pbl)
self.user_email = pbl[0] if pbl_len > 0 else None
self.project_id = pbl[1] if pbl_len > 1 else None
self.access_token = pbl[2] if pbl_len > 2 else None
def _SendRecv():
"""Communicate with the Developer Shell server socket."""
"""Communicate with the Developer Shell server socket."""
port = int(os.getenv(DEVSHELL_ENV, 0))
if port == 0:
raise NoDevshellServer()
port = int(os.getenv(DEVSHELL_ENV, 0))
if port == 0:
raise NoDevshellServer()
import socket
import socket
sock = socket.socket()
sock.connect(('localhost', port))
sock = socket.socket()
sock.connect(('localhost', port))
data = CREDENTIAL_INFO_REQUEST_JSON
msg = '%s\n%s' % (len(data), data)
sock.sendall(_to_bytes(msg, encoding='utf-8'))
data = CREDENTIAL_INFO_REQUEST_JSON
msg = '%s\n%s' % (len(data), data)
sock.sendall(_to_bytes(msg, encoding='utf-8'))
header = sock.recv(6).decode()
if '\n' not in header:
raise CommunicationError('saw no newline in the first 6 bytes')
len_str, json_str = header.split('\n', 1)
to_read = int(len_str) - len(json_str)
if to_read > 0:
json_str += sock.recv(to_read, socket.MSG_WAITALL).decode()
header = sock.recv(6).decode()
if '\n' not in header:
raise CommunicationError('saw no newline in the first 6 bytes')
len_str, json_str = header.split('\n', 1)
to_read = int(len_str) - len(json_str)
if to_read > 0:
json_str += sock.recv(to_read, socket.MSG_WAITALL).decode()
return CredentialInfoResponse(json_str)
return CredentialInfoResponse(json_str)
class DevshellCredentials(client.GoogleCredentials):
"""Credentials object for Google Developer Shell environment.
"""Credentials object for Google Developer Shell environment.
This object will allow a Google Developer Shell session to identify its user
to Google and other OAuth 2.0 servers that can verify assertions. It can be
@@ -102,8 +100,8 @@ class DevshellCredentials(client.GoogleCredentials):
generate and refresh its own access tokens.
"""
def __init__(self, user_agent=None):
super(DevshellCredentials, self).__init__(
def __init__(self, user_agent=None):
super(DevshellCredentials, self).__init__(
None, # access_token, initialized below
None, # client_id
None, # client_secret
@@ -111,26 +109,26 @@ class DevshellCredentials(client.GoogleCredentials):
None, # token_expiry
None, # token_uri
user_agent)
self._refresh(None)
self._refresh(None)
def _refresh(self, http_request):
self.devshell_response = _SendRecv()
self.access_token = self.devshell_response.access_token
def _refresh(self, http_request):
self.devshell_response = _SendRecv()
self.access_token = self.devshell_response.access_token
@property
def user_email(self):
return self.devshell_response.user_email
@property
def user_email(self):
return self.devshell_response.user_email
@property
def project_id(self):
return self.devshell_response.project_id
@property
def project_id(self):
return self.devshell_response.project_id
@classmethod
def from_json(cls, json_data):
raise NotImplementedError(
@classmethod
def from_json(cls, json_data):
raise NotImplementedError(
'Cannot load Developer Shell credentials from JSON.')
@property
def serialization_data(self):
raise NotImplementedError(
@property
def serialization_data(self):
raise NotImplementedError(
'Cannot serialize Developer Shell credentials.')

View File

@@ -27,58 +27,59 @@ import pickle
from django.db import models
from oauth2client.client import Storage as BaseStorage
class CredentialsField(models.Field):
__metaclass__ = models.SubfieldBase
__metaclass__ = models.SubfieldBase
def __init__(self, *args, **kwargs):
if 'null' not in kwargs:
kwargs['null'] = True
super(CredentialsField, self).__init__(*args, **kwargs)
def __init__(self, *args, **kwargs):
if 'null' not in kwargs:
kwargs['null'] = True
super(CredentialsField, self).__init__(*args, **kwargs)
def get_internal_type(self):
return "TextField"
def get_internal_type(self):
return "TextField"
def to_python(self, value):
if value is None:
return None
if isinstance(value, oauth2client.client.Credentials):
return value
return pickle.loads(base64.b64decode(value))
def to_python(self, value):
if value is None:
return None
if isinstance(value, oauth2client.client.Credentials):
return value
return pickle.loads(base64.b64decode(value))
def get_db_prep_value(self, value, connection, prepared=False):
if value is None:
return None
return base64.b64encode(pickle.dumps(value))
def get_db_prep_value(self, value, connection, prepared=False):
if value is None:
return None
return base64.b64encode(pickle.dumps(value))
class FlowField(models.Field):
__metaclass__ = models.SubfieldBase
__metaclass__ = models.SubfieldBase
def __init__(self, *args, **kwargs):
if 'null' not in kwargs:
kwargs['null'] = True
super(FlowField, self).__init__(*args, **kwargs)
def __init__(self, *args, **kwargs):
if 'null' not in kwargs:
kwargs['null'] = True
super(FlowField, self).__init__(*args, **kwargs)
def get_internal_type(self):
return "TextField"
def get_internal_type(self):
return "TextField"
def to_python(self, value):
if value is None:
return None
if isinstance(value, oauth2client.client.Flow):
return value
return pickle.loads(base64.b64decode(value))
def to_python(self, value):
if value is None:
return None
if isinstance(value, oauth2client.client.Flow):
return value
return pickle.loads(base64.b64decode(value))
def get_db_prep_value(self, value, connection, prepared=False):
if value is None:
return None
return base64.b64encode(pickle.dumps(value))
def get_db_prep_value(self, value, connection, prepared=False):
if value is None:
return None
return base64.b64encode(pickle.dumps(value))
class Storage(BaseStorage):
"""Store and retrieve a single credential to and from
"""Store and retrieve a single credential to and from
the datastore.
This Storage helper presumes the Credentials
@@ -86,8 +87,8 @@ class Storage(BaseStorage):
on a db model class.
"""
def __init__(self, model_class, key_name, key_value, property_name):
"""Constructor for Storage.
def __init__(self, model_class, key_name, key_value, property_name):
"""Constructor for Storage.
Args:
model: db.Model, model class
@@ -95,47 +96,47 @@ class Storage(BaseStorage):
key_value: string, key value for the entity that has the credentials
property_name: string, name of the property that is an CredentialsProperty
"""
self.model_class = model_class
self.key_name = key_name
self.key_value = key_value
self.property_name = property_name
self.model_class = model_class
self.key_name = key_name
self.key_value = key_value
self.property_name = property_name
def locked_get(self):
"""Retrieve Credential from datastore.
def locked_get(self):
"""Retrieve Credential from datastore.
Returns:
oauth2client.Credentials
"""
credential = None
credential = None
query = {self.key_name: self.key_value}
entities = self.model_class.objects.filter(**query)
if len(entities) > 0:
credential = getattr(entities[0], self.property_name)
if credential and hasattr(credential, 'set_store'):
credential.set_store(self)
return credential
query = {self.key_name: self.key_value}
entities = self.model_class.objects.filter(**query)
if len(entities) > 0:
credential = getattr(entities[0], self.property_name)
if credential and hasattr(credential, 'set_store'):
credential.set_store(self)
return credential
def locked_put(self, credentials, overwrite=False):
"""Write a Credentials to the datastore.
def locked_put(self, credentials, overwrite=False):
"""Write a Credentials to the datastore.
Args:
credentials: Credentials, the credentials to store.
overwrite: Boolean, indicates whether you would like these credentials to
overwrite any existing stored credentials.
"""
args = {self.key_name: self.key_value}
args = {self.key_name: self.key_value}
if overwrite:
entity, unused_is_new = self.model_class.objects.get_or_create(**args)
else:
entity = self.model_class(**args)
if overwrite:
entity, unused_is_new = self.model_class.objects.get_or_create(**args)
else:
entity = self.model_class(**args)
setattr(entity, self.property_name, credentials)
entity.save()
setattr(entity, self.property_name, credentials)
entity.save()
def locked_delete(self):
"""Delete Credentials from the datastore."""
def locked_delete(self):
"""Delete Credentials from the datastore."""
query = {self.key_name: self.key_value}
entities = self.model_class.objects.filter(**query).delete()
query = {self.key_name: self.key_value}
entities = self.model_class.objects.filter(**query).delete()

View File

@@ -28,37 +28,37 @@ from oauth2client.client import Storage as BaseStorage
class CredentialsFileSymbolicLinkError(Exception):
"""Credentials files must not be symbolic links."""
"""Credentials files must not be symbolic links."""
class Storage(BaseStorage):
"""Store and retrieve a single credential to and from a file."""
"""Store and retrieve a single credential to and from a file."""
def __init__(self, filename):
self._filename = filename
self._lock = threading.Lock()
def __init__(self, filename):
self._filename = filename
self._lock = threading.Lock()
def _validate_file(self):
if os.path.islink(self._filename):
raise CredentialsFileSymbolicLinkError(
def _validate_file(self):
if os.path.islink(self._filename):
raise CredentialsFileSymbolicLinkError(
'File: %s is a symbolic link.' % self._filename)
def acquire_lock(self):
"""Acquires any lock necessary to access this Storage.
def acquire_lock(self):
"""Acquires any lock necessary to access this Storage.
This lock is not reentrant."""
self._lock.acquire()
self._lock.acquire()
def release_lock(self):
"""Release the Storage lock.
def release_lock(self):
"""Release the Storage lock.
Trying to release a lock that isn't held will result in a
RuntimeError.
"""
self._lock.release()
self._lock.release()
def locked_get(self):
"""Retrieve Credential from file.
def locked_get(self):
"""Retrieve Credential from file.
Returns:
oauth2client.client.Credentials
@@ -66,38 +66,38 @@ class Storage(BaseStorage):
Raises:
CredentialsFileSymbolicLinkError if the file is a symbolic link.
"""
credentials = None
self._validate_file()
try:
f = open(self._filename, 'rb')
content = f.read()
f.close()
except IOError:
return credentials
credentials = None
self._validate_file()
try:
f = open(self._filename, 'rb')
content = f.read()
f.close()
except IOError:
return credentials
try:
credentials = Credentials.new_from_json(content)
credentials.set_store(self)
except ValueError:
pass
try:
credentials = Credentials.new_from_json(content)
credentials.set_store(self)
except ValueError:
pass
return credentials
return credentials
def _create_file_if_needed(self):
"""Create an empty file if necessary.
def _create_file_if_needed(self):
"""Create an empty file if necessary.
This method will not initialize the file. Instead it implements a
simple version of "touch" to ensure the file has been created.
"""
if not os.path.exists(self._filename):
old_umask = os.umask(0o177)
try:
open(self._filename, 'a+b').close()
finally:
os.umask(old_umask)
if not os.path.exists(self._filename):
old_umask = os.umask(0o177)
try:
open(self._filename, 'a+b').close()
finally:
os.umask(old_umask)
def locked_put(self, credentials):
"""Write Credentials to file.
def locked_put(self, credentials):
"""Write Credentials to file.
Args:
credentials: Credentials, the credentials to store.
@@ -106,17 +106,17 @@ class Storage(BaseStorage):
CredentialsFileSymbolicLinkError if the file is a symbolic link.
"""
self._create_file_if_needed()
self._validate_file()
f = open(self._filename, 'w')
f.write(credentials.to_json())
f.close()
self._create_file_if_needed()
self._validate_file()
f = open(self._filename, 'w')
f.write(credentials.to_json())
f.close()
def locked_delete(self):
"""Delete Credentials file.
def locked_delete(self):
"""Delete Credentials file.
Args:
credentials: Credentials, the credentials to store.
"""
os.unlink(self._filename)
os.unlink(self._filename)

View File

@@ -189,8 +189,7 @@ from oauth2client.client import Storage
from oauth2client import clientsecrets
from oauth2client import util
DEFAULT_SCOPES = ('email',)
DEFAULT_SCOPES = ('email', )
class UserOAuth2(object):
@@ -461,6 +460,7 @@ class UserOAuth2(object):
be redirected to the authorization flow. Once complete, the user will
be redirected back to the original page.
"""
def curry_wrapper(wrapped_function):
@wraps(wrapped_function)
def required_wrapper(*args, **kwargs):
@@ -519,6 +519,7 @@ class FlaskSessionStorage(Storage):
credentials. We strongly recommend using a server-side session
implementation.
"""
def locked_get(self):
serialized = session.get('google_oauth2_credentials')

View File

@@ -36,7 +36,7 @@ META = ('http://metadata.google.internal/0.1/meta-data/service-accounts/'
class AppAssertionCredentials(AssertionCredentials):
"""Credentials object for Compute Engine Assertion Grants
"""Credentials object for Compute Engine Assertion Grants
This object will allow a Compute Engine instance to identify itself to
Google and other OAuth 2.0 servers that can verify assertions. It can be used
@@ -48,27 +48,27 @@ class AppAssertionCredentials(AssertionCredentials):
generate and refresh its own access tokens.
"""
@util.positional(2)
def __init__(self, scope, **kwargs):
"""Constructor for AppAssertionCredentials
@util.positional(2)
def __init__(self, scope, **kwargs):
"""Constructor for AppAssertionCredentials
Args:
scope: string or iterable of strings, scope(s) of the credentials being
requested.
"""
self.scope = util.scopes_to_string(scope)
self.kwargs = kwargs
self.scope = util.scopes_to_string(scope)
self.kwargs = kwargs
# Assertion type is no longer used, but still in the parent class signature.
super(AppAssertionCredentials, self).__init__(None)
# Assertion type is no longer used, but still in the parent class signature.
super(AppAssertionCredentials, self).__init__(None)
@classmethod
def from_json(cls, json_data):
data = json.loads(_from_bytes(json_data))
return AppAssertionCredentials(data['scope'])
@classmethod
def from_json(cls, json_data):
data = json.loads(_from_bytes(json_data))
return AppAssertionCredentials(data['scope'])
def _refresh(self, http_request):
"""Refreshes the access_token.
def _refresh(self, http_request):
"""Refreshes the access_token.
Skip all the storage hoops and just refresh using the API.
@@ -79,29 +79,29 @@ class AppAssertionCredentials(AssertionCredentials):
Raises:
AccessTokenRefreshError: When the refresh fails.
"""
query = '?scope=%s' % urllib.parse.quote(self.scope, '')
uri = META.replace('{?scope}', query)
response, content = http_request(uri)
content = _from_bytes(content)
if response.status == 200:
try:
d = json.loads(content)
except Exception as e:
raise AccessTokenRefreshError(str(e))
self.access_token = d['accessToken']
else:
if response.status == 404:
content += (' This can occur if a VM was created'
query = '?scope=%s' % urllib.parse.quote(self.scope, '')
uri = META.replace('{?scope}', query)
response, content = http_request(uri)
content = _from_bytes(content)
if response.status == 200:
try:
d = json.loads(content)
except Exception as e:
raise AccessTokenRefreshError(str(e))
self.access_token = d['accessToken']
else:
if response.status == 404:
content += (' This can occur if a VM was created'
' with no service account or scopes.')
raise AccessTokenRefreshError(content)
raise AccessTokenRefreshError(content)
@property
@property
def serialization_data(self):
raise NotImplementedError(
raise NotImplementedError(
'Cannot serialize credentials for GCE service accounts.')
def create_scoped_required(self):
return not self.scope
def create_scoped_required(self):
return not self.scope
def create_scoped(self, scopes):
return AppAssertionCredentials(scopes, **self.kwargs)
def create_scoped(self, scopes):
return AppAssertionCredentials(scopes, **self.kwargs)

View File

@@ -28,7 +28,7 @@ from oauth2client.client import Storage as BaseStorage
class Storage(BaseStorage):
"""Store and retrieve a single credential to and from the keyring.
"""Store and retrieve a single credential to and from the keyring.
To use this module you must have the keyring module installed. See
<http://pypi.python.org/pypi/keyring/>. This is an optional module and is not
@@ -48,63 +48,63 @@ class Storage(BaseStorage):
"""
def __init__(self, service_name, user_name):
"""Constructor.
def __init__(self, service_name, user_name):
"""Constructor.
Args:
service_name: string, The name of the service under which the credentials
are stored.
user_name: string, The name of the user to store credentials for.
"""
self._service_name = service_name
self._user_name = user_name
self._lock = threading.Lock()
self._service_name = service_name
self._user_name = user_name
self._lock = threading.Lock()
def acquire_lock(self):
"""Acquires any lock necessary to access this Storage.
def acquire_lock(self):
"""Acquires any lock necessary to access this Storage.
This lock is not reentrant."""
self._lock.acquire()
self._lock.acquire()
def release_lock(self):
"""Release the Storage lock.
def release_lock(self):
"""Release the Storage lock.
Trying to release a lock that isn't held will result in a
RuntimeError.
"""
self._lock.release()
self._lock.release()
def locked_get(self):
"""Retrieve Credential from file.
def locked_get(self):
"""Retrieve Credential from file.
Returns:
oauth2client.client.Credentials
"""
credentials = None
content = keyring.get_password(self._service_name, self._user_name)
credentials = None
content = keyring.get_password(self._service_name, self._user_name)
if content is not None:
try:
credentials = Credentials.new_from_json(content)
credentials.set_store(self)
except ValueError:
pass
if content is not None:
try:
credentials = Credentials.new_from_json(content)
credentials.set_store(self)
except ValueError:
pass
return credentials
return credentials
def locked_put(self, credentials):
"""Write Credentials to file.
def locked_put(self, credentials):
"""Write Credentials to file.
Args:
credentials: Credentials, the credentials to store.
"""
keyring.set_password(self._service_name, self._user_name,
keyring.set_password(self._service_name, self._user_name,
credentials.to_json())
def locked_delete(self):
"""Delete Credentials file.
def locked_delete(self):
"""Delete Credentials file.
Args:
credentials: Credentials, the credentials to store.
"""
keyring.set_password(self._service_name, self._user_name, '')
keyring.set_password(self._service_name, self._user_name, '')

View File

@@ -45,68 +45,69 @@ logger = logging.getLogger(__name__)
class CredentialsFileSymbolicLinkError(Exception):
"""Credentials files must not be symbolic links."""
"""Credentials files must not be symbolic links."""
class AlreadyLockedException(Exception):
"""Trying to lock a file that has already been locked by the LockedFile."""
pass
"""Trying to lock a file that has already been locked by the LockedFile."""
pass
def validate_file(filename):
if os.path.islink(filename):
raise CredentialsFileSymbolicLinkError(
if os.path.islink(filename):
raise CredentialsFileSymbolicLinkError(
'File: %s is a symbolic link.' % filename)
class _Opener(object):
"""Base class for different locking primitives."""
def __init__(self, filename, mode, fallback_mode):
"""Create an Opener.
class _Opener(object):
"""Base class for different locking primitives."""
def __init__(self, filename, mode, fallback_mode):
"""Create an Opener.
Args:
filename: string, The pathname of the file.
mode: string, The preferred mode to access the file with.
fallback_mode: string, The mode to use if locking fails.
"""
self._locked = False
self._filename = filename
self._mode = mode
self._fallback_mode = fallback_mode
self._fh = None
self._lock_fd = None
self._locked = False
self._filename = filename
self._mode = mode
self._fallback_mode = fallback_mode
self._fh = None
self._lock_fd = None
def is_locked(self):
"""Was the file locked."""
return self._locked
def is_locked(self):
"""Was the file locked."""
return self._locked
def file_handle(self):
"""The file handle to the file. Valid only after opened."""
return self._fh
def file_handle(self):
"""The file handle to the file. Valid only after opened."""
return self._fh
def filename(self):
"""The filename that is being locked."""
return self._filename
def filename(self):
"""The filename that is being locked."""
return self._filename
def open_and_lock(self, timeout, delay):
"""Open the file and lock it.
def open_and_lock(self, timeout, delay):
"""Open the file and lock it.
Args:
timeout: float, How long to try to lock for.
delay: float, How long to wait between retries.
"""
pass
pass
def unlock_and_close(self):
"""Unlock and close the file."""
pass
def unlock_and_close(self):
"""Unlock and close the file."""
pass
class _PosixOpener(_Opener):
"""Lock files using Posix advisory lock files."""
"""Lock files using Posix advisory lock files."""
def open_and_lock(self, timeout, delay):
"""Open the file and lock it.
def open_and_lock(self, timeout, delay):
"""Open the file and lock it.
Tries to create a .lock file next to the file we're trying to open.
@@ -119,141 +120,67 @@ class _PosixOpener(_Opener):
IOError: if the open fails.
CredentialsFileSymbolicLinkError if the file is a symbolic link.
"""
if self._locked:
raise AlreadyLockedException('File %s is already locked' %
if self._locked:
raise AlreadyLockedException('File %s is already locked' %
self._filename)
self._locked = False
self._locked = False
validate_file(self._filename)
try:
self._fh = open(self._filename, self._mode)
except IOError as e:
# If we can't access with _mode, try _fallback_mode and don't lock.
if e.errno == errno.EACCES:
self._fh = open(self._filename, self._fallback_mode)
return
validate_file(self._filename)
try:
self._fh = open(self._filename, self._mode)
except IOError as e:
# If we can't access with _mode, try _fallback_mode and don't lock.
if e.errno == errno.EACCES:
self._fh = open(self._filename, self._fallback_mode)
return
lock_filename = self._posix_lockfile(self._filename)
lock_filename = self._posix_lockfile(self._filename)
start_time = time.time()
while True:
try:
self._lock_fd = os.open(lock_filename,
os.O_CREAT|os.O_EXCL|os.O_RDWR)
self._locked = True
break
try:
self._lock_fd = os.open(lock_filename,
os.O_CREAT | os.O_EXCL | os.O_RDWR)
self._locked = True
break
except OSError as e:
if e.errno != errno.EEXIST:
raise
if (time.time() - start_time) >= timeout:
logger.warn('Could not acquire lock %s in %s seconds',
except OSError as e:
if e.errno != errno.EEXIST:
raise
if (time.time() - start_time) >= timeout:
logger.warn('Could not acquire lock %s in %s seconds',
lock_filename, timeout)
# Close the file and open in fallback_mode.
if self._fh:
self._fh.close()
self._fh = open(self._filename, self._fallback_mode)
return
time.sleep(delay)
def unlock_and_close(self):
"""Unlock a file by removing the .lock file, and close the handle."""
if self._locked:
lock_filename = self._posix_lockfile(self._filename)
os.close(self._lock_fd)
os.unlink(lock_filename)
self._locked = False
self._lock_fd = None
if self._fh:
self._fh.close()
def _posix_lockfile(self, filename):
"""The name of the lock file to use for posix locking."""
return '%s.lock' % filename
try:
import fcntl
class _FcntlOpener(_Opener):
"""Open, lock, and unlock a file using fcntl.lockf."""
def open_and_lock(self, timeout, delay):
"""Open the file and lock it.
Args:
timeout: float, How long to try to lock for.
delay: float, How long to wait between retries
Raises:
AlreadyLockedException: if the lock is already acquired.
IOError: if the open fails.
CredentialsFileSymbolicLinkError if the file is a symbolic link.
"""
if self._locked:
raise AlreadyLockedException('File %s is already locked' %
self._filename)
start_time = time.time()
validate_file(self._filename)
try:
self._fh = open(self._filename, self._mode)
except IOError as e:
# If we can't access with _mode, try _fallback_mode and don't lock.
if e.errno in (errno.EPERM, errno.EACCES):
self._fh = open(self._filename, self._fallback_mode)
return
# We opened in _mode, try to lock the file.
while True:
try:
fcntl.lockf(self._fh.fileno(), fcntl.LOCK_EX)
self._locked = True
return
except IOError as e:
# If not retrying, then just pass on the error.
if timeout == 0:
raise
if e.errno != errno.EACCES:
raise
# We could not acquire the lock. Try again.
if (time.time() - start_time) >= timeout:
logger.warn('Could not lock %s in %s seconds',
self._filename, timeout)
if self._fh:
self._fh.close()
self._fh = open(self._filename, self._fallback_mode)
return
time.sleep(delay)
# Close the file and open in fallback_mode.
if self._fh:
self._fh.close()
self._fh = open(self._filename, self._fallback_mode)
return
time.sleep(delay)
def unlock_and_close(self):
"""Close and unlock the file using the fcntl.lockf primitive."""
if self._locked:
fcntl.lockf(self._fh.fileno(), fcntl.LOCK_UN)
self._locked = False
if self._fh:
self._fh.close()
except ImportError:
_FcntlOpener = None
"""Unlock a file by removing the .lock file, and close the handle."""
if self._locked:
lock_filename = self._posix_lockfile(self._filename)
os.close(self._lock_fd)
os.unlink(lock_filename)
self._locked = False
self._lock_fd = None
if self._fh:
self._fh.close()
def _posix_lockfile(self, filename):
"""The name of the lock file to use for posix locking."""
return '%s.lock' % filename
try:
import pywintypes
import win32con
import win32file
import fcntl
class _Win32Opener(_Opener):
"""Open, lock, and unlock a file using windows primitives."""
# Error #33:
# 'The process cannot access the file because another process'
FILE_IN_USE_ERROR = 33
class _FcntlOpener(_Opener):
"""Open, lock, and unlock a file using fcntl.lockf."""
# Error #158:
# 'The segment is already unlocked.'
FILE_ALREADY_UNLOCKED_ERROR = 158
def open_and_lock(self, timeout, delay):
"""Open the file and lock it.
def open_and_lock(self, timeout, delay):
"""Open the file and lock it.
Args:
timeout: float, How long to try to lock for.
@@ -264,71 +191,147 @@ try:
IOError: if the open fails.
CredentialsFileSymbolicLinkError if the file is a symbolic link.
"""
if self._locked:
raise AlreadyLockedException('File %s is already locked' %
if self._locked:
raise AlreadyLockedException('File %s is already locked' %
self._filename)
start_time = time.time()
start_time = time.time()
validate_file(self._filename)
try:
self._fh = open(self._filename, self._mode)
except IOError as e:
# If we can't access with _mode, try _fallback_mode and don't lock.
if e.errno == errno.EACCES:
self._fh = open(self._filename, self._fallback_mode)
return
validate_file(self._filename)
try:
self._fh = open(self._filename, self._mode)
except IOError as e:
# If we can't access with _mode, try _fallback_mode and don't lock.
if e.errno in (errno.EPERM, errno.EACCES):
self._fh = open(self._filename, self._fallback_mode)
return
# We opened in _mode, try to lock the file.
while True:
try:
hfile = win32file._get_osfhandle(self._fh.fileno())
win32file.LockFileEx(
# We opened in _mode, try to lock the file.
while True:
try:
fcntl.lockf(self._fh.fileno(), fcntl.LOCK_EX)
self._locked = True
return
except IOError as e:
# If not retrying, then just pass on the error.
if timeout == 0:
raise
if e.errno != errno.EACCES:
raise
# We could not acquire the lock. Try again.
if (time.time() - start_time) >= timeout:
logger.warn('Could not lock %s in %s seconds',
self._filename, timeout)
if self._fh:
self._fh.close()
self._fh = open(self._filename, self._fallback_mode)
return
time.sleep(delay)
def unlock_and_close(self):
"""Close and unlock the file using the fcntl.lockf primitive."""
if self._locked:
fcntl.lockf(self._fh.fileno(), fcntl.LOCK_UN)
self._locked = False
if self._fh:
self._fh.close()
except ImportError:
_FcntlOpener = None
try:
import pywintypes
import win32con
import win32file
class _Win32Opener(_Opener):
"""Open, lock, and unlock a file using windows primitives."""
# Error #33:
# 'The process cannot access the file because another process'
FILE_IN_USE_ERROR = 33
# Error #158:
# 'The segment is already unlocked.'
FILE_ALREADY_UNLOCKED_ERROR = 158
def open_and_lock(self, timeout, delay):
"""Open the file and lock it.
Args:
timeout: float, How long to try to lock for.
delay: float, How long to wait between retries
Raises:
AlreadyLockedException: if the lock is already acquired.
IOError: if the open fails.
CredentialsFileSymbolicLinkError if the file is a symbolic link.
"""
if self._locked:
raise AlreadyLockedException('File %s is already locked' %
self._filename)
start_time = time.time()
validate_file(self._filename)
try:
self._fh = open(self._filename, self._mode)
except IOError as e:
# If we can't access with _mode, try _fallback_mode and don't lock.
if e.errno == errno.EACCES:
self._fh = open(self._filename, self._fallback_mode)
return
# We opened in _mode, try to lock the file.
while True:
try:
hfile = win32file._get_osfhandle(self._fh.fileno())
win32file.LockFileEx(
hfile,
(win32con.LOCKFILE_FAIL_IMMEDIATELY|
(win32con.LOCKFILE_FAIL_IMMEDIATELY |
win32con.LOCKFILE_EXCLUSIVE_LOCK), 0, -0x10000,
pywintypes.OVERLAPPED())
self._locked = True
return
except pywintypes.error as e:
if timeout == 0:
raise
self._locked = True
return
except pywintypes.error as e:
if timeout == 0:
raise
# If the error is not that the file is already in use, raise.
if e[0] != _Win32Opener.FILE_IN_USE_ERROR:
raise
# If the error is not that the file is already in use, raise.
if e[0] != _Win32Opener.FILE_IN_USE_ERROR:
raise
# We could not acquire the lock. Try again.
if (time.time() - start_time) >= timeout:
logger.warn('Could not lock %s in %s seconds' % (
# We could not acquire the lock. Try again.
if (time.time() - start_time) >= timeout:
logger.warn('Could not lock %s in %s seconds' % (
self._filename, timeout))
if self._fh:
self._fh.close()
self._fh = open(self._filename, self._fallback_mode)
return
time.sleep(delay)
if self._fh:
self._fh.close()
self._fh = open(self._filename, self._fallback_mode)
return
time.sleep(delay)
def unlock_and_close(self):
"""Close and unlock the file using the win32 primitive."""
if self._locked:
try:
hfile = win32file._get_osfhandle(self._fh.fileno())
win32file.UnlockFileEx(hfile, 0, -0x10000, pywintypes.OVERLAPPED())
except pywintypes.error as e:
if e[0] != _Win32Opener.FILE_ALREADY_UNLOCKED_ERROR:
raise
self._locked = False
def unlock_and_close(self):
"""Close and unlock the file using the win32 primitive."""
if self._locked:
try:
hfile = win32file._get_osfhandle(self._fh.fileno())
win32file.UnlockFileEx(hfile, 0, -0x10000, pywintypes.OVERLAPPED())
except pywintypes.error as e:
if e[0] != _Win32Opener.FILE_ALREADY_UNLOCKED_ERROR:
raise
self._locked = False
if self._fh:
self._fh.close()
self._fh.close()
except ImportError:
_Win32Opener = None
_Win32Opener = None
class LockedFile(object):
"""Represent a file that has exclusive access."""
"""Represent a file that has exclusive access."""
@util.positional(4)
def __init__(self, filename, mode, fallback_mode, use_native_locking=True):
"""Construct a LockedFile.
@util.positional(4)
def __init__(self, filename, mode, fallback_mode, use_native_locking=True):
"""Construct a LockedFile.
Args:
filename: string, The path of the file to open.
@@ -336,32 +339,32 @@ class LockedFile(object):
fallback_mode: string, The mode to use if locking fails.
use_native_locking: bool, Whether or not fcntl/win32 locking is used.
"""
opener = None
if not opener and use_native_locking:
if _Win32Opener:
opener = _Win32Opener(filename, mode, fallback_mode)
if _FcntlOpener:
opener = _FcntlOpener(filename, mode, fallback_mode)
opener = None
if not opener and use_native_locking:
if _Win32Opener:
opener = _Win32Opener(filename, mode, fallback_mode)
if _FcntlOpener:
opener = _FcntlOpener(filename, mode, fallback_mode)
if not opener:
opener = _PosixOpener(filename, mode, fallback_mode)
if not opener:
opener = _PosixOpener(filename, mode, fallback_mode)
self._opener = opener
self._opener = opener
def filename(self):
"""Return the filename we were constructed with."""
return self._opener._filename
def filename(self):
"""Return the filename we were constructed with."""
return self._opener._filename
def file_handle(self):
"""Return the file_handle to the opened file."""
return self._opener.file_handle()
def file_handle(self):
"""Return the file_handle to the opened file."""
return self._opener.file_handle()
def is_locked(self):
"""Return whether we successfully locked the file."""
return self._opener.is_locked()
def is_locked(self):
"""Return whether we successfully locked the file."""
return self._opener.is_locked()
def open_and_lock(self, timeout=0, delay=0.05):
"""Open the file, trying to lock it.
def open_and_lock(self, timeout=0, delay=0.05):
"""Open the file, trying to lock it.
Args:
timeout: float, The number of seconds to try to acquire the lock.
@@ -371,8 +374,8 @@ class LockedFile(object):
AlreadyLockedException: if the lock is already acquired.
IOError: if the open fails.
"""
self._opener.open_and_lock(timeout, delay)
self._opener.open_and_lock(timeout, delay)
def unlock_and_close(self):
"""Unlock and close a file."""
self._opener.unlock_and_close()
def unlock_and_close(self):
"""Unlock and close a file."""
self._opener.unlock_and_close()

View File

@@ -65,17 +65,17 @@ _multistores_lock = threading.Lock()
class Error(Exception):
"""Base error for this module."""
"""Base error for this module."""
class NewerCredentialStoreError(Error):
"""The credential store is a newer version than supported."""
"""The credential store is a newer version than supported."""
@util.positional(4)
def get_credential_storage(filename, client_id, user_agent, scope,
warn_on_readonly=True):
"""Get a Storage instance for a credential.
"""Get a Storage instance for a credential.
Args:
filename: The JSON file storing a set of credentials
@@ -88,17 +88,17 @@ def get_credential_storage(filename, client_id, user_agent, scope,
An object derived from client.Storage for getting/setting the
credential.
"""
# Recreate the legacy key with these specific parameters
key = {'clientId': client_id, 'userAgent': user_agent,
# Recreate the legacy key with these specific parameters
key = {'clientId': client_id, 'userAgent': user_agent,
'scope': util.scopes_to_string(scope)}
return get_credential_storage_custom_key(
return get_credential_storage_custom_key(
filename, key, warn_on_readonly=warn_on_readonly)
@util.positional(2)
def get_credential_storage_custom_string_key(
filename, key_string, warn_on_readonly=True):
"""Get a Storage instance for a credential using a single string as a key.
"""Get a Storage instance for a credential using a single string as a key.
Allows you to provide a string as a custom key that will be used for
credential storage and retrieval.
@@ -112,16 +112,16 @@ def get_credential_storage_custom_string_key(
An object derived from client.Storage for getting/setting the
credential.
"""
# Create a key dictionary that can be used
key_dict = {'key': key_string}
return get_credential_storage_custom_key(
# Create a key dictionary that can be used
key_dict = {'key': key_string}
return get_credential_storage_custom_key(
filename, key_dict, warn_on_readonly=warn_on_readonly)
@util.positional(2)
def get_credential_storage_custom_key(
filename, key_dict, warn_on_readonly=True):
"""Get a Storage instance for a credential using a dictionary as a key.
"""Get a Storage instance for a credential using a dictionary as a key.
Allows you to provide a dictionary as a custom key that will be used for
credential storage and retrieval.
@@ -137,14 +137,14 @@ def get_credential_storage_custom_key(
An object derived from client.Storage for getting/setting the
credential.
"""
multistore = _get_multistore(filename, warn_on_readonly=warn_on_readonly)
key = util.dict_to_tuple_key(key_dict)
return multistore._get_storage(key)
multistore = _get_multistore(filename, warn_on_readonly=warn_on_readonly)
key = util.dict_to_tuple_key(key_dict)
return multistore._get_storage(key)
@util.positional(1)
def get_all_credential_keys(filename, warn_on_readonly=True):
"""Gets all the registered credential keys in the given Multistore.
"""Gets all the registered credential keys in the given Multistore.
Args:
filename: The JSON file storing a set of credentials
@@ -155,17 +155,17 @@ def get_all_credential_keys(filename, warn_on_readonly=True):
dictionaries that can be passed into get_credential_storage_custom_key to
get the actual credentials.
"""
multistore = _get_multistore(filename, warn_on_readonly=warn_on_readonly)
multistore._lock()
try:
return multistore._get_all_credential_keys()
finally:
multistore._unlock()
multistore = _get_multistore(filename, warn_on_readonly=warn_on_readonly)
multistore._lock()
try:
return multistore._get_all_credential_keys()
finally:
multistore._unlock()
@util.positional(1)
def _get_multistore(filename, warn_on_readonly=True):
"""A helper method to initialize the multistore with proper locking.
"""A helper method to initialize the multistore with proper locking.
Args:
filename: The JSON file storing a set of credentials
@@ -174,176 +174,176 @@ def _get_multistore(filename, warn_on_readonly=True):
Returns:
A multistore object
"""
filename = os.path.expanduser(filename)
_multistores_lock.acquire()
try:
multistore = _multistores.setdefault(
filename = os.path.expanduser(filename)
_multistores_lock.acquire()
try:
multistore = _multistores.setdefault(
filename, _MultiStore(filename, warn_on_readonly=warn_on_readonly))
finally:
_multistores_lock.release()
return multistore
finally:
_multistores_lock.release()
return multistore
class _MultiStore(object):
"""A file backed store for multiple credentials."""
"""A file backed store for multiple credentials."""
@util.positional(2)
def __init__(self, filename, warn_on_readonly=True):
"""Initialize the class.
@util.positional(2)
def __init__(self, filename, warn_on_readonly=True):
"""Initialize the class.
This will create the file if necessary.
"""
self._file = LockedFile(filename, 'r+', 'r')
self._thread_lock = threading.Lock()
self._read_only = False
self._warn_on_readonly = warn_on_readonly
self._file = LockedFile(filename, 'r+', 'r')
self._thread_lock = threading.Lock()
self._read_only = False
self._warn_on_readonly = warn_on_readonly
self._create_file_if_needed()
self._create_file_if_needed()
# Cache of deserialized store. This is only valid after the
# _MultiStore is locked or _refresh_data_cache is called. This is
# of the form of:
#
# ((key, value), (key, value)...) -> OAuth2Credential
#
# If this is None, then the store hasn't been read yet.
self._data = None
# Cache of deserialized store. This is only valid after the
# _MultiStore is locked or _refresh_data_cache is called. This is
# of the form of:
#
# ((key, value), (key, value)...) -> OAuth2Credential
#
# If this is None, then the store hasn't been read yet.
self._data = None
class _Storage(BaseStorage):
"""A Storage object that knows how to read/write a single credential."""
class _Storage(BaseStorage):
"""A Storage object that knows how to read/write a single credential."""
def __init__(self, multistore, key):
self._multistore = multistore
self._key = key
def __init__(self, multistore, key):
self._multistore = multistore
self._key = key
def acquire_lock(self):
"""Acquires any lock necessary to access this Storage.
def acquire_lock(self):
"""Acquires any lock necessary to access this Storage.
This lock is not reentrant.
"""
self._multistore._lock()
self._multistore._lock()
def release_lock(self):
"""Release the Storage lock.
def release_lock(self):
"""Release the Storage lock.
Trying to release a lock that isn't held will result in a
RuntimeError.
"""
self._multistore._unlock()
self._multistore._unlock()
def locked_get(self):
"""Retrieve credential.
def locked_get(self):
"""Retrieve credential.
The Storage lock must be held when this is called.
Returns:
oauth2client.client.Credentials
"""
credential = self._multistore._get_credential(self._key)
if credential:
credential.set_store(self)
return credential
credential = self._multistore._get_credential(self._key)
if credential:
credential.set_store(self)
return credential
def locked_put(self, credentials):
"""Write a credential.
def locked_put(self, credentials):
"""Write a credential.
The Storage lock must be held when this is called.
Args:
credentials: Credentials, the credentials to store.
"""
self._multistore._update_credential(self._key, credentials)
self._multistore._update_credential(self._key, credentials)
def locked_delete(self):
"""Delete a credential.
def locked_delete(self):
"""Delete a credential.
The Storage lock must be held when this is called.
Args:
credentials: Credentials, the credentials to store.
"""
self._multistore._delete_credential(self._key)
self._multistore._delete_credential(self._key)
def _create_file_if_needed(self):
"""Create an empty file if necessary.
def _create_file_if_needed(self):
"""Create an empty file if necessary.
This method will not initialize the file. Instead it implements a
simple version of "touch" to ensure the file has been created.
"""
if not os.path.exists(self._file.filename()):
old_umask = os.umask(0o177)
try:
open(self._file.filename(), 'a+b').close()
finally:
os.umask(old_umask)
if not os.path.exists(self._file.filename()):
old_umask = os.umask(0o177)
try:
open(self._file.filename(), 'a+b').close()
finally:
os.umask(old_umask)
def _lock(self):
"""Lock the entire multistore."""
self._thread_lock.acquire()
try:
self._file.open_and_lock()
except IOError as e:
if e.errno == errno.ENOSYS:
logger.warn('File system does not support locking the credentials '
def _lock(self):
"""Lock the entire multistore."""
self._thread_lock.acquire()
try:
self._file.open_and_lock()
except IOError as e:
if e.errno == errno.ENOSYS:
logger.warn('File system does not support locking the credentials '
'file.')
elif e.errno == errno.ENOLCK:
logger.warn('File system is out of resources for writing the '
elif e.errno == errno.ENOLCK:
logger.warn('File system is out of resources for writing the '
'credentials file (is your disk full?).')
else:
raise
if not self._file.is_locked():
self._read_only = True
if self._warn_on_readonly:
logger.warn('The credentials file (%s) is not writable. Opening in '
else:
raise
if not self._file.is_locked():
self._read_only = True
if self._warn_on_readonly:
logger.warn('The credentials file (%s) is not writable. Opening in '
'read-only mode. Any refreshed credentials will only be '
'valid for this run.', self._file.filename())
if os.path.getsize(self._file.filename()) == 0:
logger.debug('Initializing empty multistore file')
# The multistore is empty so write out an empty file.
self._data = {}
self._write()
elif not self._read_only or self._data is None:
# Only refresh the data if we are read/write or we haven't
# cached the data yet. If we are readonly, we assume is isn't
# changing out from under us and that we only have to read it
# once. This prevents us from whacking any new access keys that
# we have cached in memory but were unable to write out.
self._refresh_data_cache()
if os.path.getsize(self._file.filename()) == 0:
logger.debug('Initializing empty multistore file')
# The multistore is empty so write out an empty file.
self._data = {}
self._write()
elif not self._read_only or self._data is None:
# Only refresh the data if we are read/write or we haven't
# cached the data yet. If we are readonly, we assume is isn't
# changing out from under us and that we only have to read it
# once. This prevents us from whacking any new access keys that
# we have cached in memory but were unable to write out.
self._refresh_data_cache()
def _unlock(self):
"""Release the lock on the multistore."""
self._file.unlock_and_close()
self._thread_lock.release()
def _unlock(self):
"""Release the lock on the multistore."""
self._file.unlock_and_close()
self._thread_lock.release()
def _locked_json_read(self):
"""Get the raw content of the multistore file.
def _locked_json_read(self):
"""Get the raw content of the multistore file.
The multistore must be locked when this is called.
Returns:
The contents of the multistore decoded as JSON.
"""
assert self._thread_lock.locked()
self._file.file_handle().seek(0)
return json.load(self._file.file_handle())
assert self._thread_lock.locked()
self._file.file_handle().seek(0)
return json.load(self._file.file_handle())
def _locked_json_write(self, data):
"""Write a JSON serializable data structure to the multistore.
def _locked_json_write(self, data):
"""Write a JSON serializable data structure to the multistore.
The multistore must be locked when this is called.
Args:
data: The data to be serialized and written.
"""
assert self._thread_lock.locked()
if self._read_only:
return
self._file.file_handle().seek(0)
json.dump(data, self._file.file_handle(), sort_keys=True, indent=2, separators=(',', ': '))
self._file.file_handle().truncate()
assert self._thread_lock.locked()
if self._read_only:
return
self._file.file_handle().seek(0)
json.dump(data, self._file.file_handle(), sort_keys=True, indent=2, separators=(',', ': '))
self._file.file_handle().truncate()
def _refresh_data_cache(self):
"""Refresh the contents of the multistore.
def _refresh_data_cache(self):
"""Refresh the contents of the multistore.
The multistore must be locked when this is called.
@@ -351,41 +351,41 @@ class _MultiStore(object):
NewerCredentialStoreError: Raised when a newer client has written the
store.
"""
self._data = {}
try:
raw_data = self._locked_json_read()
except Exception:
logger.warn('Credential data store could not be loaded. '
self._data = {}
try:
raw_data = self._locked_json_read()
except Exception:
logger.warn('Credential data store could not be loaded. '
'Will ignore and overwrite.')
return
return
version = 0
try:
version = raw_data['file_version']
except Exception:
logger.warn('Missing version for credential data store. It may be '
version = 0
try:
version = raw_data['file_version']
except Exception:
logger.warn('Missing version for credential data store. It may be '
'corrupt or an old version. Overwriting.')
if version > 1:
raise NewerCredentialStoreError(
if version > 1:
raise NewerCredentialStoreError(
'Credential file has file_version of %d. '
'Only file_version of 1 is supported.' % version)
credentials = []
try:
credentials = raw_data['data']
except (TypeError, KeyError):
pass
credentials = []
try:
credentials = raw_data['data']
except (TypeError, KeyError):
pass
for cred_entry in credentials:
try:
(key, credential) = self._decode_credential_from_json(cred_entry)
self._data[key] = credential
except:
# If something goes wrong loading a credential, just ignore it
logger.info('Error decoding credential, skipping', exc_info=True)
for cred_entry in credentials:
try:
(key, credential) = self._decode_credential_from_json(cred_entry)
self._data[key] = credential
except:
# If something goes wrong loading a credential, just ignore it
logger.info('Error decoding credential, skipping', exc_info=True)
def _decode_credential_from_json(self, cred_entry):
"""Load a credential from our JSON serialization.
def _decode_credential_from_json(self, cred_entry):
"""Load a credential from our JSON serialization.
Args:
cred_entry: A dict entry from the data member of our format
@@ -394,36 +394,36 @@ class _MultiStore(object):
(key, cred) where the key is the key tuple and the cred is the
OAuth2Credential object.
"""
raw_key = cred_entry['key']
key = util.dict_to_tuple_key(raw_key)
credential = None
credential = Credentials.new_from_json(json.dumps(cred_entry['credential']))
return (key, credential)
raw_key = cred_entry['key']
key = util.dict_to_tuple_key(raw_key)
credential = None
credential = Credentials.new_from_json(json.dumps(cred_entry['credential']))
return (key, credential)
def _write(self):
"""Write the cached data back out.
def _write(self):
"""Write the cached data back out.
The multistore must be locked.
"""
raw_data = {'file_version': 1}
raw_creds = []
raw_data['data'] = raw_creds
for (cred_key, cred) in self._data.items():
raw_key = dict(cred_key)
raw_cred = json.loads(cred.to_json())
raw_creds.append({'key': raw_key, 'credential': raw_cred})
self._locked_json_write(raw_data)
raw_data = {'file_version': 1}
raw_creds = []
raw_data['data'] = raw_creds
for (cred_key, cred) in self._data.items():
raw_key = dict(cred_key)
raw_cred = json.loads(cred.to_json())
raw_creds.append({'key': raw_key, 'credential': raw_cred})
self._locked_json_write(raw_data)
def _get_all_credential_keys(self):
"""Gets all the registered credential keys in the multistore.
def _get_all_credential_keys(self):
"""Gets all the registered credential keys in the multistore.
Returns:
A list of dictionaries corresponding to all the keys currently registered
"""
return [dict(key) for key in self._data.keys()]
return [dict(key) for key in self._data.keys()]
def _get_credential(self, key):
"""Get a credential from the multistore.
def _get_credential(self, key):
"""Get a credential from the multistore.
The multistore must be locked.
@@ -433,10 +433,10 @@ class _MultiStore(object):
Returns:
The credential specified or None if not present
"""
return self._data.get(key, None)
return self._data.get(key, None)
def _update_credential(self, key, cred):
"""Update a credential and write the multistore.
def _update_credential(self, key, cred):
"""Update a credential and write the multistore.
This must be called when the multistore is locked.
@@ -444,25 +444,25 @@ class _MultiStore(object):
key: The key used to retrieve the credential
cred: The OAuth2Credential to update/set
"""
self._data[key] = cred
self._write()
self._data[key] = cred
self._write()
def _delete_credential(self, key):
"""Delete a credential and write the multistore.
def _delete_credential(self, key):
"""Delete a credential and write the multistore.
This must be called when the multistore is locked.
Args:
key: The key used to retrieve the credential
"""
try:
del self._data[key]
except KeyError:
pass
self._write()
try:
del self._data[key]
except KeyError:
pass
self._write()
def _get_storage(self, key):
"""Get a Storage object to get/set a credential.
def _get_storage(self, key):
"""Get a Storage object to get/set a credential.
This Storage is a 'view' into the multistore.
@@ -472,4 +472,4 @@ class _MultiStore(object):
Returns:
A Storage object that can be used to get/set this cred
"""
return self._Storage(self, key)
return self._Storage(self, key)

View File

@@ -30,7 +30,6 @@ from oauth2client import util
from oauth2client.tools import ClientRedirectHandler
from oauth2client.tools import ClientRedirectServer
FLAGS = gflags.FLAGS
gflags.DEFINE_boolean('auth_local_webserver', True,
@@ -48,7 +47,7 @@ gflags.DEFINE_multi_int('auth_host_port', [8080, 8090],
@util.positional(2)
def run(flow, storage, http=None):
"""Core code for a command-line application.
"""Core code for a command-line application.
The ``run()`` function is called from your application and runs
through all the steps to obtain credentials. It takes a ``Flow``
@@ -86,76 +85,76 @@ def run(flow, storage, http=None):
Returns:
Credentials, the obtained credential.
"""
logging.warning('This function, oauth2client.tools.run(), and the use of '
logging.warning('This function, oauth2client.tools.run(), and the use of '
'the gflags library are deprecated and will be removed in a future '
'version of the library.')
if FLAGS.auth_local_webserver:
success = False
port_number = 0
for port in FLAGS.auth_host_port:
port_number = port
try:
httpd = ClientRedirectServer((FLAGS.auth_host_name, port),
if FLAGS.auth_local_webserver:
success = False
port_number = 0
for port in FLAGS.auth_host_port:
port_number = port
try:
httpd = ClientRedirectServer((FLAGS.auth_host_name, port),
ClientRedirectHandler)
except socket.error as e:
pass
else:
success = True
break
FLAGS.auth_local_webserver = success
except socket.error as e:
pass
else:
success = True
break
FLAGS.auth_local_webserver = success
if not success:
print('Failed to start a local webserver listening on either port 8080')
print('or port 9090. Please check your firewall settings and locally')
print('running programs that may be blocking or using those ports.')
print()
print('Falling back to --noauth_local_webserver and continuing with')
print('authorization.')
print()
print('Failed to start a local webserver listening on either port 8080')
print('or port 9090. Please check your firewall settings and locally')
print('running programs that may be blocking or using those ports.')
print()
print('Falling back to --noauth_local_webserver and continuing with')
print('authorization.')
print()
if FLAGS.auth_local_webserver:
oauth_callback = 'http://%s:%s/' % (FLAGS.auth_host_name, port_number)
else:
oauth_callback = client.OOB_CALLBACK_URN
flow.redirect_uri = oauth_callback
authorize_url = flow.step1_get_authorize_url()
if FLAGS.auth_local_webserver:
webbrowser.open(authorize_url, new=1, autoraise=True)
print('Your browser has been opened to visit:')
print()
print(' ' + authorize_url)
print()
print('If your browser is on a different machine then exit and re-run')
print('this application with the command-line parameter ')
print()
print(' --noauth_local_webserver')
print()
else:
print('Go to the following link in your browser:')
print()
print(' ' + authorize_url)
print()
code = None
if FLAGS.auth_local_webserver:
httpd.handle_request()
if 'error' in httpd.query_params:
sys.exit('Authentication request was rejected.')
if 'code' in httpd.query_params:
code = httpd.query_params['code']
if FLAGS.auth_local_webserver:
oauth_callback = 'http://%s:%s/' % (FLAGS.auth_host_name, port_number)
else:
print('Failed to find "code" in the query parameters of the redirect.')
sys.exit('Try running with --noauth_local_webserver.')
else:
code = input('Enter verification code: ').strip()
oauth_callback = client.OOB_CALLBACK_URN
flow.redirect_uri = oauth_callback
authorize_url = flow.step1_get_authorize_url()
try:
credential = flow.step2_exchange(code, http=http)
except client.FlowExchangeError as e:
sys.exit('Authentication has failed: %s' % e)
if FLAGS.auth_local_webserver:
webbrowser.open(authorize_url, new=1, autoraise=True)
print('Your browser has been opened to visit:')
print()
print(' ' + authorize_url)
print()
print('If your browser is on a different machine then exit and re-run')
print('this application with the command-line parameter ')
print()
print(' --noauth_local_webserver')
print()
else:
print('Go to the following link in your browser:')
print()
print(' ' + authorize_url)
print()
storage.put(credential)
credential.set_store(storage)
print('Authentication successful.')
code = None
if FLAGS.auth_local_webserver:
httpd.handle_request()
if 'error' in httpd.query_params:
sys.exit('Authentication request was rejected.')
if 'code' in httpd.query_params:
code = httpd.query_params['code']
else:
print('Failed to find "code" in the query parameters of the redirect.')
sys.exit('Try running with --noauth_local_webserver.')
else:
code = input('Enter verification code: ').strip()
return credential
try:
credential = flow.step2_exchange(code, http=http)
except client.FlowExchangeError as e:
sys.exit('Authentication has failed: %s' % e)
storage.put(credential)
credential.set_store(storage)
print('Authentication successful.')
return credential

View File

@@ -34,83 +34,83 @@ from oauth2client.client import AssertionCredentials
class _ServiceAccountCredentials(AssertionCredentials):
"""Class representing a service account (signed JWT) credential."""
"""Class representing a service account (signed JWT) credential."""
MAX_TOKEN_LIFETIME_SECS = 3600 # 1 hour in seconds
MAX_TOKEN_LIFETIME_SECS = 3600 # 1 hour in seconds
def __init__(self, service_account_id, service_account_email, private_key_id,
def __init__(self, service_account_id, service_account_email, private_key_id,
private_key_pkcs8_text, scopes, user_agent=None,
token_uri=GOOGLE_TOKEN_URI, revoke_uri=GOOGLE_REVOKE_URI,
**kwargs):
super(_ServiceAccountCredentials, self).__init__(
super(_ServiceAccountCredentials, self).__init__(
None, user_agent=user_agent, token_uri=token_uri, revoke_uri=revoke_uri)
self._service_account_id = service_account_id
self._service_account_email = service_account_email
self._private_key_id = private_key_id
self._private_key = _get_private_key(private_key_pkcs8_text)
self._private_key_pkcs8_text = private_key_pkcs8_text
self._scopes = util.scopes_to_string(scopes)
self._user_agent = user_agent
self._token_uri = token_uri
self._revoke_uri = revoke_uri
self._kwargs = kwargs
self._service_account_id = service_account_id
self._service_account_email = service_account_email
self._private_key_id = private_key_id
self._private_key = _get_private_key(private_key_pkcs8_text)
self._private_key_pkcs8_text = private_key_pkcs8_text
self._scopes = util.scopes_to_string(scopes)
self._user_agent = user_agent
self._token_uri = token_uri
self._revoke_uri = revoke_uri
self._kwargs = kwargs
def _generate_assertion(self):
"""Generate the assertion that will be used in the request."""
def _generate_assertion(self):
"""Generate the assertion that will be used in the request."""
header = {
header = {
'alg': 'RS256',
'typ': 'JWT',
'kid': self._private_key_id
}
}
now = int(time.time())
payload = {
now = int(time.time())
payload = {
'aud': self._token_uri,
'scope': self._scopes,
'iat': now,
'exp': now + _ServiceAccountCredentials.MAX_TOKEN_LIFETIME_SECS,
'iss': self._service_account_email
}
payload.update(self._kwargs)
}
payload.update(self._kwargs)
first_segment = _urlsafe_b64encode(_json_encode(header))
second_segment = _urlsafe_b64encode(_json_encode(payload))
assertion_input = first_segment + b'.' + second_segment
first_segment = _urlsafe_b64encode(_json_encode(header))
second_segment = _urlsafe_b64encode(_json_encode(payload))
assertion_input = first_segment + b'.' + second_segment
# Sign the assertion.
rsa_bytes = rsa.pkcs1.sign(assertion_input, self._private_key, 'SHA-256')
signature = base64.urlsafe_b64encode(rsa_bytes).rstrip(b'=')
# Sign the assertion.
rsa_bytes = rsa.pkcs1.sign(assertion_input, self._private_key, 'SHA-256')
signature = base64.urlsafe_b64encode(rsa_bytes).rstrip(b'=')
return assertion_input + b'.' + signature
return assertion_input + b'.' + signature
def sign_blob(self, blob):
# Ensure that it is bytes
blob = _to_bytes(blob, encoding='utf-8')
return (self._private_key_id,
def sign_blob(self, blob):
# Ensure that it is bytes
blob = _to_bytes(blob, encoding='utf-8')
return (self._private_key_id,
rsa.pkcs1.sign(blob, self._private_key, 'SHA-256'))
@property
def service_account_email(self):
return self._service_account_email
@property
def service_account_email(self):
return self._service_account_email
@property
def serialization_data(self):
return {
@property
def serialization_data(self):
return {
'type': 'service_account',
'client_id': self._service_account_id,
'client_email': self._service_account_email,
'private_key_id': self._private_key_id,
'private_key': self._private_key_pkcs8_text
}
}
def create_scoped_required(self):
return not self._scopes
def create_scoped_required(self):
return not self._scopes
def create_scoped(self, scopes):
return _ServiceAccountCredentials(self._service_account_id,
def create_scoped(self, scopes):
return _ServiceAccountCredentials(self._service_account_id,
self._service_account_email,
self._private_key_id,
self._private_key_pkcs8_text,
@@ -122,10 +122,10 @@ class _ServiceAccountCredentials(AssertionCredentials):
def _get_private_key(private_key_pkcs8_text):
"""Get an RSA private key object from a pkcs8 representation."""
private_key_pkcs8_text = _to_bytes(private_key_pkcs8_text)
der = rsa.pem.load_pem(private_key_pkcs8_text, 'PRIVATE KEY')
asn1_private_key, _ = decoder.decode(der, asn1Spec=PrivateKeyInfo())
return rsa.PrivateKey.load_pkcs1(
"""Get an RSA private key object from a pkcs8 representation."""
private_key_pkcs8_text = _to_bytes(private_key_pkcs8_text)
der = rsa.pem.load_pem(private_key_pkcs8_text, 'PRIVATE KEY')
asn1_private_key, _ = decoder.decode(der, asn1Spec=PrivateKeyInfo())
return rsa.PrivateKey.load_pkcs1(
asn1_private_key.getComponentByName('privateKey').asOctets(),
format='DER')

View File

@@ -35,7 +35,6 @@ from six.moves import input
from oauth2client import client
from oauth2client import util
_CLIENT_SECRETS_MESSAGE = """WARNING: Please configure OAuth 2.0
To make this sample run you will need to populate the client_secrets.json file
@@ -47,22 +46,23 @@ with information from the APIs Console <https://code.google.com/apis/console>.
"""
def _CreateArgumentParser():
try:
import argparse
except ImportError:
return None
parser = argparse.ArgumentParser(add_help=False)
parser.add_argument('--auth_host_name', default='localhost',
try:
import argparse
except ImportError:
return None
parser = argparse.ArgumentParser(add_help=False)
parser.add_argument('--auth_host_name', default='localhost',
help='Hostname when running a local web server.')
parser.add_argument('--noauth_local_webserver', action='store_true',
parser.add_argument('--noauth_local_webserver', action='store_true',
default=False, help='Do not run a local web server.')
parser.add_argument('--auth_host_port', default=[8080, 8090], type=int,
parser.add_argument('--auth_host_port', default=[8080, 8090], type=int,
nargs='*', help='Port web server should listen on.')
parser.add_argument('--logging_level', default='ERROR',
parser.add_argument('--logging_level', default='ERROR',
choices=['DEBUG', 'INFO', 'WARNING', 'ERROR', 'CRITICAL'],
help='Set the logging level of detail.')
return parser
return parser
# argparser is an ArgumentParser that contains command-line options expected
# by tools.run(). Pass it in as part of the 'parents' argument to your own
@@ -71,45 +71,45 @@ argparser = _CreateArgumentParser()
class ClientRedirectServer(BaseHTTPServer.HTTPServer):
"""A server to handle OAuth 2.0 redirects back to localhost.
"""A server to handle OAuth 2.0 redirects back to localhost.
Waits for a single request and parses the query parameters
into query_params and then stops serving.
"""
query_params = {}
query_params = {}
class ClientRedirectHandler(BaseHTTPServer.BaseHTTPRequestHandler):
"""A handler for OAuth 2.0 redirects back to localhost.
"""A handler for OAuth 2.0 redirects back to localhost.
Waits for a single request and parses the query parameters
into the servers query_params and then stops serving.
"""
def do_GET(self):
"""Handle a GET request.
def do_GET(self):
"""Handle a GET request.
Parses the query parameters and prints a message
if the flow has completed. Note that we can't detect
if an error occurred.
"""
self.send_response(200)
self.send_header("Content-type", "text/html")
self.end_headers()
query = self.path.split('?', 1)[-1]
query = dict(urllib.parse.parse_qsl(query))
self.server.query_params = query
self.wfile.write(b"<html><head><title>Authentication Status</title></head>")
self.wfile.write(b"<body><p>The authentication flow has completed.</p>")
self.wfile.write(b"</body></html>")
self.send_response(200)
self.send_header("Content-type", "text/html")
self.end_headers()
query = self.path.split('?', 1)[-1]
query = dict(urllib.parse.parse_qsl(query))
self.server.query_params = query
self.wfile.write(b"<html><head><title>Authentication Status</title></head>")
self.wfile.write(b"<body><p>The authentication flow has completed.</p>")
self.wfile.write(b"</body></html>")
def log_message(self, format, *args):
"""Do not log messages to stdout while running as command line program."""
def log_message(self, format, *args):
"""Do not log messages to stdout while running as command line program."""
@util.positional(3)
def run_flow(flow, storage, flags, http=None):
"""Core code for a command-line application.
"""Core code for a command-line application.
The ``run()`` function is called from your application and runs
through all the steps to obtain credentials. It takes a ``Flow``
@@ -159,91 +159,91 @@ def run_flow(flow, storage, flags, http=None):
Returns:
Credentials, the obtained credential.
"""
logging.getLogger().setLevel(getattr(logging, flags.logging_level))
if not flags.noauth_local_webserver:
success = False
port_number = 0
for port in flags.auth_host_port:
port_number = port
try:
httpd = ClientRedirectServer((flags.auth_host_name, port),
logging.getLogger().setLevel(getattr(logging, flags.logging_level))
if not flags.noauth_local_webserver:
success = False
port_number = 0
for port in flags.auth_host_port:
port_number = port
try:
httpd = ClientRedirectServer((flags.auth_host_name, port),
ClientRedirectHandler)
except socket.error:
pass
else:
success = True
break
flags.noauth_local_webserver = not success
except socket.error:
pass
else:
success = True
break
flags.noauth_local_webserver = not success
if not success:
print('Failed to start a local webserver listening on either port 8080')
print('or port 9090. Please check your firewall settings and locally')
print('running programs that may be blocking or using those ports.')
print()
print('Falling back to --noauth_local_webserver and continuing with')
print('authorization.')
print()
print('Failed to start a local webserver listening on either port 8080')
print('or port 9090. Please check your firewall settings and locally')
print('running programs that may be blocking or using those ports.')
print()
print('Falling back to --noauth_local_webserver and continuing with')
print('authorization.')
print()
if not flags.noauth_local_webserver:
oauth_callback = 'http://%s:%s/' % (flags.auth_host_name, port_number)
else:
oauth_callback = client.OOB_CALLBACK_URN
flow.redirect_uri = oauth_callback
authorize_url = flow.step1_get_authorize_url()
if not flags.noauth_local_webserver:
import webbrowser
webbrowser.open(authorize_url, new=1, autoraise=True)
print('Your browser has been opened to visit:')
print()
print(' ' + authorize_url)
print()
print('If your browser is on a different machine then exit and re-run this')
print('application with the command-line parameter ')
print()
print(' --noauth_local_webserver')
print()
else:
print('Go to the following link in your browser:')
print()
print(' ' + authorize_url)
print()
code = None
if not flags.noauth_local_webserver:
httpd.handle_request()
if 'error' in httpd.query_params:
sys.exit('Authentication request was rejected.')
if 'code' in httpd.query_params:
code = httpd.query_params['code']
if not flags.noauth_local_webserver:
oauth_callback = 'http://%s:%s/' % (flags.auth_host_name, port_number)
else:
print('Failed to find "code" in the query parameters of the redirect.')
sys.exit('Try running with --noauth_local_webserver.')
else:
code = input('Enter verification code: ').strip()
oauth_callback = client.OOB_CALLBACK_URN
flow.redirect_uri = oauth_callback
authorize_url = flow.step1_get_authorize_url()
try:
credential = flow.step2_exchange(code, http=http)
except client.FlowExchangeError as e:
sys.exit('Authentication has failed: %s' % e)
if not flags.noauth_local_webserver:
import webbrowser
webbrowser.open(authorize_url, new=1, autoraise=True)
print('Your browser has been opened to visit:')
print()
print(' ' + authorize_url)
print()
print('If your browser is on a different machine then exit and re-run this')
print('application with the command-line parameter ')
print()
print(' --noauth_local_webserver')
print()
else:
print('Go to the following link in your browser:')
print()
print(' ' + authorize_url)
print()
storage.put(credential)
credential.set_store(storage)
print('Authentication successful.')
code = None
if not flags.noauth_local_webserver:
httpd.handle_request()
if 'error' in httpd.query_params:
sys.exit('Authentication request was rejected.')
if 'code' in httpd.query_params:
code = httpd.query_params['code']
else:
print('Failed to find "code" in the query parameters of the redirect.')
sys.exit('Try running with --noauth_local_webserver.')
else:
code = input('Enter verification code: ').strip()
return credential
try:
credential = flow.step2_exchange(code, http=http)
except client.FlowExchangeError as e:
sys.exit('Authentication has failed: %s' % e)
storage.put(credential)
credential.set_store(storage)
print('Authentication successful.')
return credential
def message_if_missing(filename):
"""Helpful message to display if the CLIENT_SECRETS file is missing."""
"""Helpful message to display if the CLIENT_SECRETS file is missing."""
return _CLIENT_SECRETS_MESSAGE % filename
return _CLIENT_SECRETS_MESSAGE % filename
try:
from oauth2client.old_run import run
from oauth2client.old_run import FLAGS
from oauth2client.old_run import run
from oauth2client.old_run import FLAGS
except ImportError:
def run(*args, **kwargs):
raise NotImplementedError(
def run(*args, **kwargs):
raise NotImplementedError(
'The gflags library must be installed to use tools.run(). '
'Please install gflags or preferrably switch to using '
'tools.run_flow().')

View File

@@ -38,7 +38,6 @@ import types
import six
from six.moves import urllib
logger = logging.getLogger(__name__)
POSITIONAL_WARNING = 'WARNING'
@@ -49,8 +48,9 @@ POSITIONAL_SET = frozenset([POSITIONAL_WARNING, POSITIONAL_EXCEPTION,
positional_parameters_enforcement = POSITIONAL_WARNING
def positional(max_positional_args):
"""A decorator to declare that only the first N arguments my be positional.
"""A decorator to declare that only the first N arguments my be positional.
This decorator makes it easy to support Python 3 style keyword-only
parameters. For example, in Python 3 it is possible to write::
@@ -119,33 +119,34 @@ def positional(max_positional_args):
POSITIONAL_EXCEPTION.
"""
def positional_decorator(wrapped):
@functools.wraps(wrapped)
def positional_wrapper(*args, **kwargs):
if len(args) > max_positional_args:
plural_s = ''
if max_positional_args != 1:
plural_s = 's'
message = '%s() takes at most %d positional argument%s (%d given)' % (
wrapped.__name__, max_positional_args, plural_s, len(args))
if positional_parameters_enforcement == POSITIONAL_EXCEPTION:
raise TypeError(message)
elif positional_parameters_enforcement == POSITIONAL_WARNING:
logger.warning(message)
else: # IGNORE
pass
return wrapped(*args, **kwargs)
return positional_wrapper
if isinstance(max_positional_args, six.integer_types):
return positional_decorator
else:
args, _, _, defaults = inspect.getargspec(max_positional_args)
return positional(len(args) - len(defaults))(max_positional_args)
def positional_decorator(wrapped):
@functools.wraps(wrapped)
def positional_wrapper(*args, **kwargs):
if len(args) > max_positional_args:
plural_s = ''
if max_positional_args != 1:
plural_s = 's'
message = '%s() takes at most %d positional argument%s (%d given)' % (
wrapped.__name__, max_positional_args, plural_s, len(args))
if positional_parameters_enforcement == POSITIONAL_EXCEPTION:
raise TypeError(message)
elif positional_parameters_enforcement == POSITIONAL_WARNING:
logger.warning(message)
else: # IGNORE
pass
return wrapped(*args, **kwargs)
return positional_wrapper
if isinstance(max_positional_args, six.integer_types):
return positional_decorator
else:
args, _, _, defaults = inspect.getargspec(max_positional_args)
return positional(len(args) - len(defaults))(max_positional_args)
def scopes_to_string(scopes):
"""Converts scope value to a string.
"""Converts scope value to a string.
If scopes is a string then it is simply passed through. If scopes is an
iterable then a string is returned that is all the individual scopes
@@ -157,14 +158,14 @@ def scopes_to_string(scopes):
Returns:
The scopes formatted as a single string.
"""
if isinstance(scopes, six.string_types):
return scopes
else:
return ' '.join(scopes)
if isinstance(scopes, six.string_types):
return scopes
else:
return ' '.join(scopes)
def string_to_scopes(scopes):
"""Converts stringifed scope value to a list.
"""Converts stringifed scope value to a list.
If scopes is a list then it is simply passed through. If scopes is an
string then a list of each individual scope is returned.
@@ -175,16 +176,16 @@ def string_to_scopes(scopes):
Returns:
The scopes in a list.
"""
if not scopes:
return []
if isinstance(scopes, six.string_types):
return scopes.split(' ')
else:
return scopes
if not scopes:
return []
if isinstance(scopes, six.string_types):
return scopes.split(' ')
else:
return scopes
def dict_to_tuple_key(dictionary):
"""Converts a dictionary to a tuple that can be used as an immutable key.
"""Converts a dictionary to a tuple that can be used as an immutable key.
The resulting key is always sorted so that logically equivalent dictionaries
always produce an identical tuple for a key.
@@ -195,11 +196,11 @@ def dict_to_tuple_key(dictionary):
Returns:
A tuple representing the dictionary in it's naturally sorted ordering.
"""
return tuple(sorted(dictionary.items()))
return tuple(sorted(dictionary.items()))
def _add_query_parameter(url, name, value):
"""Adds a query parameter to a url.
"""Adds a query parameter to a url.
Replaces the current value if it already exists in the URL.
@@ -211,11 +212,11 @@ def _add_query_parameter(url, name, value):
Returns:
Updated query parameter. Does not update the url if value is None.
"""
if value is None:
return url
else:
parsed = list(urllib.parse.urlparse(url))
q = dict(urllib.parse.parse_qsl(parsed[4]))
q[name] = value
parsed[4] = urllib.parse.urlencode(q)
return urllib.parse.urlunparse(parsed)
if value is None:
return url
else:
parsed = list(urllib.parse.urlparse(url))
q = dict(urllib.parse.parse_qsl(parsed[4]))
q[name] = value
parsed[4] = urllib.parse.urlencode(q)
return urllib.parse.urlunparse(parsed)

View File

@@ -20,7 +20,6 @@ __authors__ = [
'"Joe Gregorio" <jcgregorio@google.com>',
]
import base64
import hmac
import time
@@ -28,13 +27,11 @@ import time
import six
from oauth2client import util
# Delimiter character
DELIMITER = b':'
# 1 hour in seconds
DEFAULT_TIMEOUT_SECS = 1*60*60
DEFAULT_TIMEOUT_SECS = 1 * 60 * 60
def _force_bytes(s):
@@ -48,7 +45,7 @@ def _force_bytes(s):
@util.positional(2)
def generate_token(key, user_id, action_id="", when=None):
"""Generates a URL-safe token for the given user, action, time tuple.
"""Generates a URL-safe token for the given user, action, time tuple.
Args:
key: secret key to use.
@@ -61,22 +58,22 @@ def generate_token(key, user_id, action_id="", when=None):
Returns:
A string XSRF protection token.
"""
when = _force_bytes(when or int(time.time()))
digester = hmac.new(_force_bytes(key))
digester.update(_force_bytes(user_id))
digester.update(DELIMITER)
digester.update(_force_bytes(action_id))
digester.update(DELIMITER)
digester.update(when)
digest = digester.digest()
when = _force_bytes(when or int(time.time()))
digester = hmac.new(_force_bytes(key))
digester.update(_force_bytes(user_id))
digester.update(DELIMITER)
digester.update(_force_bytes(action_id))
digester.update(DELIMITER)
digester.update(when)
digest = digester.digest()
token = base64.urlsafe_b64encode(digest + DELIMITER + when)
return token
token = base64.urlsafe_b64encode(digest + DELIMITER + when)
return token
@util.positional(3)
def validate_token(key, token, user_id, action_id="", current_time=None):
"""Validates that the given token authorizes the user for the action.
"""Validates that the given token authorizes the user for the action.
Tokens are invalid if the time of issue is too old or if the token
does not match what generateToken outputs (i.e. the token was forged).
@@ -92,27 +89,27 @@ def validate_token(key, token, user_id, action_id="", current_time=None):
A boolean - True if the user is authorized for the action, False
otherwise.
"""
if not token:
return False
try:
decoded = base64.urlsafe_b64decode(token)
token_time = int(decoded.split(DELIMITER)[-1])
except (TypeError, ValueError):
return False
if current_time is None:
current_time = time.time()
# If the token is too old it's not valid.
if current_time - token_time > DEFAULT_TIMEOUT_SECS:
return False
if not token:
return False
try:
decoded = base64.urlsafe_b64decode(token)
token_time = int(decoded.split(DELIMITER)[-1])
except (TypeError, ValueError):
return False
if current_time is None:
current_time = time.time()
# If the token is too old it's not valid.
if current_time - token_time > DEFAULT_TIMEOUT_SECS:
return False
# The given token should match the generated one with the same time.
expected_token = generate_token(key, user_id, action_id=action_id,
# The given token should match the generated one with the same time.
expected_token = generate_token(key, user_id, action_id=action_id,
when=token_time)
if len(token) != len(expected_token):
return False
if len(token) != len(expected_token):
return False
# Perform constant time comparison to avoid timing attacks
different = 0
for x, y in zip(bytearray(token), bytearray(expected_token)):
different |= x ^ y
return not different
# Perform constant time comparison to avoid timing attacks
different = 0
for x, y in zip(bytearray(token), bytearray(expected_token)):
different |= x ^ y
return not different

View File

@@ -16,6 +16,7 @@ __author__ = 'afshar@google.com (Ali Afshar)'
import oauth2client.util
def setup_package():
"""Run on testing package."""
oauth2client.util.positional_parameters_enforcement = 'EXCEPTION'
"""Run on testing package."""
oauth2client.util.positional_parameters_enforcement = 'EXCEPTION'

View File

@@ -20,47 +20,45 @@ import httplib2
# TODO(craigcitro): Find a cleaner way to share this code with googleapiclient.
class HttpMock(object):
"""Mock of httplib2.Http"""
"""Mock of httplib2.Http"""
def __init__(self, filename=None, headers=None):
"""
def __init__(self, filename=None, headers=None):
"""
Args:
filename: string, absolute filename to read response from
headers: dict, header to return with response
"""
if headers is None:
headers = {'status': '200 OK'}
if filename:
f = file(filename, 'r')
self.data = f.read()
f.close()
else:
self.data = None
self.response_headers = headers
self.headers = None
self.uri = None
self.method = None
self.body = None
self.headers = None
if headers is None:
headers = {'status': '200 OK'}
if filename:
f = file(filename, 'r')
self.data = f.read()
f.close()
else:
self.data = None
self.response_headers = headers
self.headers = None
self.uri = None
self.method = None
self.body = None
self.headers = None
def request(self, uri,
def request(self, uri,
method='GET',
body=None,
headers=None,
redirections=1,
connection_type=None):
self.uri = uri
self.method = method
self.body = body
self.headers = headers
return httplib2.Response(self.response_headers), self.data
self.uri = uri
self.method = method
self.body = body
self.headers = headers
return httplib2.Response(self.response_headers), self.data
class HttpMockSequence(object):
"""Mock of httplib2.Http
"""Mock of httplib2.Http
Mocks a sequence of calls to request returning different responses for each
call. Create an instance initialized with the desired response headers
@@ -83,33 +81,33 @@ class HttpMockSequence(object):
'echo_request_uri' means return the request uri in the response body
"""
def __init__(self, iterable):
"""
def __init__(self, iterable):
"""
Args:
iterable: iterable, a sequence of pairs of (headers, body)
"""
self._iterable = iterable
self.follow_redirects = True
self.requests = []
self._iterable = iterable
self.follow_redirects = True
self.requests = []
def request(self, uri,
def request(self, uri,
method='GET',
body=None,
headers=None,
redirections=1,
connection_type=None):
resp, content = self._iterable.pop(0)
self.requests.append({'uri': uri, 'body': body, 'headers': headers})
# Read any underlying stream before sending the request.
body_stream_content = body.read() if getattr(body, 'read', None) else None
if content == 'echo_request_headers':
content = headers
elif content == 'echo_request_headers_as_json':
content = json.dumps(headers)
elif content == 'echo_request_body':
content = body if body_stream_content is None else body_stream_content
elif content == 'echo_request_uri':
content = uri
elif not isinstance(content, bytes):
raise TypeError('http content should be bytes: %r' % (content,))
return httplib2.Response(resp), content
resp, content = self._iterable.pop(0)
self.requests.append({'uri': uri, 'body': body, 'headers': headers})
# Read any underlying stream before sending the request.
body_stream_content = body.read() if getattr(body, 'read', None) else None
if content == 'echo_request_headers':
content = headers
elif content == 'echo_request_headers_as_json':
content = json.dumps(headers)
elif content == 'echo_request_body':
content = body if body_stream_content is None else body_stream_content
elif content == 'echo_request_uri':
content = uri
elif not isinstance(content, bytes):
raise TypeError('http content should be bytes: %r' % (content, ))
return httplib2.Response(resp), content

View File

@@ -25,93 +25,93 @@ from oauth2client._helpers import _urlsafe_b64encode
class Test__parse_pem_key(unittest.TestCase):
def test_valid_input(self):
test_string = b'1234-----BEGIN FOO BAR BAZ'
result = _parse_pem_key(test_string)
self.assertEqual(result, test_string[4:])
def test_valid_input(self):
test_string = b'1234-----BEGIN FOO BAR BAZ'
result = _parse_pem_key(test_string)
self.assertEqual(result, test_string[4:])
def test_bad_input(self):
test_string = b'DOES NOT HAVE DASHES'
result = _parse_pem_key(test_string)
self.assertEqual(result, None)
def test_bad_input(self):
test_string = b'DOES NOT HAVE DASHES'
result = _parse_pem_key(test_string)
self.assertEqual(result, None)
class Test__json_encode(unittest.TestCase):
def test_dictionary_input(self):
# Use only a single key since dictionary hash order
# is non-deterministic.
data = {u'foo': 10}
result = _json_encode(data)
self.assertEqual(result, """{"foo":10}""")
def test_dictionary_input(self):
# Use only a single key since dictionary hash order
# is non-deterministic.
data = {u'foo': 10}
result = _json_encode(data)
self.assertEqual(result, """{"foo":10}""")
def test_list_input(self):
data = [42, 1337]
result = _json_encode(data)
self.assertEqual(result, """[42,1337]""")
def test_list_input(self):
data = [42, 1337]
result = _json_encode(data)
self.assertEqual(result, """[42,1337]""")
class Test__to_bytes(unittest.TestCase):
def test_with_bytes(self):
value = b'bytes-val'
self.assertEqual(_to_bytes(value), value)
def test_with_bytes(self):
value = b'bytes-val'
self.assertEqual(_to_bytes(value), value)
def test_with_unicode(self):
value = u'string-val'
encoded_value = b'string-val'
self.assertEqual(_to_bytes(value), encoded_value)
def test_with_unicode(self):
value = u'string-val'
encoded_value = b'string-val'
self.assertEqual(_to_bytes(value), encoded_value)
def test_with_nonstring_type(self):
value = object()
self.assertRaises(ValueError, _to_bytes, value)
def test_with_nonstring_type(self):
value = object()
self.assertRaises(ValueError, _to_bytes, value)
class Test__from_bytes(unittest.TestCase):
def test_with_unicode(self):
value = u'bytes-val'
self.assertEqual(_from_bytes(value), value)
def test_with_unicode(self):
value = u'bytes-val'
self.assertEqual(_from_bytes(value), value)
def test_with_bytes(self):
value = b'string-val'
decoded_value = u'string-val'
self.assertEqual(_from_bytes(value), decoded_value)
def test_with_bytes(self):
value = b'string-val'
decoded_value = u'string-val'
self.assertEqual(_from_bytes(value), decoded_value)
def test_with_nonstring_type(self):
value = object()
self.assertRaises(ValueError, _from_bytes, value)
def test_with_nonstring_type(self):
value = object()
self.assertRaises(ValueError, _from_bytes, value)
class Test__urlsafe_b64encode(unittest.TestCase):
DEADBEEF_ENCODED = b'ZGVhZGJlZWY'
DEADBEEF_ENCODED = b'ZGVhZGJlZWY'
def test_valid_input_bytes(self):
test_string = b'deadbeef'
result = _urlsafe_b64encode(test_string)
self.assertEqual(result, self.DEADBEEF_ENCODED)
def test_valid_input_bytes(self):
test_string = b'deadbeef'
result = _urlsafe_b64encode(test_string)
self.assertEqual(result, self.DEADBEEF_ENCODED)
def test_valid_input_unicode(self):
test_string = u'deadbeef'
result = _urlsafe_b64encode(test_string)
self.assertEqual(result, self.DEADBEEF_ENCODED)
def test_valid_input_unicode(self):
test_string = u'deadbeef'
result = _urlsafe_b64encode(test_string)
self.assertEqual(result, self.DEADBEEF_ENCODED)
class Test__urlsafe_b64decode(unittest.TestCase):
def test_valid_input_bytes(self):
test_string = b'ZGVhZGJlZWY'
result = _urlsafe_b64decode(test_string)
self.assertEqual(result, b'deadbeef')
def test_valid_input_bytes(self):
test_string = b'ZGVhZGJlZWY'
result = _urlsafe_b64decode(test_string)
self.assertEqual(result, b'deadbeef')
def test_valid_input_unicode(self):
test_string = b'ZGVhZGJlZWY'
result = _urlsafe_b64decode(test_string)
self.assertEqual(result, b'deadbeef')
def test_valid_input_unicode(self):
test_string = b'ZGVhZGJlZWY'
result = _urlsafe_b64decode(test_string)
self.assertEqual(result, b'deadbeef')
def test_bad_input(self):
import binascii
bad_string = b'+'
self.assertRaises((TypeError, binascii.Error),
def test_bad_input(self):
import binascii
bad_string = b'+'
self.assertRaises((TypeError, binascii.Error),
_urlsafe_b64decode, bad_string)

View File

@@ -22,42 +22,42 @@ from oauth2client.crypt import PyCryptoVerifier
class TestPyCryptoVerifier(unittest.TestCase):
PUBLIC_KEY_FILENAME = os.path.join(os.path.dirname(__file__),
PUBLIC_KEY_FILENAME = os.path.join(os.path.dirname(__file__),
'data', 'publickey.pem')
PRIVATE_KEY_FILENAME = os.path.join(os.path.dirname(__file__),
PRIVATE_KEY_FILENAME = os.path.join(os.path.dirname(__file__),
'data', 'privatekey.pem')
def _load_public_key_bytes(self):
with open(self.PUBLIC_KEY_FILENAME, 'rb') as fh:
return fh.read()
def _load_public_key_bytes(self):
with open(self.PUBLIC_KEY_FILENAME, 'rb') as fh:
return fh.read()
def _load_private_key_bytes(self):
with open(self.PRIVATE_KEY_FILENAME, 'rb') as fh:
return fh.read()
def _load_private_key_bytes(self):
with open(self.PRIVATE_KEY_FILENAME, 'rb') as fh:
return fh.read()
def test_verify_success(self):
to_sign = b'foo'
signer = PyCryptoSigner.from_string(self._load_private_key_bytes())
actual_signature = signer.sign(to_sign)
def test_verify_success(self):
to_sign = b'foo'
signer = PyCryptoSigner.from_string(self._load_private_key_bytes())
actual_signature = signer.sign(to_sign)
verifier = PyCryptoVerifier.from_string(self._load_public_key_bytes(),
verifier = PyCryptoVerifier.from_string(self._load_public_key_bytes(),
is_x509_cert=True)
self.assertTrue(verifier.verify(to_sign, actual_signature))
self.assertTrue(verifier.verify(to_sign, actual_signature))
def test_verify_failure(self):
verifier = PyCryptoVerifier.from_string(self._load_public_key_bytes(),
def test_verify_failure(self):
verifier = PyCryptoVerifier.from_string(self._load_public_key_bytes(),
is_x509_cert=True)
bad_signature = b''
self.assertFalse(verifier.verify(b'foo', bad_signature))
bad_signature = b''
self.assertFalse(verifier.verify(b'foo', bad_signature))
def test_verify_bad_key(self):
verifier = PyCryptoVerifier.from_string(self._load_public_key_bytes(),
def test_verify_bad_key(self):
verifier = PyCryptoVerifier.from_string(self._load_public_key_bytes(),
is_x509_cert=True)
bad_signature = b''
self.assertFalse(verifier.verify(b'foo', bad_signature))
bad_signature = b''
self.assertFalse(verifier.verify(b'foo', bad_signature))
def test_from_string_unicode_key(self):
public_key = self._load_public_key_bytes()
public_key = public_key.decode('utf-8')
verifier = PyCryptoVerifier.from_string(public_key, is_x509_cert=True)
self.assertTrue(isinstance(verifier, PyCryptoVerifier))
def test_from_string_unicode_key(self):
public_key = self._load_public_key_bytes()
public_key = public_key.decode('utf-8')
verifier = PyCryptoVerifier.from_string(public_key, is_x509_cert=True)
self.assertTrue(isinstance(verifier, PyCryptoVerifier))

File diff suppressed because it is too large Load Diff

View File

@@ -16,7 +16,6 @@
__author__ = 'jcgregorio@google.com (Joe Gregorio)'
import os
import unittest
from io import StringIO
@@ -25,22 +24,22 @@ import httplib2
from oauth2client import clientsecrets
DATA_DIR = os.path.join(os.path.dirname(__file__), 'data')
VALID_FILE = os.path.join(DATA_DIR, 'client_secrets.json')
INVALID_FILE = os.path.join(DATA_DIR, 'unfilled_client_secrets.json')
NONEXISTENT_FILE = os.path.join(__file__, '..', 'afilethatisntthere.json')
class OAuth2CredentialsTests(unittest.TestCase):
def setUp(self):
pass
def setUp(self):
pass
def tearDown(self):
pass
def tearDown(self):
pass
def test_validate_error(self):
ERRORS = [
def test_validate_error(self):
ERRORS = [
('{}', 'Invalid'),
('{"foo": {}}', 'Unknown'),
('{"web": {}}', 'Missing'),
@@ -56,95 +55,95 @@ class OAuth2CredentialsTests(unittest.TestCase):
}
""", 'Property'),
]
for src, match in ERRORS:
# Ensure that it is unicode
try:
src = src.decode('utf-8')
except AttributeError:
pass
# Test load(s)
try:
clientsecrets.loads(src)
self.fail(src + ' should not be a valid client_secrets file.')
except clientsecrets.InvalidClientSecretsError as e:
self.assertTrue(str(e).startswith(match))
for src, match in ERRORS:
# Ensure that it is unicode
try:
src = src.decode('utf-8')
except AttributeError:
pass
# Test load(s)
try:
clientsecrets.loads(src)
self.fail(src + ' should not be a valid client_secrets file.')
except clientsecrets.InvalidClientSecretsError as e:
self.assertTrue(str(e).startswith(match))
# Test loads(fp)
try:
fp = StringIO(src)
clientsecrets.load(fp)
self.fail(src + ' should not be a valid client_secrets file.')
except clientsecrets.InvalidClientSecretsError as e:
self.assertTrue(str(e).startswith(match))
# Test loads(fp)
try:
fp = StringIO(src)
clientsecrets.load(fp)
self.fail(src + ' should not be a valid client_secrets file.')
except clientsecrets.InvalidClientSecretsError as e:
self.assertTrue(str(e).startswith(match))
def test_load_by_filename(self):
try:
clientsecrets._loadfile(NONEXISTENT_FILE)
self.fail('should fail to load a missing client_secrets file.')
except clientsecrets.InvalidClientSecretsError as e:
self.assertTrue(str(e).startswith('File'))
def test_load_by_filename(self):
try:
clientsecrets._loadfile(NONEXISTENT_FILE)
self.fail('should fail to load a missing client_secrets file.')
except clientsecrets.InvalidClientSecretsError as e:
self.assertTrue(str(e).startswith('File'))
class CachedClientsecretsTests(unittest.TestCase):
class CacheMock(object):
def __init__(self):
self.cache = {}
self.last_get_ns = None
self.last_set_ns = None
class CacheMock(object):
def __init__(self):
self.cache = {}
self.last_get_ns = None
self.last_set_ns = None
def get(self, key, namespace=''):
# ignoring namespace for easier testing
self.last_get_ns = namespace
return self.cache.get(key, None)
def get(self, key, namespace=''):
# ignoring namespace for easier testing
self.last_get_ns = namespace
return self.cache.get(key, None)
def set(self, key, value, namespace=''):
# ignoring namespace for easier testing
self.last_set_ns = namespace
self.cache[key] = value
def set(self, key, value, namespace=''):
# ignoring namespace for easier testing
self.last_set_ns = namespace
self.cache[key] = value
def setUp(self):
self.cache_mock = self.CacheMock()
def setUp(self):
self.cache_mock = self.CacheMock()
def test_cache_miss(self):
client_type, client_info = clientsecrets.loadfile(
def test_cache_miss(self):
client_type, client_info = clientsecrets.loadfile(
VALID_FILE, cache=self.cache_mock)
self.assertEqual('web', client_type)
self.assertEqual('foo_client_secret', client_info['client_secret'])
self.assertEqual('web', client_type)
self.assertEqual('foo_client_secret', client_info['client_secret'])
cached = self.cache_mock.cache[VALID_FILE]
self.assertEqual({client_type: client_info}, cached)
cached = self.cache_mock.cache[VALID_FILE]
self.assertEqual({client_type: client_info}, cached)
# make sure we're using non-empty namespace
ns = self.cache_mock.last_set_ns
self.assertTrue(bool(ns))
# make sure they're equal
self.assertEqual(ns, self.cache_mock.last_get_ns)
# make sure we're using non-empty namespace
ns = self.cache_mock.last_set_ns
self.assertTrue(bool(ns))
# make sure they're equal
self.assertEqual(ns, self.cache_mock.last_get_ns)
def test_cache_hit(self):
self.cache_mock.cache[NONEXISTENT_FILE] = { 'web': 'secret info' }
def test_cache_hit(self):
self.cache_mock.cache[NONEXISTENT_FILE] = {'web': 'secret info'}
client_type, client_info = clientsecrets.loadfile(
client_type, client_info = clientsecrets.loadfile(
NONEXISTENT_FILE, cache=self.cache_mock)
self.assertEqual('web', client_type)
self.assertEqual('secret info', client_info)
# make sure we didn't do any set() RPCs
self.assertEqual(None, self.cache_mock.last_set_ns)
self.assertEqual('web', client_type)
self.assertEqual('secret info', client_info)
# make sure we didn't do any set() RPCs
self.assertEqual(None, self.cache_mock.last_set_ns)
def test_validation(self):
try:
clientsecrets.loadfile(INVALID_FILE, cache=self.cache_mock)
self.fail('Expected InvalidClientSecretsError to be raised '
def test_validation(self):
try:
clientsecrets.loadfile(INVALID_FILE, cache=self.cache_mock)
self.fail('Expected InvalidClientSecretsError to be raised '
'while loading %s' % INVALID_FILE)
except clientsecrets.InvalidClientSecretsError:
pass
except clientsecrets.InvalidClientSecretsError:
pass
def test_without_cache(self):
# this also ensures loadfile() is backward compatible
client_type, client_info = clientsecrets.loadfile(VALID_FILE)
self.assertEqual('web', client_type)
self.assertEqual('foo_client_secret', client_info['client_secret'])
def test_without_cache(self):
# this also ensures loadfile() is backward compatible
client_type, client_info = clientsecrets.loadfile(VALID_FILE)
self.assertEqual('web', client_type)
self.assertEqual('foo_client_secret', client_info['client_secret'])
if __name__ == '__main__':
unittest.main()
unittest.main()

View File

@@ -18,10 +18,10 @@ import sys
import unittest
try:
reload
reload
except NameError:
# For Python3 (though importlib should be used, silly 3.3).
from imp import reload
# For Python3 (though importlib should be used, silly 3.3).
from imp import reload
from oauth2client import _helpers
from oauth2client.client import HAS_OPENSSL
@@ -30,44 +30,44 @@ from oauth2client import crypt
def datafile(filename):
f = open(os.path.join(os.path.dirname(__file__), 'data', filename), 'rb')
data = f.read()
f.close()
return data
f = open(os.path.join(os.path.dirname(__file__), 'data', filename), 'rb')
data = f.read()
f.close()
return data
class Test_pkcs12_key_as_pem(unittest.TestCase):
def _make_signed_jwt_creds(self, private_key_file='privatekey.p12',
def _make_signed_jwt_creds(self, private_key_file='privatekey.p12',
private_key=None):
private_key = private_key or datafile(private_key_file)
return SignedJwtAssertionCredentials(
private_key = private_key or datafile(private_key_file)
return SignedJwtAssertionCredentials(
'some_account@example.com',
private_key,
scope='read+write',
sub='joe@example.org')
def _succeeds_helper(self, password=None):
self.assertEqual(True, HAS_OPENSSL)
def _succeeds_helper(self, password=None):
self.assertEqual(True, HAS_OPENSSL)
credentials = self._make_signed_jwt_creds()
if password is None:
password = credentials.private_key_password
pem_contents = crypt.pkcs12_key_as_pem(credentials.private_key, password)
pkcs12_key_as_pem = datafile('pem_from_pkcs12.pem')
pkcs12_key_as_pem = _helpers._parse_pem_key(pkcs12_key_as_pem)
alternate_pem = datafile('pem_from_pkcs12_alternate.pem')
self.assertTrue(pem_contents in [pkcs12_key_as_pem, alternate_pem])
credentials = self._make_signed_jwt_creds()
if password is None:
password = credentials.private_key_password
pem_contents = crypt.pkcs12_key_as_pem(credentials.private_key, password)
pkcs12_key_as_pem = datafile('pem_from_pkcs12.pem')
pkcs12_key_as_pem = _helpers._parse_pem_key(pkcs12_key_as_pem)
alternate_pem = datafile('pem_from_pkcs12_alternate.pem')
self.assertTrue(pem_contents in [pkcs12_key_as_pem, alternate_pem])
def test_succeeds(self):
self._succeeds_helper()
def test_succeeds(self):
self._succeeds_helper()
def test_succeeds_with_unicode_password(self):
password = u'notasecret'
self._succeeds_helper(password)
def test_succeeds_with_unicode_password(self):
password = u'notasecret'
self._succeeds_helper(password)
def test_with_nonsense_key(self):
from OpenSSL import crypto
credentials = self._make_signed_jwt_creds(private_key=b'NOT_A_KEY')
self.assertRaises(crypto.Error, crypt.pkcs12_key_as_pem,
def test_with_nonsense_key(self):
from OpenSSL import crypto
credentials = self._make_signed_jwt_creds(private_key=b'NOT_A_KEY')
self.assertRaises(crypto.Error, crypt.pkcs12_key_as_pem,
credentials.private_key, credentials.private_key_password)

View File

@@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for oauth2client.devshell."""
import os
@@ -30,110 +29,110 @@ from oauth2client.devshell import NoDevshellServer
class _AuthReferenceServer(threading.Thread):
def __init__(self, response=None):
super(_AuthReferenceServer, self).__init__(None)
self.response = (response or
def __init__(self, response=None):
super(_AuthReferenceServer, self).__init__(None)
self.response = (response or
'["joe@example.com", "fooproj", "sometoken"]')
def __enter__(self):
self.start_server()
def __enter__(self):
self.start_server()
def start_server(self):
self._socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
self._socket.bind(('localhost', 0))
port = self._socket.getsockname()[1]
os.environ[DEVSHELL_ENV] = str(port)
self._socket.listen(0)
self.daemon = True
self.start()
return self
def start_server(self):
self._socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
self._socket.bind(('localhost', 0))
port = self._socket.getsockname()[1]
os.environ[DEVSHELL_ENV] = str(port)
self._socket.listen(0)
self.daemon = True
self.start()
return self
def __exit__(self, e_type, value, traceback):
self.stop_server()
def __exit__(self, e_type, value, traceback):
self.stop_server()
def stop_server(self):
del os.environ[DEVSHELL_ENV]
self._socket.close()
def stop_server(self):
del os.environ[DEVSHELL_ENV]
self._socket.close()
def run(self):
s = None
try:
# Do not set the timeout on the socket, leave it in the blocking mode as
# setting the timeout seems to cause spurious EAGAIN errors on OSX.
self._socket.settimeout(None)
def run(self):
s = None
try:
# Do not set the timeout on the socket, leave it in the blocking mode as
# setting the timeout seems to cause spurious EAGAIN errors on OSX.
self._socket.settimeout(None)
s, unused_addr = self._socket.accept()
resp_buffer = ''
resp_1 = s.recv(6).decode()
if '\n' not in resp_1:
raise Exception('invalid request data')
nstr, extra = resp_1.split('\n', 1)
resp_buffer = extra
n = int(nstr)
to_read = n-len(extra)
if to_read > 0:
resp_buffer += s.recv(to_read, socket.MSG_WAITALL)
if resp_buffer != CREDENTIAL_INFO_REQUEST_JSON:
raise Exception('bad request')
l = len(self.response)
s.sendall(('%d\n%s' % (l, self.response)).encode())
finally:
if s:
s.close()
s, unused_addr = self._socket.accept()
resp_buffer = ''
resp_1 = s.recv(6).decode()
if '\n' not in resp_1:
raise Exception('invalid request data')
nstr, extra = resp_1.split('\n', 1)
resp_buffer = extra
n = int(nstr)
to_read = n - len(extra)
if to_read > 0:
resp_buffer += s.recv(to_read, socket.MSG_WAITALL)
if resp_buffer != CREDENTIAL_INFO_REQUEST_JSON:
raise Exception('bad request')
l = len(self.response)
s.sendall(('%d\n%s' % (l, self.response)).encode())
finally:
if s:
s.close()
class DevshellCredentialsTests(unittest.TestCase):
def test_signals_no_server(self):
self.assertRaises(NoDevshellServer, DevshellCredentials)
def test_signals_no_server(self):
self.assertRaises(NoDevshellServer, DevshellCredentials)
def test_request_response(self):
with _AuthReferenceServer():
response = _SendRecv()
self.assertEqual(response.user_email, 'joe@example.com')
self.assertEqual(response.project_id, 'fooproj')
self.assertEqual(response.access_token, 'sometoken')
def test_request_response(self):
with _AuthReferenceServer():
response = _SendRecv()
self.assertEqual(response.user_email, 'joe@example.com')
self.assertEqual(response.project_id, 'fooproj')
self.assertEqual(response.access_token, 'sometoken')
def test_no_refresh_token(self):
with _AuthReferenceServer():
creds = DevshellCredentials()
self.assertEquals(None, creds.refresh_token)
def test_no_refresh_token(self):
with _AuthReferenceServer():
creds = DevshellCredentials()
self.assertEquals(None, creds.refresh_token)
def test_reads_credentials(self):
with _AuthReferenceServer():
creds = DevshellCredentials()
self.assertEqual('joe@example.com', creds.user_email)
self.assertEqual('fooproj', creds.project_id)
self.assertEqual('sometoken', creds.access_token)
def test_reads_credentials(self):
with _AuthReferenceServer():
creds = DevshellCredentials()
self.assertEqual('joe@example.com', creds.user_email)
self.assertEqual('fooproj', creds.project_id)
self.assertEqual('sometoken', creds.access_token)
def test_handles_skipped_fields(self):
with _AuthReferenceServer('["joe@example.com"]'):
creds = DevshellCredentials()
self.assertEqual('joe@example.com', creds.user_email)
self.assertEqual(None, creds.project_id)
self.assertEqual(None, creds.access_token)
def test_handles_skipped_fields(self):
with _AuthReferenceServer('["joe@example.com"]'):
creds = DevshellCredentials()
self.assertEqual('joe@example.com', creds.user_email)
self.assertEqual(None, creds.project_id)
self.assertEqual(None, creds.access_token)
def test_handles_tiny_response(self):
with _AuthReferenceServer('[]'):
creds = DevshellCredentials()
self.assertEqual(None, creds.user_email)
self.assertEqual(None, creds.project_id)
self.assertEqual(None, creds.access_token)
def test_handles_tiny_response(self):
with _AuthReferenceServer('[]'):
creds = DevshellCredentials()
self.assertEqual(None, creds.user_email)
self.assertEqual(None, creds.project_id)
self.assertEqual(None, creds.access_token)
def test_handles_ignores_extra_fields(self):
with _AuthReferenceServer(
def test_handles_ignores_extra_fields(self):
with _AuthReferenceServer(
'["joe@example.com", "fooproj", "sometoken", "extra"]'):
creds = DevshellCredentials()
self.assertEqual('joe@example.com', creds.user_email)
self.assertEqual('fooproj', creds.project_id)
self.assertEqual('sometoken', creds.access_token)
creds = DevshellCredentials()
self.assertEqual('joe@example.com', creds.user_email)
self.assertEqual('fooproj', creds.project_id)
self.assertEqual('sometoken', creds.access_token)
def test_refuses_to_save_to_well_known_file(self):
ORIGINAL_ISDIR = os.path.isdir
try:
os.path.isdir = lambda path: True
with _AuthReferenceServer():
creds = DevshellCredentials()
self.assertRaises(NotImplementedError, save_to_well_known_file, creds)
finally:
os.path.isdir = ORIGINAL_ISDIR
def test_refuses_to_save_to_well_known_file(self):
ORIGINAL_ISDIR = os.path.isdir
try:
os.path.isdir = lambda path: True
with _AuthReferenceServer():
creds = DevshellCredentials()
self.assertRaises(NotImplementedError, save_to_well_known_file, creds)
finally:
os.path.isdir = ORIGINAL_ISDIR

View File

@@ -14,7 +14,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""Discovery document tests
Unit tests for objects created from discovery documents.
@@ -31,10 +30,10 @@ import unittest
# Ensure that if app engine is available, we use the correct django from it
try:
from google.appengine.dist import use_library
use_library('django', '1.5')
from google.appengine.dist import use_library
use_library('django', '1.5')
except ImportError:
pass
pass
from oauth2client.client import Credentials
from oauth2client.client import Flow
@@ -51,39 +50,39 @@ from oauth2client.django_orm import FlowField
class TestCredentialsField(unittest.TestCase):
def setUp(self):
self.field = CredentialsField()
self.credentials = Credentials()
self.pickle = base64.b64encode(pickle.dumps(self.credentials))
def setUp(self):
self.field = CredentialsField()
self.credentials = Credentials()
self.pickle = base64.b64encode(pickle.dumps(self.credentials))
def test_field_is_text(self):
self.assertEquals(self.field.get_internal_type(), 'TextField')
def test_field_is_text(self):
self.assertEquals(self.field.get_internal_type(), 'TextField')
def test_field_unpickled(self):
self.assertTrue(isinstance(self.field.to_python(self.pickle), Credentials))
def test_field_unpickled(self):
self.assertTrue(isinstance(self.field.to_python(self.pickle), Credentials))
def test_field_pickled(self):
prep_value = self.field.get_db_prep_value(self.credentials,
def test_field_pickled(self):
prep_value = self.field.get_db_prep_value(self.credentials,
connection=None)
self.assertEqual(prep_value, self.pickle)
self.assertEqual(prep_value, self.pickle)
class TestFlowField(unittest.TestCase):
def setUp(self):
self.field = FlowField()
self.flow = Flow()
self.pickle = base64.b64encode(pickle.dumps(self.flow))
def setUp(self):
self.field = FlowField()
self.flow = Flow()
self.pickle = base64.b64encode(pickle.dumps(self.flow))
def test_field_is_text(self):
self.assertEquals(self.field.get_internal_type(), 'TextField')
def test_field_is_text(self):
self.assertEquals(self.field.get_internal_type(), 'TextField')
def test_field_unpickled(self):
self.assertTrue(isinstance(self.field.to_python(self.pickle), Flow))
def test_field_unpickled(self):
self.assertTrue(isinstance(self.field.to_python(self.pickle), Flow))
def test_field_pickled(self):
prep_value = self.field.get_db_prep_value(self.flow, connection=None)
self.assertEqual(prep_value, self.pickle)
def test_field_pickled(self):
prep_value = self.field.get_db_prep_value(self.flow, connection=None)
self.assertEqual(prep_value, self.pickle)
if __name__ == '__main__':
unittest.main()
unittest.main()

View File

@@ -14,7 +14,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""Oauth2client.file tests
Unit tests for oauth2client.file
@@ -42,363 +41,362 @@ from oauth2client.client import AccessTokenCredentials
from oauth2client.client import OAuth2Credentials
from six.moves import http_client
try:
# Python2
from future_builtins import oct
# Python2
from future_builtins import oct
except:
pass
pass
FILENAME = tempfile.mktemp('oauth2client_test.data')
class OAuth2ClientFileTests(unittest.TestCase):
def tearDown(self):
try:
os.unlink(FILENAME)
except OSError:
pass
def tearDown(self):
try:
os.unlink(FILENAME)
except OSError:
pass
def setUp(self):
try:
os.unlink(FILENAME)
except OSError:
pass
def setUp(self):
try:
os.unlink(FILENAME)
except OSError:
pass
def create_test_credentials(self, client_id='some_client_id',
def create_test_credentials(self, client_id='some_client_id',
expiration=None):
access_token = 'foo'
client_secret = 'cOuDdkfjxxnv+'
refresh_token = '1/0/a.df219fjls0'
token_expiry = expiration or datetime.datetime.utcnow()
token_uri = 'https://www.google.com/accounts/o8/oauth2/token'
user_agent = 'refresh_checker/1.0'
access_token = 'foo'
client_secret = 'cOuDdkfjxxnv+'
refresh_token = '1/0/a.df219fjls0'
token_expiry = expiration or datetime.datetime.utcnow()
token_uri = 'https://www.google.com/accounts/o8/oauth2/token'
user_agent = 'refresh_checker/1.0'
credentials = OAuth2Credentials(
credentials = OAuth2Credentials(
access_token, client_id, client_secret,
refresh_token, token_expiry, token_uri,
user_agent)
return credentials
return credentials
def test_non_existent_file_storage(self):
s = file.Storage(FILENAME)
credentials = s.get()
self.assertEquals(None, credentials)
def test_non_existent_file_storage(self):
s = file.Storage(FILENAME)
credentials = s.get()
self.assertEquals(None, credentials)
def test_no_sym_link_credentials(self):
if hasattr(os, 'symlink'):
SYMFILENAME = FILENAME + '.sym'
os.symlink(FILENAME, SYMFILENAME)
s = file.Storage(SYMFILENAME)
try:
s.get()
self.fail('Should have raised an exception.')
except file.CredentialsFileSymbolicLinkError:
pass
finally:
os.unlink(SYMFILENAME)
def test_no_sym_link_credentials(self):
if hasattr(os, 'symlink'):
SYMFILENAME = FILENAME + '.sym'
os.symlink(FILENAME, SYMFILENAME)
s = file.Storage(SYMFILENAME)
try:
s.get()
self.fail('Should have raised an exception.')
except file.CredentialsFileSymbolicLinkError:
pass
finally:
os.unlink(SYMFILENAME)
def test_pickle_and_json_interop(self):
# Write a file with a pickled OAuth2Credentials.
credentials = self.create_test_credentials()
def test_pickle_and_json_interop(self):
# Write a file with a pickled OAuth2Credentials.
credentials = self.create_test_credentials()
f = open(FILENAME, 'wb')
pickle.dump(credentials, f)
f.close()
f = open(FILENAME, 'wb')
pickle.dump(credentials, f)
f.close()
# Storage should be not be able to read that object, as the capability to
# read and write credentials as pickled objects has been removed.
s = file.Storage(FILENAME)
read_credentials = s.get()
self.assertEquals(None, read_credentials)
# Storage should be not be able to read that object, as the capability to
# read and write credentials as pickled objects has been removed.
s = file.Storage(FILENAME)
read_credentials = s.get()
self.assertEquals(None, read_credentials)
# Now write it back out and confirm it has been rewritten as JSON
s.put(credentials)
with open(FILENAME) as f:
data = json.load(f)
# Now write it back out and confirm it has been rewritten as JSON
s.put(credentials)
with open(FILENAME) as f:
data = json.load(f)
self.assertEquals(data['access_token'], 'foo')
self.assertEquals(data['_class'], 'OAuth2Credentials')
self.assertEquals(data['_module'], OAuth2Credentials.__module__)
self.assertEquals(data['access_token'], 'foo')
self.assertEquals(data['_class'], 'OAuth2Credentials')
self.assertEquals(data['_module'], OAuth2Credentials.__module__)
def test_token_refresh_store_expired(self):
expiration = datetime.datetime.utcnow() - datetime.timedelta(minutes=15)
credentials = self.create_test_credentials(expiration=expiration)
def test_token_refresh_store_expired(self):
expiration = datetime.datetime.utcnow() - datetime.timedelta(minutes=15)
credentials = self.create_test_credentials(expiration=expiration)
s = file.Storage(FILENAME)
s.put(credentials)
credentials = s.get()
new_cred = copy.copy(credentials)
new_cred.access_token = 'bar'
s.put(new_cred)
s = file.Storage(FILENAME)
s.put(credentials)
credentials = s.get()
new_cred = copy.copy(credentials)
new_cred.access_token = 'bar'
s.put(new_cred)
access_token = '1/3w'
token_response = {'access_token': access_token, 'expires_in': 3600}
http = HttpMockSequence([
access_token = '1/3w'
token_response = {'access_token': access_token, 'expires_in': 3600}
http = HttpMockSequence([
({'status': '200'}, json.dumps(token_response).encode('utf-8')),
])
])
credentials._refresh(http.request)
self.assertEquals(credentials.access_token, access_token)
credentials._refresh(http.request)
self.assertEquals(credentials.access_token, access_token)
def test_token_refresh_store_expires_soon(self):
# Tests the case where an access token that is valid when it is read from
# the store expires before the original request succeeds.
expiration = datetime.datetime.utcnow() + datetime.timedelta(minutes=15)
credentials = self.create_test_credentials(expiration=expiration)
def test_token_refresh_store_expires_soon(self):
# Tests the case where an access token that is valid when it is read from
# the store expires before the original request succeeds.
expiration = datetime.datetime.utcnow() + datetime.timedelta(minutes=15)
credentials = self.create_test_credentials(expiration=expiration)
s = file.Storage(FILENAME)
s.put(credentials)
credentials = s.get()
new_cred = copy.copy(credentials)
new_cred.access_token = 'bar'
s.put(new_cred)
s = file.Storage(FILENAME)
s.put(credentials)
credentials = s.get()
new_cred = copy.copy(credentials)
new_cred.access_token = 'bar'
s.put(new_cred)
access_token = '1/3w'
token_response = {'access_token': access_token, 'expires_in': 3600}
http = HttpMockSequence([
access_token = '1/3w'
token_response = {'access_token': access_token, 'expires_in': 3600}
http = HttpMockSequence([
({'status': str(http_client.UNAUTHORIZED)}, b'Initial token expired'),
({'status': str(http_client.UNAUTHORIZED)}, b'Store token expired'),
({'status': str(http_client.OK)},
json.dumps(token_response).encode('utf-8')),
({'status': str(http_client.OK)},
b'Valid response to original request')
])
])
credentials.authorize(http)
http.request('https://example.com')
self.assertEqual(credentials.access_token, access_token)
credentials.authorize(http)
http.request('https://example.com')
self.assertEqual(credentials.access_token, access_token)
def test_token_refresh_good_store(self):
expiration = datetime.datetime.utcnow() + datetime.timedelta(minutes=15)
credentials = self.create_test_credentials(expiration=expiration)
def test_token_refresh_good_store(self):
expiration = datetime.datetime.utcnow() + datetime.timedelta(minutes=15)
credentials = self.create_test_credentials(expiration=expiration)
s = file.Storage(FILENAME)
s.put(credentials)
credentials = s.get()
new_cred = copy.copy(credentials)
new_cred.access_token = 'bar'
s.put(new_cred)
s = file.Storage(FILENAME)
s.put(credentials)
credentials = s.get()
new_cred = copy.copy(credentials)
new_cred.access_token = 'bar'
s.put(new_cred)
credentials._refresh(lambda x: x)
self.assertEquals(credentials.access_token, 'bar')
credentials._refresh(lambda x: x)
self.assertEquals(credentials.access_token, 'bar')
def test_token_refresh_stream_body(self):
expiration = datetime.datetime.utcnow() + datetime.timedelta(minutes=15)
credentials = self.create_test_credentials(expiration=expiration)
def test_token_refresh_stream_body(self):
expiration = datetime.datetime.utcnow() + datetime.timedelta(minutes=15)
credentials = self.create_test_credentials(expiration=expiration)
s = file.Storage(FILENAME)
s.put(credentials)
credentials = s.get()
new_cred = copy.copy(credentials)
new_cred.access_token = 'bar'
s.put(new_cred)
s = file.Storage(FILENAME)
s.put(credentials)
credentials = s.get()
new_cred = copy.copy(credentials)
new_cred.access_token = 'bar'
s.put(new_cred)
valid_access_token = '1/3w'
token_response = {'access_token': valid_access_token, 'expires_in': 3600}
http = HttpMockSequence([
valid_access_token = '1/3w'
token_response = {'access_token': valid_access_token, 'expires_in': 3600}
http = HttpMockSequence([
({'status': str(http_client.UNAUTHORIZED)}, b'Initial token expired'),
({'status': str(http_client.UNAUTHORIZED)}, b'Store token expired'),
({'status': str(http_client.OK)},
json.dumps(token_response).encode('utf-8')),
({'status': str(http_client.OK)}, 'echo_request_body')
])
])
body = six.StringIO('streaming body')
body = six.StringIO('streaming body')
credentials.authorize(http)
_, content = http.request('https://example.com', body=body)
self.assertEqual(content, 'streaming body')
self.assertEqual(credentials.access_token, valid_access_token)
credentials.authorize(http)
_, content = http.request('https://example.com', body=body)
self.assertEqual(content, 'streaming body')
self.assertEqual(credentials.access_token, valid_access_token)
def test_credentials_delete(self):
credentials = self.create_test_credentials()
def test_credentials_delete(self):
credentials = self.create_test_credentials()
s = file.Storage(FILENAME)
s.put(credentials)
credentials = s.get()
self.assertNotEquals(None, credentials)
s.delete()
credentials = s.get()
self.assertEquals(None, credentials)
s = file.Storage(FILENAME)
s.put(credentials)
credentials = s.get()
self.assertNotEquals(None, credentials)
s.delete()
credentials = s.get()
self.assertEquals(None, credentials)
def test_access_token_credentials(self):
access_token = 'foo'
user_agent = 'refresh_checker/1.0'
def test_access_token_credentials(self):
access_token = 'foo'
user_agent = 'refresh_checker/1.0'
credentials = AccessTokenCredentials(access_token, user_agent)
credentials = AccessTokenCredentials(access_token, user_agent)
s = file.Storage(FILENAME)
credentials = s.put(credentials)
credentials = s.get()
s = file.Storage(FILENAME)
credentials = s.put(credentials)
credentials = s.get()
self.assertNotEquals(None, credentials)
self.assertEquals('foo', credentials.access_token)
mode = os.stat(FILENAME).st_mode
self.assertNotEquals(None, credentials)
self.assertEquals('foo', credentials.access_token)
mode = os.stat(FILENAME).st_mode
if os.name == 'posix':
self.assertEquals('0o600', oct(stat.S_IMODE(os.stat(FILENAME).st_mode)))
if os.name == 'posix':
self.assertEquals('0o600', oct(stat.S_IMODE(os.stat(FILENAME).st_mode)))
def test_read_only_file_fail_lock(self):
credentials = self.create_test_credentials()
def test_read_only_file_fail_lock(self):
credentials = self.create_test_credentials()
open(FILENAME, 'a+b').close()
os.chmod(FILENAME, 0o400)
open(FILENAME, 'a+b').close()
os.chmod(FILENAME, 0o400)
store = multistore_file.get_credential_storage(
store = multistore_file.get_credential_storage(
FILENAME,
credentials.client_id,
credentials.user_agent,
['some-scope', 'some-other-scope'])
store.put(credentials)
if os.name == 'posix':
self.assertTrue(store._multistore._read_only)
os.chmod(FILENAME, 0o600)
store.put(credentials)
if os.name == 'posix':
self.assertTrue(store._multistore._read_only)
os.chmod(FILENAME, 0o600)
def test_multistore_no_symbolic_link_files(self):
if hasattr(os, 'symlink'):
SYMFILENAME = FILENAME + 'sym'
os.symlink(FILENAME, SYMFILENAME)
store = multistore_file.get_credential_storage(
def test_multistore_no_symbolic_link_files(self):
if hasattr(os, 'symlink'):
SYMFILENAME = FILENAME + 'sym'
os.symlink(FILENAME, SYMFILENAME)
store = multistore_file.get_credential_storage(
SYMFILENAME,
'some_client_id',
'user-agent/1.0',
['some-scope', 'some-other-scope'])
try:
store.get()
self.fail('Should have raised an exception.')
except locked_file.CredentialsFileSymbolicLinkError:
pass
finally:
os.unlink(SYMFILENAME)
try:
store.get()
self.fail('Should have raised an exception.')
except locked_file.CredentialsFileSymbolicLinkError:
pass
finally:
os.unlink(SYMFILENAME)
def test_multistore_non_existent_file(self):
store = multistore_file.get_credential_storage(
def test_multistore_non_existent_file(self):
store = multistore_file.get_credential_storage(
FILENAME,
'some_client_id',
'user-agent/1.0',
['some-scope', 'some-other-scope'])
credentials = store.get()
self.assertEquals(None, credentials)
credentials = store.get()
self.assertEquals(None, credentials)
def test_multistore_file(self):
credentials = self.create_test_credentials()
def test_multistore_file(self):
credentials = self.create_test_credentials()
store = multistore_file.get_credential_storage(
store = multistore_file.get_credential_storage(
FILENAME,
credentials.client_id,
credentials.user_agent,
['some-scope', 'some-other-scope'])
store.put(credentials)
credentials = store.get()
store.put(credentials)
credentials = store.get()
self.assertNotEquals(None, credentials)
self.assertEquals('foo', credentials.access_token)
self.assertNotEquals(None, credentials)
self.assertEquals('foo', credentials.access_token)
store.delete()
credentials = store.get()
store.delete()
credentials = store.get()
self.assertEquals(None, credentials)
self.assertEquals(None, credentials)
if os.name == 'posix':
self.assertEquals('0o600', oct(stat.S_IMODE(os.stat(FILENAME).st_mode)))
if os.name == 'posix':
self.assertEquals('0o600', oct(stat.S_IMODE(os.stat(FILENAME).st_mode)))
def test_multistore_file_custom_key(self):
credentials = self.create_test_credentials()
def test_multistore_file_custom_key(self):
credentials = self.create_test_credentials()
custom_key = {'myapp': 'testing', 'clientid': 'some client'}
store = multistore_file.get_credential_storage_custom_key(
custom_key = {'myapp': 'testing', 'clientid': 'some client'}
store = multistore_file.get_credential_storage_custom_key(
FILENAME, custom_key)
store.put(credentials)
stored_credentials = store.get()
store.put(credentials)
stored_credentials = store.get()
self.assertNotEquals(None, stored_credentials)
self.assertEqual(credentials.access_token, stored_credentials.access_token)
self.assertNotEquals(None, stored_credentials)
self.assertEqual(credentials.access_token, stored_credentials.access_token)
store.delete()
stored_credentials = store.get()
store.delete()
stored_credentials = store.get()
self.assertEquals(None, stored_credentials)
self.assertEquals(None, stored_credentials)
def test_multistore_file_custom_string_key(self):
credentials = self.create_test_credentials()
def test_multistore_file_custom_string_key(self):
credentials = self.create_test_credentials()
# store with string key
store = multistore_file.get_credential_storage_custom_string_key(
# store with string key
store = multistore_file.get_credential_storage_custom_string_key(
FILENAME, 'mykey')
store.put(credentials)
stored_credentials = store.get()
store.put(credentials)
stored_credentials = store.get()
self.assertNotEquals(None, stored_credentials)
self.assertEqual(credentials.access_token, stored_credentials.access_token)
self.assertNotEquals(None, stored_credentials)
self.assertEqual(credentials.access_token, stored_credentials.access_token)
# try retrieving with a dictionary
store_dict = multistore_file.get_credential_storage_custom_string_key(
# try retrieving with a dictionary
store_dict = multistore_file.get_credential_storage_custom_string_key(
FILENAME, {'key': 'mykey'})
stored_credentials = store.get()
self.assertNotEquals(None, stored_credentials)
self.assertEqual(credentials.access_token, stored_credentials.access_token)
stored_credentials = store.get()
self.assertNotEquals(None, stored_credentials)
self.assertEqual(credentials.access_token, stored_credentials.access_token)
store.delete()
stored_credentials = store.get()
store.delete()
stored_credentials = store.get()
self.assertEquals(None, stored_credentials)
self.assertEquals(None, stored_credentials)
def test_multistore_file_backwards_compatibility(self):
credentials = self.create_test_credentials()
scopes = ['scope1', 'scope2']
def test_multistore_file_backwards_compatibility(self):
credentials = self.create_test_credentials()
scopes = ['scope1', 'scope2']
# store the credentials using the legacy key method
store = multistore_file.get_credential_storage(
# store the credentials using the legacy key method
store = multistore_file.get_credential_storage(
FILENAME, 'client_id', 'user_agent', scopes)
store.put(credentials)
store.put(credentials)
# retrieve the credentials using a custom key that matches the legacy key
key = {'clientId': 'client_id', 'userAgent': 'user_agent',
# retrieve the credentials using a custom key that matches the legacy key
key = {'clientId': 'client_id', 'userAgent': 'user_agent',
'scope': util.scopes_to_string(scopes)}
store = multistore_file.get_credential_storage_custom_key(FILENAME, key)
stored_credentials = store.get()
store = multistore_file.get_credential_storage_custom_key(FILENAME, key)
stored_credentials = store.get()
self.assertEqual(credentials.access_token, stored_credentials.access_token)
self.assertEqual(credentials.access_token, stored_credentials.access_token)
def test_multistore_file_get_all_keys(self):
# start with no keys
keys = multistore_file.get_all_credential_keys(FILENAME)
self.assertEquals([], keys)
def test_multistore_file_get_all_keys(self):
# start with no keys
keys = multistore_file.get_all_credential_keys(FILENAME)
self.assertEquals([], keys)
# store credentials
credentials = self.create_test_credentials(client_id='client1')
custom_key = {'myapp': 'testing', 'clientid': 'client1'}
store1 = multistore_file.get_credential_storage_custom_key(
# store credentials
credentials = self.create_test_credentials(client_id='client1')
custom_key = {'myapp': 'testing', 'clientid': 'client1'}
store1 = multistore_file.get_credential_storage_custom_key(
FILENAME, custom_key)
store1.put(credentials)
store1.put(credentials)
keys = multistore_file.get_all_credential_keys(FILENAME)
self.assertEquals([custom_key], keys)
keys = multistore_file.get_all_credential_keys(FILENAME)
self.assertEquals([custom_key], keys)
# store more credentials
credentials = self.create_test_credentials(client_id='client2')
string_key = 'string_key'
store2 = multistore_file.get_credential_storage_custom_string_key(
# store more credentials
credentials = self.create_test_credentials(client_id='client2')
string_key = 'string_key'
store2 = multistore_file.get_credential_storage_custom_string_key(
FILENAME, string_key)
store2.put(credentials)
store2.put(credentials)
keys = multistore_file.get_all_credential_keys(FILENAME)
self.assertEquals(2, len(keys))
self.assertTrue(custom_key in keys)
self.assertTrue({'key': string_key} in keys)
keys = multistore_file.get_all_credential_keys(FILENAME)
self.assertEquals(2, len(keys))
self.assertTrue(custom_key in keys)
self.assertTrue({'key': string_key} in keys)
# back to no keys
store1.delete()
store2.delete()
keys = multistore_file.get_all_credential_keys(FILENAME)
self.assertEquals([], keys)
# back to no keys
store1.delete()
store2.delete()
keys = multistore_file.get_all_credential_keys(FILENAME)
self.assertEquals([], keys)
if __name__ == '__main__':
unittest.main()
unittest.main()

View File

@@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""Unit tests for the Flask utilities"""
__author__ = 'jonwayne@google.com (Jon Wayne Parrott)'
@@ -35,6 +34,7 @@ from oauth2client.client import OAuth2Credentials
class Http2Mock(object):
"""Mock httplib2.Http for code exchange / refresh"""
def __init__(self, status=httplib.OK, **kwargs):
self.status = status
self.content = {

View File

@@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for oauth2client.gce.
Unit tests for oauth2client.gce.
@@ -36,86 +35,86 @@ from oauth2client.gce import AppAssertionCredentials
class AssertionCredentialsTests(unittest.TestCase):
def _refresh_success_helper(self, bytes_response=False):
access_token = u'this-is-a-token'
return_val = json.dumps({u'accessToken': access_token})
if bytes_response:
return_val = _to_bytes(return_val)
http = mock.MagicMock()
http.request = mock.MagicMock(
def _refresh_success_helper(self, bytes_response=False):
access_token = u'this-is-a-token'
return_val = json.dumps({u'accessToken': access_token})
if bytes_response:
return_val = _to_bytes(return_val)
http = mock.MagicMock()
http.request = mock.MagicMock(
return_value=(mock.Mock(status=200), return_val))
scopes = ['http://example.com/a', 'http://example.com/b']
credentials = AppAssertionCredentials(scope=scopes)
self.assertEquals(None, credentials.access_token)
credentials.refresh(http)
self.assertEquals(access_token, credentials.access_token)
scopes = ['http://example.com/a', 'http://example.com/b']
credentials = AppAssertionCredentials(scope=scopes)
self.assertEquals(None, credentials.access_token)
credentials.refresh(http)
self.assertEquals(access_token, credentials.access_token)
base_metadata_uri = ('http://metadata.google.internal/0.1/meta-data/'
base_metadata_uri = ('http://metadata.google.internal/0.1/meta-data/'
'service-accounts/default/acquire')
escaped_scopes = urllib.parse.quote(' '.join(scopes), safe='')
request_uri = base_metadata_uri + '?scope=' + escaped_scopes
http.request.assert_called_once_with(request_uri)
escaped_scopes = urllib.parse.quote(' '.join(scopes), safe='')
request_uri = base_metadata_uri + '?scope=' + escaped_scopes
http.request.assert_called_once_with(request_uri)
def test_refresh_success(self):
self._refresh_success_helper(bytes_response=False)
def test_refresh_success(self):
self._refresh_success_helper(bytes_response=False)
def test_refresh_success_bytes(self):
self._refresh_success_helper(bytes_response=True)
def test_refresh_success_bytes(self):
self._refresh_success_helper(bytes_response=True)
def test_fail_refresh(self):
http = mock.MagicMock()
http.request = mock.MagicMock(return_value=(mock.Mock(status=400), '{}'))
def test_fail_refresh(self):
http = mock.MagicMock()
http.request = mock.MagicMock(return_value=(mock.Mock(status=400), '{}'))
c = AppAssertionCredentials(scope=['http://example.com/a',
c = AppAssertionCredentials(scope=['http://example.com/a',
'http://example.com/b'])
self.assertRaises(AccessTokenRefreshError, c.refresh, http)
self.assertRaises(AccessTokenRefreshError, c.refresh, http)
def test_to_from_json(self):
c = AppAssertionCredentials(scope=['http://example.com/a',
def test_to_from_json(self):
c = AppAssertionCredentials(scope=['http://example.com/a',
'http://example.com/b'])
json = c.to_json()
c2 = Credentials.new_from_json(json)
json = c.to_json()
c2 = Credentials.new_from_json(json)
self.assertEqual(c.access_token, c2.access_token)
self.assertEqual(c.access_token, c2.access_token)
def test_create_scoped_required_without_scopes(self):
credentials = AppAssertionCredentials([])
self.assertTrue(credentials.create_scoped_required())
def test_create_scoped_required_without_scopes(self):
credentials = AppAssertionCredentials([])
self.assertTrue(credentials.create_scoped_required())
def test_create_scoped_required_with_scopes(self):
credentials = AppAssertionCredentials(['dummy_scope'])
self.assertFalse(credentials.create_scoped_required())
def test_create_scoped_required_with_scopes(self):
credentials = AppAssertionCredentials(['dummy_scope'])
self.assertFalse(credentials.create_scoped_required())
def test_create_scoped(self):
credentials = AppAssertionCredentials([])
new_credentials = credentials.create_scoped(['dummy_scope'])
self.assertNotEqual(credentials, new_credentials)
self.assertTrue(isinstance(new_credentials, AppAssertionCredentials))
self.assertEqual('dummy_scope', new_credentials.scope)
def test_create_scoped(self):
credentials = AppAssertionCredentials([])
new_credentials = credentials.create_scoped(['dummy_scope'])
self.assertNotEqual(credentials, new_credentials)
self.assertTrue(isinstance(new_credentials, AppAssertionCredentials))
self.assertEqual('dummy_scope', new_credentials.scope)
def test_get_access_token(self):
http = mock.MagicMock()
http.request = mock.MagicMock(
def test_get_access_token(self):
http = mock.MagicMock()
http.request = mock.MagicMock(
return_value=(mock.Mock(status=200),
'{"accessToken": "this-is-a-token"}'))
credentials = AppAssertionCredentials(['dummy_scope'])
token = credentials.get_access_token(http=http)
self.assertEqual('this-is-a-token', token.access_token)
self.assertEqual(None, token.expires_in)
credentials = AppAssertionCredentials(['dummy_scope'])
token = credentials.get_access_token(http=http)
self.assertEqual('this-is-a-token', token.access_token)
self.assertEqual(None, token.expires_in)
http.request.assert_called_once_with(
http.request.assert_called_once_with(
'http://metadata.google.internal/0.1/meta-data/service-accounts/'
'default/acquire?scope=dummy_scope')
def test_save_to_well_known_file(self):
import os
ORIGINAL_ISDIR = os.path.isdir
try:
os.path.isdir = lambda path: True
credentials = AppAssertionCredentials([])
self.assertRaises(NotImplementedError, save_to_well_known_file,
def test_save_to_well_known_file(self):
import os
ORIGINAL_ISDIR = os.path.isdir
try:
os.path.isdir = lambda path: True
credentials = AppAssertionCredentials([])
self.assertRaises(NotImplementedError, save_to_well_known_file,
credentials)
finally:
os.path.isdir = ORIGINAL_ISDIR
finally:
os.path.isdir = ORIGINAL_ISDIR

View File

@@ -19,9 +19,9 @@ import unittest
class ImportTest(unittest.TestCase):
def test_tools_import(self):
import oauth2client.tools
def test_tools_import(self):
import oauth2client.tools
if __name__ == '__main__':
unittest.main()
unittest.main()

View File

@@ -14,7 +14,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""Oauth2client tests
Unit tests for oauth2client.
@@ -42,295 +41,295 @@ from oauth2client.file import Storage
def datafile(filename):
f = open(os.path.join(os.path.dirname(__file__), 'data', filename), 'rb')
data = f.read()
f.close()
return data
f = open(os.path.join(os.path.dirname(__file__), 'data', filename), 'rb')
data = f.read()
f.close()
return data
class CryptTests(unittest.TestCase):
def setUp(self):
self.format = 'p12'
self.signer = crypt.OpenSSLSigner
self.verifier = crypt.OpenSSLVerifier
def setUp(self):
self.format = 'p12'
self.signer = crypt.OpenSSLSigner
self.verifier = crypt.OpenSSLVerifier
def test_sign_and_verify(self):
self._check_sign_and_verify('privatekey.%s' % self.format)
def test_sign_and_verify(self):
self._check_sign_and_verify('privatekey.%s' % self.format)
def test_sign_and_verify_from_converted_pkcs12(self):
# Tests that following instructions to convert from PKCS12 to PEM works.
if self.format == 'pem':
self._check_sign_and_verify('pem_from_pkcs12.pem')
def test_sign_and_verify_from_converted_pkcs12(self):
# Tests that following instructions to convert from PKCS12 to PEM works.
if self.format == 'pem':
self._check_sign_and_verify('pem_from_pkcs12.pem')
def _check_sign_and_verify(self, private_key_file):
private_key = datafile(private_key_file)
public_key = datafile('publickey.pem')
def _check_sign_and_verify(self, private_key_file):
private_key = datafile(private_key_file)
public_key = datafile('publickey.pem')
# We pass in a non-bytes password to make sure all branches
# are traversed in tests.
signer = self.signer.from_string(private_key,
# We pass in a non-bytes password to make sure all branches
# are traversed in tests.
signer = self.signer.from_string(private_key,
password=u'notasecret')
signature = signer.sign('foo')
signature = signer.sign('foo')
verifier = self.verifier.from_string(public_key, True)
self.assertTrue(verifier.verify(b'foo', signature))
verifier = self.verifier.from_string(public_key, True)
self.assertTrue(verifier.verify(b'foo', signature))
self.assertFalse(verifier.verify(b'bar', signature))
self.assertFalse(verifier.verify(b'foo', b'bad signagure'))
self.assertFalse(verifier.verify(b'foo', u'bad signagure'))
self.assertFalse(verifier.verify(b'bar', signature))
self.assertFalse(verifier.verify(b'foo', b'bad signagure'))
self.assertFalse(verifier.verify(b'foo', u'bad signagure'))
def _check_jwt_failure(self, jwt, expected_error):
public_key = datafile('publickey.pem')
certs = {'foo': public_key}
audience = ('https://www.googleapis.com/auth/id?client_id='
def _check_jwt_failure(self, jwt, expected_error):
public_key = datafile('publickey.pem')
certs = {'foo': public_key}
audience = ('https://www.googleapis.com/auth/id?client_id='
'external_public_key@testing.gserviceaccount.com')
try:
crypt.verify_signed_jwt_with_certs(jwt, certs, audience)
self.fail()
except crypt.AppIdentityError as e:
self.assertTrue(expected_error in str(e))
try:
crypt.verify_signed_jwt_with_certs(jwt, certs, audience)
self.fail()
except crypt.AppIdentityError as e:
self.assertTrue(expected_error in str(e))
def _create_signed_jwt(self):
private_key = datafile('privatekey.%s' % self.format)
signer = self.signer.from_string(private_key)
audience = 'some_audience_address@testing.gserviceaccount.com'
now = int(time.time())
def _create_signed_jwt(self):
private_key = datafile('privatekey.%s' % self.format)
signer = self.signer.from_string(private_key)
audience = 'some_audience_address@testing.gserviceaccount.com'
now = int(time.time())
return crypt.make_signed_jwt(signer, {
return crypt.make_signed_jwt(signer, {
'aud': audience,
'iat': now,
'exp': now + 300,
'user': 'billy bob',
'metadata': {'meta': 'data'},
})
})
def test_verify_id_token(self):
jwt = self._create_signed_jwt()
public_key = datafile('publickey.pem')
certs = {'foo': public_key}
audience = 'some_audience_address@testing.gserviceaccount.com'
contents = crypt.verify_signed_jwt_with_certs(jwt, certs, audience)
self.assertEqual('billy bob', contents['user'])
self.assertEqual('data', contents['metadata']['meta'])
def test_verify_id_token(self):
jwt = self._create_signed_jwt()
public_key = datafile('publickey.pem')
certs = {'foo': public_key}
audience = 'some_audience_address@testing.gserviceaccount.com'
contents = crypt.verify_signed_jwt_with_certs(jwt, certs, audience)
self.assertEqual('billy bob', contents['user'])
self.assertEqual('data', contents['metadata']['meta'])
def test_verify_id_token_with_certs_uri(self):
jwt = self._create_signed_jwt()
def test_verify_id_token_with_certs_uri(self):
jwt = self._create_signed_jwt()
http = HttpMockSequence([
http = HttpMockSequence([
({'status': '200'}, datafile('certs.json')),
])
])
contents = verify_id_token(
contents = verify_id_token(
jwt, 'some_audience_address@testing.gserviceaccount.com', http=http)
self.assertEqual('billy bob', contents['user'])
self.assertEqual('data', contents['metadata']['meta'])
self.assertEqual('billy bob', contents['user'])
self.assertEqual('data', contents['metadata']['meta'])
def test_verify_id_token_with_certs_uri_fails(self):
jwt = self._create_signed_jwt()
def test_verify_id_token_with_certs_uri_fails(self):
jwt = self._create_signed_jwt()
http = HttpMockSequence([
http = HttpMockSequence([
({'status': '404'}, datafile('certs.json')),
])
])
self.assertRaises(VerifyJwtTokenError, verify_id_token, jwt,
self.assertRaises(VerifyJwtTokenError, verify_id_token, jwt,
'some_audience_address@testing.gserviceaccount.com',
http=http)
def test_verify_id_token_bad_tokens(self):
private_key = datafile('privatekey.%s' % self.format)
def test_verify_id_token_bad_tokens(self):
private_key = datafile('privatekey.%s' % self.format)
# Wrong number of segments
self._check_jwt_failure('foo', 'Wrong number of segments')
# Wrong number of segments
self._check_jwt_failure('foo', 'Wrong number of segments')
# Not json
self._check_jwt_failure('foo.bar.baz', 'Can\'t parse token')
# Not json
self._check_jwt_failure('foo.bar.baz', 'Can\'t parse token')
# Bad signature
jwt = b'.'.join([b'foo', crypt._urlsafe_b64encode('{"a":"b"}'), b'baz'])
self._check_jwt_failure(jwt, 'Invalid token signature')
# Bad signature
jwt = b'.'.join([b'foo', crypt._urlsafe_b64encode('{"a":"b"}'), b'baz'])
self._check_jwt_failure(jwt, 'Invalid token signature')
# No expiration
signer = self.signer.from_string(private_key)
audience = ('https:#www.googleapis.com/auth/id?client_id='
# No expiration
signer = self.signer.from_string(private_key)
audience = ('https:#www.googleapis.com/auth/id?client_id='
'external_public_key@testing.gserviceaccount.com')
jwt = crypt.make_signed_jwt(signer, {
jwt = crypt.make_signed_jwt(signer, {
'aud': audience,
'iat': time.time(),
})
self._check_jwt_failure(jwt, 'No exp field in token')
})
self._check_jwt_failure(jwt, 'No exp field in token')
# No issued at
jwt = crypt.make_signed_jwt(signer, {
# No issued at
jwt = crypt.make_signed_jwt(signer, {
'aud': 'audience',
'exp': time.time() + 400,
})
self._check_jwt_failure(jwt, 'No iat field in token')
})
self._check_jwt_failure(jwt, 'No iat field in token')
# Too early
jwt = crypt.make_signed_jwt(signer, {
# Too early
jwt = crypt.make_signed_jwt(signer, {
'aud': 'audience',
'iat': time.time() + 301,
'exp': time.time() + 400,
})
self._check_jwt_failure(jwt, 'Token used too early')
})
self._check_jwt_failure(jwt, 'Token used too early')
# Too late
jwt = crypt.make_signed_jwt(signer, {
# Too late
jwt = crypt.make_signed_jwt(signer, {
'aud': 'audience',
'iat': time.time() - 500,
'exp': time.time() - 301,
})
self._check_jwt_failure(jwt, 'Token used too late')
})
self._check_jwt_failure(jwt, 'Token used too late')
# Wrong target
jwt = crypt.make_signed_jwt(signer, {
# Wrong target
jwt = crypt.make_signed_jwt(signer, {
'aud': 'somebody else',
'iat': time.time(),
'exp': time.time() + 300,
})
self._check_jwt_failure(jwt, 'Wrong recipient')
})
self._check_jwt_failure(jwt, 'Wrong recipient')
def test_from_string_non_509_cert(self):
# Use a private key instead of a certificate to test the other branch
# of from_string().
public_key = datafile('privatekey.pem')
verifier = self.verifier.from_string(public_key, is_x509_cert=False)
self.assertTrue(isinstance(verifier, self.verifier))
def test_from_string_non_509_cert(self):
# Use a private key instead of a certificate to test the other branch
# of from_string().
public_key = datafile('privatekey.pem')
verifier = self.verifier.from_string(public_key, is_x509_cert=False)
self.assertTrue(isinstance(verifier, self.verifier))
class PEMCryptTestsPyCrypto(CryptTests):
def setUp(self):
self.format = 'pem'
self.signer = crypt.PyCryptoSigner
self.verifier = crypt.PyCryptoVerifier
def setUp(self):
self.format = 'pem'
self.signer = crypt.PyCryptoSigner
self.verifier = crypt.PyCryptoVerifier
class PEMCryptTestsOpenSSL(CryptTests):
def setUp(self):
self.format = 'pem'
self.signer = crypt.OpenSSLSigner
self.verifier = crypt.OpenSSLVerifier
def setUp(self):
self.format = 'pem'
self.signer = crypt.OpenSSLSigner
self.verifier = crypt.OpenSSLVerifier
class SignedJwtAssertionCredentialsTests(unittest.TestCase):
def setUp(self):
self.format = 'p12'
crypt.Signer = crypt.OpenSSLSigner
def setUp(self):
self.format = 'p12'
crypt.Signer = crypt.OpenSSLSigner
def test_credentials_good(self):
private_key = datafile('privatekey.%s' % self.format)
credentials = SignedJwtAssertionCredentials(
def test_credentials_good(self):
private_key = datafile('privatekey.%s' % self.format)
credentials = SignedJwtAssertionCredentials(
'some_account@example.com',
private_key,
scope='read+write',
sub='joe@example.org')
http = HttpMockSequence([
http = HttpMockSequence([
({'status': '200'}, b'{"access_token":"1/3w","expires_in":3600}'),
({'status': '200'}, 'echo_request_headers'),
])
http = credentials.authorize(http)
resp, content = http.request('http://example.org')
self.assertEqual(b'Bearer 1/3w', content[b'Authorization'])
])
http = credentials.authorize(http)
resp, content = http.request('http://example.org')
self.assertEqual(b'Bearer 1/3w', content[b'Authorization'])
def test_credentials_to_from_json(self):
private_key = datafile('privatekey.%s' % self.format)
credentials = SignedJwtAssertionCredentials(
def test_credentials_to_from_json(self):
private_key = datafile('privatekey.%s' % self.format)
credentials = SignedJwtAssertionCredentials(
'some_account@example.com',
private_key,
scope='read+write',
sub='joe@example.org')
json = credentials.to_json()
restored = Credentials.new_from_json(json)
self.assertEqual(credentials.private_key, restored.private_key)
self.assertEqual(credentials.private_key_password,
json = credentials.to_json()
restored = Credentials.new_from_json(json)
self.assertEqual(credentials.private_key, restored.private_key)
self.assertEqual(credentials.private_key_password,
restored.private_key_password)
self.assertEqual(credentials.kwargs, restored.kwargs)
self.assertEqual(credentials.kwargs, restored.kwargs)
def _credentials_refresh(self, credentials):
http = HttpMockSequence([
def _credentials_refresh(self, credentials):
http = HttpMockSequence([
({'status': '200'}, b'{"access_token":"1/3w","expires_in":3600}'),
({'status': '401'}, b''),
({'status': '200'}, b'{"access_token":"3/3w","expires_in":3600}'),
({'status': '200'}, 'echo_request_headers'),
])
http = credentials.authorize(http)
_, content = http.request('http://example.org')
return content
])
http = credentials.authorize(http)
_, content = http.request('http://example.org')
return content
def test_credentials_refresh_without_storage(self):
private_key = datafile('privatekey.%s' % self.format)
credentials = SignedJwtAssertionCredentials(
def test_credentials_refresh_without_storage(self):
private_key = datafile('privatekey.%s' % self.format)
credentials = SignedJwtAssertionCredentials(
'some_account@example.com',
private_key,
scope='read+write',
sub='joe@example.org')
content = self._credentials_refresh(credentials)
content = self._credentials_refresh(credentials)
self.assertEqual(b'Bearer 3/3w', content[b'Authorization'])
self.assertEqual(b'Bearer 3/3w', content[b'Authorization'])
def test_credentials_refresh_with_storage(self):
private_key = datafile('privatekey.%s' % self.format)
credentials = SignedJwtAssertionCredentials(
def test_credentials_refresh_with_storage(self):
private_key = datafile('privatekey.%s' % self.format)
credentials = SignedJwtAssertionCredentials(
'some_account@example.com',
private_key,
scope='read+write',
sub='joe@example.org')
(filehandle, filename) = tempfile.mkstemp()
os.close(filehandle)
store = Storage(filename)
store.put(credentials)
credentials.set_store(store)
(filehandle, filename) = tempfile.mkstemp()
os.close(filehandle)
store = Storage(filename)
store.put(credentials)
credentials.set_store(store)
content = self._credentials_refresh(credentials)
content = self._credentials_refresh(credentials)
self.assertEqual(b'Bearer 3/3w', content[b'Authorization'])
os.unlink(filename)
self.assertEqual(b'Bearer 3/3w', content[b'Authorization'])
os.unlink(filename)
class PEMSignedJwtAssertionCredentialsOpenSSLTests(
SignedJwtAssertionCredentialsTests):
def setUp(self):
self.format = 'pem'
crypt.Signer = crypt.OpenSSLSigner
def setUp(self):
self.format = 'pem'
crypt.Signer = crypt.OpenSSLSigner
class PEMSignedJwtAssertionCredentialsPyCryptoTests(
SignedJwtAssertionCredentialsTests):
def setUp(self):
self.format = 'pem'
crypt.Signer = crypt.PyCryptoSigner
def setUp(self):
self.format = 'pem'
crypt.Signer = crypt.PyCryptoSigner
class PKCSSignedJwtAssertionCredentialsPyCryptoTests(unittest.TestCase):
def test_for_failure(self):
crypt.Signer = crypt.PyCryptoSigner
private_key = datafile('privatekey.p12')
credentials = SignedJwtAssertionCredentials(
def test_for_failure(self):
crypt.Signer = crypt.PyCryptoSigner
private_key = datafile('privatekey.p12')
credentials = SignedJwtAssertionCredentials(
'some_account@example.com',
private_key,
scope='read+write',
sub='joe@example.org')
try:
credentials._generate_assertion()
self.fail()
except NotImplementedError:
pass
try:
credentials._generate_assertion()
self.fail()
except NotImplementedError:
pass
class TestHasOpenSSLFlag(unittest.TestCase):
def test_true(self):
self.assertEqual(True, HAS_OPENSSL)
self.assertEqual(True, HAS_CRYPTO)
def test_true(self):
self.assertEqual(True, HAS_OPENSSL)
self.assertEqual(True, HAS_CRYPTO)
if __name__ == '__main__':
unittest.main()
unittest.main()

View File

@@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for oauth2client.keyring_storage tests.
Unit tests for oauth2client.keyring_storage.
@@ -33,59 +32,59 @@ from oauth2client.keyring_storage import Storage
class OAuth2ClientKeyringTests(unittest.TestCase):
def test_non_existent_credentials_storage(self):
with mock.patch.object(keyring, 'get_password',
def test_non_existent_credentials_storage(self):
with mock.patch.object(keyring, 'get_password',
return_value=None,
autospec=True) as get_password:
s = Storage('my_unit_test', 'me')
credentials = s.get()
self.assertEquals(None, credentials)
get_password.assert_called_once_with('my_unit_test', 'me')
s = Storage('my_unit_test', 'me')
credentials = s.get()
self.assertEquals(None, credentials)
get_password.assert_called_once_with('my_unit_test', 'me')
def test_malformed_credentials_in_storage(self):
with mock.patch.object(keyring, 'get_password',
def test_malformed_credentials_in_storage(self):
with mock.patch.object(keyring, 'get_password',
return_value='{',
autospec=True) as get_password:
s = Storage('my_unit_test', 'me')
credentials = s.get()
self.assertEquals(None, credentials)
get_password.assert_called_once_with('my_unit_test', 'me')
s = Storage('my_unit_test', 'me')
credentials = s.get()
self.assertEquals(None, credentials)
get_password.assert_called_once_with('my_unit_test', 'me')
def test_json_credentials_storage(self):
access_token = 'foo'
client_id = 'some_client_id'
client_secret = 'cOuDdkfjxxnv+'
refresh_token = '1/0/a.df219fjls0'
token_expiry = datetime.datetime.utcnow()
user_agent = 'refresh_checker/1.0'
def test_json_credentials_storage(self):
access_token = 'foo'
client_id = 'some_client_id'
client_secret = 'cOuDdkfjxxnv+'
refresh_token = '1/0/a.df219fjls0'
token_expiry = datetime.datetime.utcnow()
user_agent = 'refresh_checker/1.0'
credentials = OAuth2Credentials(
credentials = OAuth2Credentials(
access_token, client_id, client_secret,
refresh_token, token_expiry, GOOGLE_TOKEN_URI,
user_agent)
# Setting autospec on a mock with an iterable side_effect is
# currently broken (http://bugs.python.org/issue17826), so instead
# we patch twice.
with mock.patch.object(keyring, 'get_password',
# Setting autospec on a mock with an iterable side_effect is
# currently broken (http://bugs.python.org/issue17826), so instead
# we patch twice.
with mock.patch.object(keyring, 'get_password',
return_value=None,
autospec=True) as get_password:
with mock.patch.object(keyring, 'set_password',
with mock.patch.object(keyring, 'set_password',
return_value=None,
autospec=True) as set_password:
s = Storage('my_unit_test', 'me')
self.assertEquals(None, s.get())
s = Storage('my_unit_test', 'me')
self.assertEquals(None, s.get())
s.put(credentials)
s.put(credentials)
set_password.assert_called_once_with(
set_password.assert_called_once_with(
'my_unit_test', 'me', credentials.to_json())
get_password.assert_called_once_with('my_unit_test', 'me')
get_password.assert_called_once_with('my_unit_test', 'me')
with mock.patch.object(keyring, 'get_password',
with mock.patch.object(keyring, 'get_password',
return_value=credentials.to_json(),
autospec=True) as get_password:
restored = s.get()
self.assertEqual('foo', restored.access_token)
self.assertEqual('some_client_id', restored.client_id)
get_password.assert_called_once_with('my_unit_test', 'me')
restored = s.get()
self.assertEqual('foo', restored.access_token)
self.assertEqual('some_client_id', restored.client_id)
get_password.assert_called_once_with('my_unit_test', 'me')

File diff suppressed because it is too large Load Diff

View File

@@ -14,7 +14,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""Oauth2client tests.
Unit tests for service account credentials implemented using RSA.
@@ -31,94 +30,94 @@ from oauth2client.service_account import _ServiceAccountCredentials
def datafile(filename):
# TODO(orestica): Refactor this using pkgutil.get_data
f = open(os.path.join(os.path.dirname(__file__), 'data', filename), 'rb')
data = f.read()
f.close()
return data
# TODO(orestica): Refactor this using pkgutil.get_data
f = open(os.path.join(os.path.dirname(__file__), 'data', filename), 'rb')
data = f.read()
f.close()
return data
class ServiceAccountCredentialsTests(unittest.TestCase):
def setUp(self):
self.service_account_id = '123'
self.service_account_email = 'dummy@google.com'
self.private_key_id = 'ABCDEF'
self.private_key = datafile('pem_from_pkcs12.pem')
self.scopes = ['dummy_scope']
self.credentials = _ServiceAccountCredentials(self.service_account_id,
def setUp(self):
self.service_account_id = '123'
self.service_account_email = 'dummy@google.com'
self.private_key_id = 'ABCDEF'
self.private_key = datafile('pem_from_pkcs12.pem')
self.scopes = ['dummy_scope']
self.credentials = _ServiceAccountCredentials(self.service_account_id,
self.service_account_email,
self.private_key_id,
self.private_key,
[])
def test_sign_blob(self):
private_key_id, signature = self.credentials.sign_blob('Google')
self.assertEqual( self.private_key_id, private_key_id)
def test_sign_blob(self):
private_key_id, signature = self.credentials.sign_blob('Google')
self.assertEqual(self.private_key_id, private_key_id)
pub_key = rsa.PublicKey.load_pkcs1_openssl_pem(
pub_key = rsa.PublicKey.load_pkcs1_openssl_pem(
datafile('publickey_openssl.pem'))
self.assertTrue(rsa.pkcs1.verify(b'Google', signature, pub_key))
self.assertTrue(rsa.pkcs1.verify(b'Google', signature, pub_key))
try:
rsa.pkcs1.verify(b'Orest', signature, pub_key)
self.fail('Verification should have failed!')
except rsa.pkcs1.VerificationError:
pass # Expected
try:
rsa.pkcs1.verify(b'Orest', signature, pub_key)
self.fail('Verification should have failed!')
except rsa.pkcs1.VerificationError:
pass # Expected
try:
rsa.pkcs1.verify(b'Google', b'bad signature', pub_key)
self.fail('Verification should have failed!')
except rsa.pkcs1.VerificationError:
pass # Expected
try:
rsa.pkcs1.verify(b'Google', b'bad signature', pub_key)
self.fail('Verification should have failed!')
except rsa.pkcs1.VerificationError:
pass # Expected
def test_service_account_email(self):
self.assertEqual(self.service_account_email,
def test_service_account_email(self):
self.assertEqual(self.service_account_email,
self.credentials.service_account_email)
def test_create_scoped_required_without_scopes(self):
self.assertTrue(self.credentials.create_scoped_required())
def test_create_scoped_required_without_scopes(self):
self.assertTrue(self.credentials.create_scoped_required())
def test_create_scoped_required_with_scopes(self):
self.credentials = _ServiceAccountCredentials(self.service_account_id,
def test_create_scoped_required_with_scopes(self):
self.credentials = _ServiceAccountCredentials(self.service_account_id,
self.service_account_email,
self.private_key_id,
self.private_key,
self.scopes)
self.assertFalse(self.credentials.create_scoped_required())
self.assertFalse(self.credentials.create_scoped_required())
def test_create_scoped(self):
new_credentials = self.credentials.create_scoped(self.scopes)
self.assertNotEqual(self.credentials, new_credentials)
self.assertTrue(isinstance(new_credentials, _ServiceAccountCredentials))
self.assertEqual('dummy_scope', new_credentials._scopes)
def test_create_scoped(self):
new_credentials = self.credentials.create_scoped(self.scopes)
self.assertNotEqual(self.credentials, new_credentials)
self.assertTrue(isinstance(new_credentials, _ServiceAccountCredentials))
self.assertEqual('dummy_scope', new_credentials._scopes)
def test_access_token(self):
S = 2 # number of seconds in which the token expires
token_response_first = {'access_token': 'first_token', 'expires_in': S}
token_response_second = {'access_token': 'second_token', 'expires_in': S}
http = HttpMockSequence([
def test_access_token(self):
S = 2 # number of seconds in which the token expires
token_response_first = {'access_token': 'first_token', 'expires_in': S}
token_response_second = {'access_token': 'second_token', 'expires_in': S}
http = HttpMockSequence([
({'status': '200'}, json.dumps(token_response_first).encode('utf-8')),
({'status': '200'}, json.dumps(token_response_second).encode('utf-8')),
])
])
token = self.credentials.get_access_token(http=http)
self.assertEqual('first_token', token.access_token)
self.assertEqual(S - 1, token.expires_in)
self.assertFalse(self.credentials.access_token_expired)
self.assertEqual(token_response_first, self.credentials.token_response)
token = self.credentials.get_access_token(http=http)
self.assertEqual('first_token', token.access_token)
self.assertEqual(S - 1, token.expires_in)
self.assertFalse(self.credentials.access_token_expired)
self.assertEqual(token_response_first, self.credentials.token_response)
token = self.credentials.get_access_token(http=http)
self.assertEqual('first_token', token.access_token)
self.assertEqual(S - 1, token.expires_in)
self.assertFalse(self.credentials.access_token_expired)
self.assertEqual(token_response_first, self.credentials.token_response)
token = self.credentials.get_access_token(http=http)
self.assertEqual('first_token', token.access_token)
self.assertEqual(S - 1, token.expires_in)
self.assertFalse(self.credentials.access_token_expired)
self.assertEqual(token_response_first, self.credentials.token_response)
time.sleep(S + 0.5) # some margin to avoid flakiness
self.assertTrue(self.credentials.access_token_expired)
time.sleep(S + 0.5) # some margin to avoid flakiness
self.assertTrue(self.credentials.access_token_expired)
token = self.credentials.get_access_token(http=http)
self.assertEqual('second_token', token.access_token)
self.assertEqual(S - 1, token.expires_in)
self.assertFalse(self.credentials.access_token_expired)
self.assertEqual(token_response_second, self.credentials.token_response)
token = self.credentials.get_access_token(http=http)
self.assertEqual('second_token', token.access_token)
self.assertEqual(S - 1, token.expires_in)
self.assertFalse(self.credentials.access_token_expired)
self.assertEqual(token_response_second, self.credentials.token_response)

View File

@@ -5,6 +5,7 @@ from oauth2client import tools
from six.moves.urllib import request
import threading
class TestClientRedirectServer(unittest.TestCase):
"""Test the ClientRedirectServer and ClientRedirectHandler classes."""
@@ -15,16 +16,15 @@ class TestClientRedirectServer(unittest.TestCase):
httpd = tools.ClientRedirectServer(('localhost', 0), tools.ClientRedirectHandler)
code = 'foo'
url = 'http://localhost:%i?code=%s' % (httpd.server_address[1], code)
t = threading.Thread(target = httpd.handle_request)
t = threading.Thread(target=httpd.handle_request)
t.setDaemon(True)
t.start()
f = request.urlopen( url )
f = request.urlopen(url)
self.assertTrue(f.read())
t.join()
httpd.server_close()
self.assertEqual(httpd.query_params.get('code'),code)
self.assertEqual(httpd.query_params.get('code'), code)
if __name__ == '__main__':
unittest.main()

View File

@@ -9,8 +9,8 @@ from oauth2client import util
class ScopeToStringTests(unittest.TestCase):
def test_iterables(self):
cases = [
def test_iterables(self):
cases = [
('', ''),
('', ()),
('', []),
@@ -22,36 +22,37 @@ class ScopeToStringTests(unittest.TestCase):
('a b', ('a', 'b')),
('a b', 'a b'),
('a b', (s for s in ['a', 'b'])),
]
for expected, case in cases:
self.assertEqual(expected, util.scopes_to_string(case))
]
for expected, case in cases:
self.assertEqual(expected, util.scopes_to_string(case))
class StringToScopeTests(unittest.TestCase):
def test_conversion(self):
cases = [
def test_conversion(self):
cases = [
(['a', 'b'], ['a', 'b']),
('', []),
('a', ['a']),
('a b c d e f', ['a', 'b', 'c', 'd', 'e', 'f']),
]
]
for case, expected in cases:
self.assertEqual(expected, util.string_to_scopes(case))
for case, expected in cases:
self.assertEqual(expected, util.string_to_scopes(case))
class KeyConversionTests(unittest.TestCase):
def test_key_conversions(self):
d = {'somekey': 'some value', 'another': 'something else', 'onemore': 'foo'}
tuple_key = util.dict_to_tuple_key(d)
def test_key_conversions(self):
d = {'somekey': 'some value', 'another': 'something else', 'onemore': 'foo'}
tuple_key = util.dict_to_tuple_key(d)
# the resulting key should be naturally sorted
self.assertEqual(
# the resulting key should be naturally sorted
self.assertEqual(
(('another', 'something else'),
('onemore', 'foo'),
('somekey', 'some value')),
tuple_key)
# check we get the original dictionary back
self.assertEqual(d, dict(tuple_key))
# check we get the original dictionary back
self.assertEqual(d, dict(tuple_key))

View File

@@ -34,78 +34,78 @@ TEST_EXTRA_INFO_2 = 'more_extra_info'
class XsrfUtilTests(unittest.TestCase):
"""Test xsrfutil functions."""
"""Test xsrfutil functions."""
def testGenerateAndValidateToken(self):
"""Test generating and validating a token."""
token = xsrfutil.generate_token(TEST_KEY,
def testGenerateAndValidateToken(self):
"""Test generating and validating a token."""
token = xsrfutil.generate_token(TEST_KEY,
TEST_USER_ID_1,
action_id=TEST_ACTION_ID_1,
when=TEST_TIME)
# Check that the token is considered valid when it should be.
self.assertTrue(xsrfutil.validate_token(TEST_KEY,
# Check that the token is considered valid when it should be.
self.assertTrue(xsrfutil.validate_token(TEST_KEY,
token,
TEST_USER_ID_1,
action_id=TEST_ACTION_ID_1,
current_time=TEST_TIME))
# Should still be valid 15 minutes later.
later15mins = TEST_TIME + 15*60
self.assertTrue(xsrfutil.validate_token(TEST_KEY,
# Should still be valid 15 minutes later.
later15mins = TEST_TIME + 15 * 60
self.assertTrue(xsrfutil.validate_token(TEST_KEY,
token,
TEST_USER_ID_1,
action_id=TEST_ACTION_ID_1,
current_time=later15mins))
# But not if beyond the timeout.
later2hours = TEST_TIME + 2*60*60
self.assertFalse(xsrfutil.validate_token(TEST_KEY,
# But not if beyond the timeout.
later2hours = TEST_TIME + 2 * 60 * 60
self.assertFalse(xsrfutil.validate_token(TEST_KEY,
token,
TEST_USER_ID_1,
action_id=TEST_ACTION_ID_1,
current_time=later2hours))
# Or if the key is different.
self.assertFalse(xsrfutil.validate_token('another key',
# Or if the key is different.
self.assertFalse(xsrfutil.validate_token('another key',
token,
TEST_USER_ID_1,
action_id=TEST_ACTION_ID_1,
current_time=later15mins))
# Or the user ID....
self.assertFalse(xsrfutil.validate_token(TEST_KEY,
# Or the user ID....
self.assertFalse(xsrfutil.validate_token(TEST_KEY,
token,
TEST_USER_ID_2,
action_id=TEST_ACTION_ID_1,
current_time=later15mins))
# Or the action ID...
self.assertFalse(xsrfutil.validate_token(TEST_KEY,
# Or the action ID...
self.assertFalse(xsrfutil.validate_token(TEST_KEY,
token,
TEST_USER_ID_1,
action_id=TEST_ACTION_ID_2,
current_time=later15mins))
# Invalid when truncated
self.assertFalse(xsrfutil.validate_token(TEST_KEY,
# Invalid when truncated
self.assertFalse(xsrfutil.validate_token(TEST_KEY,
token[:-1],
TEST_USER_ID_1,
action_id=TEST_ACTION_ID_1,
current_time=later15mins))
# Invalid with extra garbage
self.assertFalse(xsrfutil.validate_token(TEST_KEY,
# Invalid with extra garbage
self.assertFalse(xsrfutil.validate_token(TEST_KEY,
token + b'x',
TEST_USER_ID_1,
action_id=TEST_ACTION_ID_1,
current_time=later15mins))
# Invalid with token of None
self.assertFalse(xsrfutil.validate_token(TEST_KEY,
# Invalid with token of None
self.assertFalse(xsrfutil.validate_token(TEST_KEY,
None,
TEST_USER_ID_1,
action_id=TEST_ACTION_ID_1))
if __name__ == '__main__':
unittest.main()
unittest.main()