Merge pull request #37 from craigcitro/imports2

More import/lint cleanup.
This commit is contained in:
Craig Citro
2014-08-18 10:52:38 -07:00
14 changed files with 114 additions and 154 deletions

View File

@@ -19,7 +19,6 @@ Utilities for making it easier to use OAuth 2.0 on Google App Engine.
__author__ = 'jcgregorio@google.com (Joe Gregorio)' __author__ = 'jcgregorio@google.com (Joe Gregorio)'
import base64
import cgi import cgi
import httplib2 import httplib2
import json import json
@@ -27,7 +26,6 @@ import logging
import os import os
import pickle import pickle
import threading import threading
import time
from google.appengine.api import app_identity from google.appengine.api import app_identity
from google.appengine.api import memcache from google.appengine.api import memcache

View File

@@ -20,10 +20,9 @@ Tools for interacting with OAuth 2.0 protected resources.
__author__ = 'jcgregorio@google.com (Joe Gregorio)' __author__ = 'jcgregorio@google.com (Joe Gregorio)'
import base64 import base64
import clientsecrets import collections
import copy import copy
import datetime import datetime
import httplib2
import json import json
import logging import logging
import os import os
@@ -32,10 +31,11 @@ import time
import urllib import urllib
import urlparse import urlparse
from collections import namedtuple import httplib2
from oauth2client import GOOGLE_AUTH_URI from oauth2client import GOOGLE_AUTH_URI
from oauth2client import GOOGLE_REVOKE_URI from oauth2client import GOOGLE_REVOKE_URI
from oauth2client import GOOGLE_TOKEN_URI from oauth2client import GOOGLE_TOKEN_URI
from oauth2client import clientsecrets
from oauth2client import util from oauth2client import util
HAS_OPENSSL = False HAS_OPENSSL = False
@@ -48,11 +48,6 @@ try:
except ImportError: except ImportError:
pass pass
try:
from urlparse import parse_qsl
except ImportError:
from cgi import parse_qsl
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
# Expiry is stored in RFC3339 UTC format # Expiry is stored in RFC3339 UTC format
@@ -81,7 +76,8 @@ SERVICE_ACCOUNT = 'service_account'
GOOGLE_APPLICATION_CREDENTIALS = 'GOOGLE_APPLICATION_CREDENTIALS' GOOGLE_APPLICATION_CREDENTIALS = 'GOOGLE_APPLICATION_CREDENTIALS'
# The access token along with the seconds in which it expires. # The access token along with the seconds in which it expires.
AccessTokenInfo = namedtuple('AccessTokenInfo', ['access_token', 'expires_in']) AccessTokenInfo = collections.namedtuple(
'AccessTokenInfo', ['access_token', 'expires_in'])
class Error(Exception): class Error(Exception):
"""Base error for this module.""" """Base error for this module."""
@@ -394,11 +390,11 @@ def _update_query_params(uri, params):
Returns: Returns:
The same URI but with the new query parameters added. The same URI but with the new query parameters added.
""" """
parts = list(urlparse.urlparse(uri)) parts = urlparse.urlparse(uri)
query_params = dict(parse_qsl(parts[4])) # 4 is the index of the query part query_params = dict(urlparse.parse_qsl(parts.query))
query_params.update(params) query_params.update(params)
parts[4] = urllib.urlencode(query_params) new_parts = parts._replace(query=urllib.urlencode(query_params))
return urlparse.urlunparse(parts) return urlparse.urlunparse(new_parts)
class OAuth2Credentials(Credentials): class OAuth2Credentials(Credentials):
@@ -1457,7 +1453,7 @@ def _parse_exchange_token_response(content):
except StandardError: except StandardError:
# different JSON libs raise different exceptions, # different JSON libs raise different exceptions,
# so we just do a catch-all here # so we just do a catch-all here
resp = dict(parse_qsl(content)) resp = dict(urlparse.parse_qsl(content))
# some providers respond with 'expires', others with 'expires_in' # some providers respond with 'expires', others with 'expires_in'
if resp and 'expires' in resp: if resp and 'expires' in resp:

View File

@@ -336,9 +336,8 @@ def verify_signed_jwt_with_certs(jwt, certs, audience):
""" """
segments = jwt.split('.') segments = jwt.split('.')
if (len(segments) != 3): if len(segments) != 3:
raise AppIdentityError( raise AppIdentityError('Wrong number of segments in token: %s' % jwt)
'Wrong number of segments in token: %s' % jwt)
signed = '%s.%s' % (segments[0], segments[1]) signed = '%s.%s' % (segments[0], segments[1])
signature = _urlsafe_b64decode(segments[2]) signature = _urlsafe_b64decode(segments[2])
@@ -352,9 +351,9 @@ def verify_signed_jwt_with_certs(jwt, certs, audience):
# Check signature. # Check signature.
verified = False verified = False
for (keyname, pem) in certs.items(): for _, pem in certs.items():
verifier = Verifier.from_string(pem, True) verifier = Verifier.from_string(pem, True)
if (verifier.verify(signed, signature)): if verifier.verify(signed, signature):
verified = True verified = True
break break
if not verified: if not verified:
@@ -372,16 +371,15 @@ def verify_signed_jwt_with_certs(jwt, certs, audience):
if exp is None: if exp is None:
raise AppIdentityError('No exp field in token: %s' % json_body) raise AppIdentityError('No exp field in token: %s' % json_body)
if exp >= now + MAX_TOKEN_LIFETIME_SECS: if exp >= now + MAX_TOKEN_LIFETIME_SECS:
raise AppIdentityError( raise AppIdentityError('exp field too far in future: %s' % json_body)
'exp field too far in future: %s' % json_body)
latest = exp + CLOCK_SKEW_SECS latest = exp + CLOCK_SKEW_SECS
if now < earliest: if now < earliest:
raise AppIdentityError('Token used too early, %d < %d: %s' % raise AppIdentityError('Token used too early, %d < %d: %s' %
(now, earliest, json_body)) (now, earliest, json_body))
if now > latest: if now > latest:
raise AppIdentityError('Token used too late, %d > %d: %s' % raise AppIdentityError('Token used too late, %d > %d: %s' %
(now, latest, json_body)) (now, latest, json_body))
# Check audience. # Check audience.
if audience is not None: if audience is not None:
@@ -390,6 +388,6 @@ def verify_signed_jwt_with_certs(jwt, certs, audience):
raise AppIdentityError('No aud field in token: %s' % json_body) raise AppIdentityError('No aud field in token: %s' % json_body)
if aud != audience: if aud != audience:
raise AppIdentityError('Wrong recipient, %s != %s: %s' % raise AppIdentityError('Wrong recipient, %s != %s: %s' %
(aud, audience, json_body)) (aud, audience, json_body))
return parsed return parsed

View File

@@ -21,11 +21,10 @@ credentials.
__author__ = 'jcgregorio@google.com (Joe Gregorio)' __author__ = 'jcgregorio@google.com (Joe Gregorio)'
import os import os
import stat
import threading import threading
from client import Storage as BaseStorage from oauth2client.client import Storage as BaseStorage
from client import Credentials from oauth2client.client import Credentials
class CredentialsFileSymbolicLinkError(Exception): class CredentialsFileSymbolicLinkError(Exception):

View File

@@ -19,7 +19,6 @@ Utilities for making it easier to use OAuth 2.0 on Google Compute Engine.
__author__ = 'jcgregorio@google.com (Joe Gregorio)' __author__ = 'jcgregorio@google.com (Joe Gregorio)'
import httplib2
import json import json
import logging import logging
import uritemplate import uritemplate
@@ -90,7 +89,7 @@ class AppAssertionCredentials(AssertionCredentials):
else: else:
if response.status == 404: if response.status == 404:
content = content + (' This can occur if a VM was created' content = content + (' This can occur if a VM was created'
' with no service account or scopes.') ' with no service account or scopes.')
raise AccessTokenRefreshError(content) raise AccessTokenRefreshError(content)
@property @property

View File

@@ -22,8 +22,8 @@ __author__ = 'jcgregorio@google.com (Joe Gregorio)'
import keyring import keyring
import threading import threading
from client import Storage as BaseStorage from oauth2client.client import Storage as BaseStorage
from client import Credentials from oauth2client.client import Credentials
class Storage(BaseStorage): class Storage(BaseStorage):

View File

@@ -70,6 +70,7 @@ class _Opener(object):
self._mode = mode self._mode = mode
self._fallback_mode = fallback_mode self._fallback_mode = fallback_mode
self._fh = None self._fh = None
self._lock_fd = None
def is_locked(self): def is_locked(self):
"""Was the file locked.""" """Was the file locked."""
@@ -141,8 +142,8 @@ class _PosixOpener(_Opener):
if e.errno != errno.EEXIST: if e.errno != errno.EEXIST:
raise raise
if (time.time() - start_time) >= timeout: if (time.time() - start_time) >= timeout:
logger.warn('Could not acquire lock %s in %s seconds' % ( logger.warn('Could not acquire lock %s in %s seconds',
lock_filename, timeout)) lock_filename, timeout)
# Close the file and open in fallback_mode. # Close the file and open in fallback_mode.
if self._fh: if self._fh:
self._fh.close() self._fh.close()
@@ -194,7 +195,7 @@ try:
self._fh = open(self._filename, self._mode) self._fh = open(self._filename, self._mode)
except IOError as e: except IOError as e:
# If we can't access with _mode, try _fallback_mode and don't lock. # If we can't access with _mode, try _fallback_mode and don't lock.
if e.errno in ( errno.EPERM, errno.EACCES ): if e.errno in (errno.EPERM, errno.EACCES):
self._fh = open(self._filename, self._fallback_mode) self._fh = open(self._filename, self._fallback_mode)
return return
@@ -212,8 +213,8 @@ try:
raise e raise e
# We could not acquire the lock. Try again. # We could not acquire the lock. Try again.
if (time.time() - start_time) >= timeout: if (time.time() - start_time) >= timeout:
logger.warn('Could not lock %s in %s seconds' % ( logger.warn('Could not lock %s in %s seconds',
self._filename, timeout)) self._filename, timeout)
if self._fh: if self._fh:
self._fh.close() self._fh.close()
self._fh = open(self._filename, self._fallback_mode) self._fh = open(self._filename, self._fallback_mode)

View File

@@ -43,8 +43,6 @@ The format of the stored data is like so:
__author__ = 'jbeda@google.com (Joe Beda)' __author__ = 'jbeda@google.com (Joe Beda)'
import base64
import errno
import json import json
import logging import logging
import os import os
@@ -53,7 +51,7 @@ import threading
from oauth2client.client import Storage as BaseStorage from oauth2client.client import Storage as BaseStorage
from oauth2client.client import Credentials from oauth2client.client import Credentials
from oauth2client import util from oauth2client import util
from locked_file import LockedFile from oauth2client.locked_file import LockedFile
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -286,7 +284,7 @@ class _MultiStore(object):
if self._warn_on_readonly: if self._warn_on_readonly:
logger.warn('The credentials file (%s) is not writable. Opening in ' logger.warn('The credentials file (%s) is not writable. Opening in '
'read-only mode. Any refreshed credentials will only be ' 'read-only mode. Any refreshed credentials will only be '
'valid for this run.' % self._file.filename()) 'valid for this run.', self._file.filename())
if os.path.getsize(self._file.filename()) == 0: if os.path.getsize(self._file.filename()) == 0:
logger.debug('Initializing empty multistore file') logger.debug('Initializing empty multistore file')
# The multistore is empty so write out an empty file. # The multistore is empty so write out an empty file.

View File

@@ -19,9 +19,7 @@ This credentials class is implemented on top of rsa library.
import base64 import base64
import json import json
import rsa
import time import time
import types
from oauth2client import GOOGLE_REVOKE_URI from oauth2client import GOOGLE_REVOKE_URI
from oauth2client import GOOGLE_TOKEN_URI from oauth2client import GOOGLE_TOKEN_URI
@@ -30,6 +28,7 @@ from oauth2client.client import AssertionCredentials
from pyasn1.codec.ber import decoder from pyasn1.codec.ber import decoder
from pyasn1_modules.rfc5208 import PrivateKeyInfo from pyasn1_modules.rfc5208 import PrivateKeyInfo
import rsa
class _ServiceAccountCredentials(AssertionCredentials): class _ServiceAccountCredentials(AssertionCredentials):
@@ -38,8 +37,9 @@ class _ServiceAccountCredentials(AssertionCredentials):
MAX_TOKEN_LIFETIME_SECS = 3600 # 1 hour in seconds MAX_TOKEN_LIFETIME_SECS = 3600 # 1 hour in seconds
def __init__(self, service_account_id, service_account_email, private_key_id, def __init__(self, service_account_id, service_account_email, private_key_id,
private_key_pkcs8_text, scopes, user_agent=None, private_key_pkcs8_text, scopes, user_agent=None,
token_uri=GOOGLE_TOKEN_URI, revoke_uri=GOOGLE_REVOKE_URI, **kwargs): token_uri=GOOGLE_TOKEN_URI, revoke_uri=GOOGLE_REVOKE_URI,
**kwargs):
super(_ServiceAccountCredentials, self).__init__( super(_ServiceAccountCredentials, self).__init__(
None, user_agent=user_agent, token_uri=token_uri, revoke_uri=revoke_uri) None, user_agent=user_agent, token_uri=token_uri, revoke_uri=revoke_uri)

View File

@@ -25,22 +25,15 @@ __all__ = ['argparser', 'run_flow', 'run', 'message_if_missing']
import BaseHTTPServer import BaseHTTPServer
import argparse import argparse
import httplib2
import logging import logging
import os
import socket import socket
import sys import sys
import urlparse
import webbrowser import webbrowser
from oauth2client import client from oauth2client import client
from oauth2client import file
from oauth2client import util from oauth2client import util
try:
from urlparse import parse_qsl
except ImportError:
from cgi import parse_qsl
_CLIENT_SECRETS_MESSAGE = """WARNING: Please configure OAuth 2.0 _CLIENT_SECRETS_MESSAGE = """WARNING: Please configure OAuth 2.0
To make this sample run you will need to populate the client_secrets.json file To make this sample run you will need to populate the client_secrets.json file
@@ -57,15 +50,15 @@ with information from the APIs Console <https://code.google.com/apis/console>.
# ArgumentParser. # ArgumentParser.
argparser = argparse.ArgumentParser(add_help=False) argparser = argparse.ArgumentParser(add_help=False)
argparser.add_argument('--auth_host_name', default='localhost', argparser.add_argument('--auth_host_name', default='localhost',
help='Hostname when running a local web server.') help='Hostname when running a local web server.')
argparser.add_argument('--noauth_local_webserver', action='store_true', argparser.add_argument('--noauth_local_webserver', action='store_true',
default=False, help='Do not run a local web server.') default=False, help='Do not run a local web server.')
argparser.add_argument('--auth_host_port', default=[8080, 8090], type=int, argparser.add_argument('--auth_host_port', default=[8080, 8090], type=int,
nargs='*', help='Port web server should listen on.') nargs='*', help='Port web server should listen on.')
argparser.add_argument('--logging_level', default='ERROR', argparser.add_argument('--logging_level', default='ERROR',
choices=['DEBUG', 'INFO', 'WARNING', 'ERROR', choices=['DEBUG', 'INFO', 'WARNING', 'ERROR',
'CRITICAL'], 'CRITICAL'],
help='Set the logging level of detail.') help='Set the logging level of detail.')
class ClientRedirectServer(BaseHTTPServer.HTTPServer): class ClientRedirectServer(BaseHTTPServer.HTTPServer):
@@ -84,26 +77,25 @@ class ClientRedirectHandler(BaseHTTPServer.BaseHTTPRequestHandler):
into the servers query_params and then stops serving. into the servers query_params and then stops serving.
""" """
def do_GET(s): def do_GET(self):
"""Handle a GET request. """Handle a GET request.
Parses the query parameters and prints a message Parses the query parameters and prints a message
if the flow has completed. Note that we can't detect if the flow has completed. Note that we can't detect
if an error occurred. if an error occurred.
""" """
s.send_response(200) self.send_response(200)
s.send_header("Content-type", "text/html") self.send_header("Content-type", "text/html")
s.end_headers() self.end_headers()
query = s.path.split('?', 1)[-1] query = self.path.split('?', 1)[-1]
query = dict(parse_qsl(query)) query = dict(urlparse.parse_qsl(query))
s.server.query_params = query self.server.query_params = query
s.wfile.write("<html><head><title>Authentication Status</title></head>") self.wfile.write("<html><head><title>Authentication Status</title></head>")
s.wfile.write("<body><p>The authentication flow has completed.</p>") self.wfile.write("<body><p>The authentication flow has completed.</p>")
s.wfile.write("</body></html>") self.wfile.write("</body></html>")
def log_message(self, format, *args): def log_message(self, format, *args):
"""Do not log messages to stdout while running as command line program.""" """Do not log messages to stdout while running as command line program."""
pass
@util.positional(3) @util.positional(3)
@@ -233,8 +225,8 @@ def message_if_missing(filename):
return _CLIENT_SECRETS_MESSAGE % filename return _CLIENT_SECRETS_MESSAGE % filename
try: try:
from old_run import run from oauth2client.old_run import run
from old_run import FLAGS from oauth2client.old_run import FLAGS
except ImportError: except ImportError:
def run(*args, **kwargs): def run(*args, **kwargs):
raise NotImplementedError( raise NotImplementedError(

View File

@@ -17,14 +17,16 @@
"""Common utility library.""" """Common utility library."""
__author__ = ['rafek@google.com (Rafe Kaplan)', __author__ = [
'guido@google.com (Guido van Rossum)', 'rafek@google.com (Rafe Kaplan)',
'guido@google.com (Guido van Rossum)',
] ]
__all__ = [ __all__ = [
'positional', 'positional',
'POSITIONAL_WARNING', 'POSITIONAL_WARNING',
'POSITIONAL_EXCEPTION', 'POSITIONAL_EXCEPTION',
'POSITIONAL_IGNORE', 'POSITIONAL_IGNORE',
] ]
import inspect import inspect
@@ -33,11 +35,6 @@ import types
import urllib import urllib
import urlparse import urlparse
try:
from urlparse import parse_qsl
except ImportError:
from cgi import parse_qsl
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
POSITIONAL_WARNING = 'WARNING' POSITIONAL_WARNING = 'WARNING'
@@ -190,7 +187,7 @@ def _add_query_parameter(url, name, value):
return url return url
else: else:
parsed = list(urlparse.urlparse(url)) parsed = list(urlparse.urlparse(url))
q = dict(parse_qsl(parsed[4])) q = dict(urlparse.parse_qsl(parsed[4]))
q[name] = value q[name] = value
parsed[4] = urllib.urlencode(q) parsed[4] = urllib.urlencode(q)
return urlparse.urlunparse(parsed) return urlparse.urlunparse(parsed)

View File

@@ -17,14 +17,13 @@
"""Helper methods for creating & verifying XSRF tokens.""" """Helper methods for creating & verifying XSRF tokens."""
__authors__ = [ __authors__ = [
'"Doug Coker" <dcoker@google.com>', '"Doug Coker" <dcoker@google.com>',
'"Joe Gregorio" <jcgregorio@google.com>', '"Joe Gregorio" <jcgregorio@google.com>',
] ]
import base64 import base64
import hmac import hmac
import os # for urandom
import time import time
from oauth2client import util from oauth2client import util

View File

@@ -22,23 +22,18 @@ Unit tests for objects created from discovery documents.
__author__ = 'jcgregorio@google.com (Joe Gregorio)' __author__ = 'jcgregorio@google.com (Joe Gregorio)'
import base64
import datetime import datetime
import httplib2 import httplib2
import json import json
import mox
import os import os
import time import time
import unittest import unittest
import urllib import urllib
import urlparse
try:
from urlparse import parse_qs
except ImportError:
from cgi import parse_qs
import dev_appserver import dev_appserver
dev_appserver.fix_sys_path() dev_appserver.fix_sys_path()
import mox
import webapp2 import webapp2
from google.appengine.api import apiproxy_stub from google.appengine.api import apiproxy_stub
@@ -559,7 +554,7 @@ class DecoratorTests(unittest.TestCase):
self.assertEqual(self.decorator.credentials, None) self.assertEqual(self.decorator.credentials, None)
response = self.app.get('http://localhost/foo_path') response = self.app.get('http://localhost/foo_path')
self.assertTrue(response.status.startswith('302')) self.assertTrue(response.status.startswith('302'))
q = parse_qs(response.headers['Location'].split('?', 1)[1]) q = urlparse.parse_qs(response.headers['Location'].split('?', 1)[1])
self.assertEqual('http://localhost/oauth2callback', q['redirect_uri'][0]) self.assertEqual('http://localhost/oauth2callback', q['redirect_uri'][0])
self.assertEqual('foo_client_id', q['client_id'][0]) self.assertEqual('foo_client_id', q['client_id'][0])
self.assertEqual('foo_scope bar_scope', q['scope'][0]) self.assertEqual('foo_scope bar_scope', q['scope'][0])
@@ -583,7 +578,8 @@ class DecoratorTests(unittest.TestCase):
self.assertEqual('http://localhost/foo_path', parts[0]) self.assertEqual('http://localhost/foo_path', parts[0])
self.assertEqual(None, self.decorator.credentials) self.assertEqual(None, self.decorator.credentials)
if self.decorator._token_response_param: if self.decorator._token_response_param:
response = parse_qs(parts[1])[self.decorator._token_response_param][0] response = urlparse.parse_qs(
parts[1])[self.decorator._token_response_param][0]
self.assertEqual(Http2Mock.content, json.loads(urllib.unquote(response))) self.assertEqual(Http2Mock.content, json.loads(urllib.unquote(response)))
self.assertEqual(self.decorator.flow, self.decorator._tls.flow) self.assertEqual(self.decorator.flow, self.decorator._tls.flow)
self.assertEqual(self.decorator.credentials, self.assertEqual(self.decorator.credentials,
@@ -619,7 +615,7 @@ class DecoratorTests(unittest.TestCase):
# Invalid Credentials should start the OAuth dance again. # Invalid Credentials should start the OAuth dance again.
response = self.app.get('/foo_path') response = self.app.get('/foo_path')
self.assertTrue(response.status.startswith('302')) self.assertTrue(response.status.startswith('302'))
q = parse_qs(response.headers['Location'].split('?', 1)[1]) q = urlparse.parse_qs(response.headers['Location'].split('?', 1)[1])
self.assertEqual('http://localhost/oauth2callback', q['redirect_uri'][0]) self.assertEqual('http://localhost/oauth2callback', q['redirect_uri'][0])
def test_storage_delete(self): def test_storage_delete(self):
@@ -667,7 +663,7 @@ class DecoratorTests(unittest.TestCase):
self.assertEqual('200 OK', response.status) self.assertEqual('200 OK', response.status)
self.assertEqual(False, self.decorator.has_credentials()) self.assertEqual(False, self.decorator.has_credentials())
url = self.decorator.authorize_url() url = self.decorator.authorize_url()
q = parse_qs(url.split('?', 1)[1]) q = urlparse.parse_qs(url.split('?', 1)[1])
self.assertEqual('http://localhost/oauth2callback', q['redirect_uri'][0]) self.assertEqual('http://localhost/oauth2callback', q['redirect_uri'][0])
self.assertEqual('foo_client_id', q['client_id'][0]) self.assertEqual('foo_client_id', q['client_id'][0])
self.assertEqual('foo_scope bar_scope', q['scope'][0]) self.assertEqual('foo_scope bar_scope', q['scope'][0])

View File

@@ -22,20 +22,12 @@ Unit tests for oauth2client.
__author__ = 'jcgregorio@google.com (Joe Gregorio)' __author__ = 'jcgregorio@google.com (Joe Gregorio)'
import httplib2
import os import os
import sys
import tempfile import tempfile
import time import time
import unittest import unittest
import urlparse
try: from tests.http_mock import HttpMockSequence
from urlparse import parse_qs
except ImportError:
from cgi import parse_qs
from http_mock import HttpMockSequence
from oauth2client import crypt from oauth2client import crypt
from oauth2client.client import Credentials from oauth2client.client import Credentials
from oauth2client.client import SignedJwtAssertionCredentials from oauth2client.client import SignedJwtAssertionCredentials
@@ -79,20 +71,18 @@ class CryptTests(unittest.TestCase):
self.assertTrue(verifier.verify('foo', signature)) self.assertTrue(verifier.verify('foo', signature))
self.assertFalse(verifier.verify('bar', signature)) self.assertFalse(verifier.verify('bar', signature))
self.assertFalse(verifier.verify('foo', 'bad signagure')) self.assertFalse(verifier.verify('foo', 'bad signature'))
def _check_jwt_failure(self, jwt, expected_error): def _check_jwt_failure(self, jwt, expected_error):
try: public_key = datafile('publickey.pem')
public_key = datafile('publickey.pem') certs = {'foo': public_key}
certs = {'foo': public_key} audience = ('https://www.googleapis.com/auth/id?client_id='
audience = 'https://www.googleapis.com/auth/id?client_id=' + \ 'external_public_key@testing.gserviceaccount.com')
'external_public_key@testing.gserviceaccount.com' try:
contents = crypt.verify_signed_jwt_with_certs(jwt, certs, audience) crypt.verify_signed_jwt_with_certs(jwt, certs, audience)
self.fail('Should have thrown for %s' % jwt) self.fail()
except: except crypt.AppIdentityError as e:
e = sys.exc_info()[1] self.assertTrue(expected_error in str(e))
msg = e.args[0]
self.assertTrue(expected_error in msg)
def _create_signed_jwt(self): def _create_signed_jwt(self):
private_key = datafile('privatekey.%s' % self.format) private_key = datafile('privatekey.%s' % self.format)
@@ -100,20 +90,18 @@ class CryptTests(unittest.TestCase):
audience = 'some_audience_address@testing.gserviceaccount.com' audience = 'some_audience_address@testing.gserviceaccount.com'
now = long(time.time()) now = long(time.time())
return crypt.make_signed_jwt( return crypt.make_signed_jwt(signer, {
signer,
{
'aud': audience, 'aud': audience,
'iat': now, 'iat': now,
'exp': now + 300, 'exp': now + 300,
'user': 'billy bob', 'user': 'billy bob',
'metadata': {'meta': 'data'}, 'metadata': {'meta': 'data'},
}) })
def test_verify_id_token(self): def test_verify_id_token(self):
jwt = self._create_signed_jwt() jwt = self._create_signed_jwt()
public_key = datafile('publickey.pem') public_key = datafile('publickey.pem')
certs = {'foo': public_key } certs = {'foo': public_key}
audience = 'some_audience_address@testing.gserviceaccount.com' audience = 'some_audience_address@testing.gserviceaccount.com'
contents = crypt.verify_signed_jwt_with_certs(jwt, certs, audience) contents = crypt.verify_signed_jwt_with_certs(jwt, certs, audience)
self.assertEqual('billy bob', contents['user']) self.assertEqual('billy bob', contents['user'])
@@ -123,11 +111,11 @@ class CryptTests(unittest.TestCase):
jwt = self._create_signed_jwt() jwt = self._create_signed_jwt()
http = HttpMockSequence([ http = HttpMockSequence([
({'status': '200'}, datafile('certs.json')), ({'status': '200'}, datafile('certs.json')),
]) ])
contents = verify_id_token(jwt, contents = verify_id_token(
'some_audience_address@testing.gserviceaccount.com', http=http) jwt, 'some_audience_address@testing.gserviceaccount.com', http=http)
self.assertEqual('billy bob', contents['user']) self.assertEqual('billy bob', contents['user'])
self.assertEqual('data', contents['metadata']['meta']) self.assertEqual('data', contents['metadata']['meta'])
@@ -135,11 +123,12 @@ class CryptTests(unittest.TestCase):
jwt = self._create_signed_jwt() jwt = self._create_signed_jwt()
http = HttpMockSequence([ http = HttpMockSequence([
({'status': '404'}, datafile('certs.json')), ({'status': '404'}, datafile('certs.json')),
]) ])
self.assertRaises(VerifyJwtTokenError, verify_id_token, jwt, self.assertRaises(VerifyJwtTokenError, verify_id_token, jwt,
'some_audience_address@testing.gserviceaccount.com', http=http) 'some_audience_address@testing.gserviceaccount.com',
http=http)
def test_verify_id_token_bad_tokens(self): def test_verify_id_token_bad_tokens(self):
private_key = datafile('privatekey.%s' % self.format) private_key = datafile('privatekey.%s' % self.format)
@@ -148,8 +137,7 @@ class CryptTests(unittest.TestCase):
self._check_jwt_failure('foo', 'Wrong number of segments') self._check_jwt_failure('foo', 'Wrong number of segments')
# Not json # Not json
self._check_jwt_failure('foo.bar.baz', self._check_jwt_failure('foo.bar.baz', 'Can\'t parse token')
'Can\'t parse token')
# Bad signature # Bad signature
jwt = 'foo.%s.baz' % crypt._urlsafe_b64encode('{"a":"b"}') jwt = 'foo.%s.baz' % crypt._urlsafe_b64encode('{"a":"b"}')
@@ -157,21 +145,19 @@ class CryptTests(unittest.TestCase):
# No expiration # No expiration
signer = self.signer.from_string(private_key) signer = self.signer.from_string(private_key)
audience = 'https:#www.googleapis.com/auth/id?client_id=' + \ audience = ('https:#www.googleapis.com/auth/id?client_id='
'external_public_key@testing.gserviceaccount.com' 'external_public_key@testing.gserviceaccount.com')
jwt = crypt.make_signed_jwt(signer, { jwt = crypt.make_signed_jwt(signer, {
'aud': 'audience', 'aud': audience,
'iat': time.time(), 'iat': time.time(),
} })
)
self._check_jwt_failure(jwt, 'No exp field in token') self._check_jwt_failure(jwt, 'No exp field in token')
# No issued at # No issued at
jwt = crypt.make_signed_jwt(signer, { jwt = crypt.make_signed_jwt(signer, {
'aud': 'audience', 'aud': 'audience',
'exp': time.time() + 400, 'exp': time.time() + 400,
} })
)
self._check_jwt_failure(jwt, 'No iat field in token') self._check_jwt_failure(jwt, 'No iat field in token')
# Too early # Too early
@@ -226,11 +212,11 @@ class SignedJwtAssertionCredentialsTests(unittest.TestCase):
scope='read+write', scope='read+write',
sub='joe@example.org') sub='joe@example.org')
http = HttpMockSequence([ http = HttpMockSequence([
({'status': '200'}, '{"access_token":"1/3w","expires_in":3600}'), ({'status': '200'}, '{"access_token":"1/3w","expires_in":3600}'),
({'status': '200'}, 'echo_request_headers'), ({'status': '200'}, 'echo_request_headers'),
]) ])
http = credentials.authorize(http) http = credentials.authorize(http)
resp, content = http.request('http://example.org') _, content = http.request('http://example.org')
self.assertEqual('Bearer 1/3w', content['Authorization']) self.assertEqual('Bearer 1/3w', content['Authorization'])
def test_credentials_to_from_json(self): def test_credentials_to_from_json(self):
@@ -249,13 +235,13 @@ class SignedJwtAssertionCredentialsTests(unittest.TestCase):
def _credentials_refresh(self, credentials): def _credentials_refresh(self, credentials):
http = HttpMockSequence([ http = HttpMockSequence([
({'status': '200'}, '{"access_token":"1/3w","expires_in":3600}'), ({'status': '200'}, '{"access_token":"1/3w","expires_in":3600}'),
({'status': '401'}, ''), ({'status': '401'}, ''),
({'status': '200'}, '{"access_token":"3/3w","expires_in":3600}'), ({'status': '200'}, '{"access_token":"3/3w","expires_in":3600}'),
({'status': '200'}, 'echo_request_headers'), ({'status': '200'}, 'echo_request_headers'),
]) ])
http = credentials.authorize(http) http = credentials.authorize(http)
resp, content = http.request('http://example.org') _, content = http.request('http://example.org')
return content return content
def test_credentials_refresh_without_storage(self): def test_credentials_refresh_without_storage(self):
@@ -319,6 +305,7 @@ class PKCSSignedJwtAssertionCredentialsPyCryptoTests(unittest.TestCase):
except NotImplementedError: except NotImplementedError:
pass pass
class TestHasOpenSSLFlag(unittest.TestCase): class TestHasOpenSSLFlag(unittest.TestCase):
def test_true(self): def test_true(self):
self.assertEqual(True, HAS_OPENSSL) self.assertEqual(True, HAS_OPENSSL)