diff --git a/oauth2client/appengine.py b/oauth2client/appengine.py index ca2ff15..12ba04d 100644 --- a/oauth2client/appengine.py +++ b/oauth2client/appengine.py @@ -446,7 +446,7 @@ class OAuth2DecoratorFromClientSecrets(OAuth2Decorator): # in API calls """ - def __init__(self, filename, scope, message=None): + def __init__(self, filename, scope, message=None, cache=None): """Constructor Args: @@ -457,9 +457,11 @@ class OAuth2DecoratorFromClientSecrets(OAuth2Decorator): clientsecrets file is missing or invalid. The message may contain HTML and will be presented on the web interface for any method that uses the decorator. + cache: An optional cache service client that implements get() and set() + methods. See clientsecrets.loadfile() for details. """ try: - client_type, client_info = clientsecrets.loadfile(filename) + client_type, client_info = clientsecrets.loadfile(filename, cache=cache) if client_type not in [clientsecrets.TYPE_WEB, clientsecrets.TYPE_INSTALLED]: raise InvalidClientSecretsError('OAuth2Decorator doesn\'t support this OAuth 2.0 flow.') super(OAuth2DecoratorFromClientSecrets, @@ -478,7 +480,8 @@ class OAuth2DecoratorFromClientSecrets(OAuth2Decorator): self._message = "Please configure your application for OAuth 2.0" -def oauth2decorator_from_clientsecrets(filename, scope, message=None): +def oauth2decorator_from_clientsecrets(filename, scope, + message=None, cache=None): """Creates an OAuth2Decorator populated from a clientsecrets file. Args: @@ -489,11 +492,14 @@ def oauth2decorator_from_clientsecrets(filename, scope, message=None): clientsecrets file is missing or invalid. The message may contain HTML and will be presented on the web interface for any method that uses the decorator. + cache: An optional cache service client that implements get() and set() + methods. See clientsecrets.loadfile() for details. Returns: An OAuth2Decorator """ - return OAuth2DecoratorFromClientSecrets(filename, scope, message) + return OAuth2DecoratorFromClientSecrets(filename, scope, + message=message, cache=cache) class OAuth2Handler(webapp.RequestHandler): diff --git a/oauth2client/client.py b/oauth2client/client.py index 5ccaccd..e701f02 100644 --- a/oauth2client/client.py +++ b/oauth2client/client.py @@ -955,7 +955,8 @@ def credentials_from_code(client_id, client_secret, scope, code, def credentials_from_clientsecrets_and_code(filename, scope, code, message = None, redirect_uri = 'postmessage', - http=None): + http=None, + cache=None): """Returns OAuth2Credentials from a clientsecrets file and an auth code. Will create the right kind of Flow based on the contents of the clientsecrets @@ -973,6 +974,8 @@ def credentials_from_clientsecrets_and_code(filename, scope, code, redirect_uri: string, this is generally set to 'postmessage' to match the redirect_uri that the client specified http: httplib2.Http, optional http instance to use to do the fetch + cache: An optional cache service client that implements get() and set() + methods. See clientsecrets.loadfile() for details. Returns: An OAuth2Credentials object. @@ -984,7 +987,7 @@ def credentials_from_clientsecrets_and_code(filename, scope, code, clientsecrets.InvalidClientSecretsError if the clientsecrets file is invalid. """ - flow = flow_from_clientsecrets(filename, scope, message) + flow = flow_from_clientsecrets(filename, scope, message=message, cache=cache) # We primarily make this call to set up the redirect_uri in the flow object uriThatWeDontReallyUse = flow.step1_get_authorize_url(redirect_uri) credentials = flow.step2_exchange(code, http) @@ -1130,7 +1133,7 @@ class OAuth2WebServerFlow(Flow): error_msg = 'Invalid response: %s.' % str(resp.status) raise FlowExchangeError(error_msg) -def flow_from_clientsecrets(filename, scope, message=None): +def flow_from_clientsecrets(filename, scope, message=None, cache=None): """Create a Flow from a clientsecrets file. Will create the right kind of Flow based on the contents of the clientsecrets @@ -1143,6 +1146,8 @@ def flow_from_clientsecrets(filename, scope, message=None): clientsecrets file is missing or invalid. If message is provided then sys.exit will be called in the case of an error. If message in not provided then clientsecrets.InvalidClientSecretsError will be raised. + cache: An optional cache service client that implements get() and set() + methods. See clientsecrets.loadfile() for details. Returns: A Flow object. @@ -1153,7 +1158,7 @@ def flow_from_clientsecrets(filename, scope, message=None): invalid. """ try: - client_type, client_info = clientsecrets.loadfile(filename) + client_type, client_info = clientsecrets.loadfile(filename, cache=cache) if client_type in [clientsecrets.TYPE_WEB, clientsecrets.TYPE_INSTALLED]: return OAuth2WebServerFlow( client_info['client_id'], diff --git a/oauth2client/clientsecrets.py b/oauth2client/clientsecrets.py index 1327a2a..428c5ec 100644 --- a/oauth2client/clientsecrets.py +++ b/oauth2client/clientsecrets.py @@ -93,7 +93,7 @@ def loads(s): return _validate_clientsecrets(obj) -def loadfile(filename): +def _loadfile(filename): try: fp = file(filename, 'r') try: @@ -103,3 +103,48 @@ def loadfile(filename): except IOError: raise InvalidClientSecretsError('File not found: "%s"' % filename) return _validate_clientsecrets(obj) + + +def loadfile(filename, cache=None): + """Loading of client_secrets JSON file, optionally backed by a cache. + + Typical cache storage would be App Engine memcache service, + but you can pass in any other cache client that implements + these methods: + - get(key, namespace=ns) + - set(key, value, namespace=ns) + + Usage: + # without caching + client_type, client_info = loadfile('secrets.json') + # using App Engine memcache service + from google.appengine.api import memcache + client_type, client_info = loadfile('secrets.json', cache=memcache) + + Args: + filename: string, Path to a client_secrets.json file on a filesystem. + cache: An optional cache service client that implements get() and set() + methods. If not specified, the file is always being loaded from + a filesystem. + + Raises: + InvalidClientSecretsError: In case of a validation error or some + I/O failure. Can happen only on cache miss. + + Returns: + (client_type, client_info) tuple, as _loadfile() normally would. + JSON contents is validated only during first load. Cache hits are not + validated. + """ + _SECRET_NAMESPACE = 'oauth2client:secrets#ns' + + if not cache: + return _loadfile(filename) + + obj = cache.get(filename, namespace=_SECRET_NAMESPACE) + if obj is None: + client_type, client_info = _loadfile(filename) + obj = { client_type: client_info } + cache.set(filename, obj, namespace=_SECRET_NAMESPACE) + + return obj.iteritems().next() diff --git a/tests/test_oauth2client.py b/tests/test_oauth2client.py index 49433df..faccccf 100644 --- a/tests/test_oauth2client.py +++ b/tests/test_oauth2client.py @@ -36,6 +36,7 @@ except ImportError: from apiclient.http import HttpMockSequence from oauth2client.anyjson import simplejson +from oauth2client.clientsecrets import _loadfile from oauth2client.client import AccessTokenCredentials from oauth2client.client import AccessTokenCredentialsError from oauth2client.client import AccessTokenRefreshError @@ -50,12 +51,29 @@ from oauth2client.client import VerifyJwtTokenError from oauth2client.client import _extract_id_token from oauth2client.client import credentials_from_code from oauth2client.client import credentials_from_clientsecrets_and_code +from oauth2client.client import flow_from_clientsecrets DATA_DIR = os.path.join(os.path.dirname(__file__), 'data') def datafile(filename): return os.path.join(DATA_DIR, filename) +def load_and_cache(existing_file, fakename, cache_mock): + client_type, client_info = _loadfile(datafile(existing_file)) + cache_mock.cache[fakename] = {client_type: client_info} + +class CacheMock(object): + def __init__(self): + self.cache = {} + + def get(self, key, namespace=''): + # ignoring namespace for easier testing + return self.cache.get(key, None) + + def set(self, key, value, namespace=''): + # ignoring namespace for easier testing + self.cache[key] = value + class CredentialsTests(unittest.TestCase): @@ -375,6 +393,16 @@ class OAuth2WebServerFlowTest(unittest.TestCase): credentials = self.flow.step2_exchange('some random code', http) self.assertEqual(credentials.id_token, body) +class FlowFromCachedClientsecrets(unittest.TestCase): + + def test_flow_from_clientsecrets_cached(self): + cache_mock = CacheMock() + load_and_cache('client_secrets.json', 'some_secrets', cache_mock) + + # flow_from_clientsecrets(filename, scope, message=None, cache=None) + flow = flow_from_clientsecrets('some_secrets', '', cache=cache_mock) + self.assertEquals('foo_client_secret', flow.client_secret) + class CredentialsFromCodeTests(unittest.TestCase): def setUp(self): self.client_id = 'client_id_abc' @@ -421,6 +449,18 @@ class CredentialsFromCodeTests(unittest.TestCase): self.assertEquals(credentials.access_token, 'asdfghjkl') self.assertNotEqual(None, credentials.token_expiry) + def test_exchange_code_and_cached_file_for_token(self): + http = HttpMockSequence([ + ({'status': '200'}, '{ "access_token":"asdfghjkl"}'), + ]) + cache_mock = CacheMock() + load_and_cache('client_secrets.json', 'some_secrets', cache_mock) + + credentials = credentials_from_clientsecrets_and_code( + 'some_secrets', self.scope, + self.code, http=http, cache=cache_mock) + self.assertEquals(credentials.access_token, 'asdfghjkl') + def test_exchange_code_and_file_for_token_fail(self): http = HttpMockSequence([ ({'status': '400'}, '{"error":"invalid_request"}'), diff --git a/tests/test_oauth2client_appengine.py b/tests/test_oauth2client_appengine.py index edc5996..58f3e55 100644 --- a/tests/test_oauth2client_appengine.py +++ b/tests/test_oauth2client_appengine.py @@ -50,6 +50,7 @@ from google.appengine.ext import db from google.appengine.ext import testbed from google.appengine.runtime import apiproxy_errors from oauth2client.anyjson import simplejson +from oauth2client.clientsecrets import _loadfile from oauth2client.appengine import AppAssertionCredentials from oauth2client.appengine import CredentialsModel from oauth2client.appengine import FlowProperty @@ -72,6 +73,24 @@ def datafile(filename): return os.path.join(DATA_DIR, filename) +def load_and_cache(existing_file, fakename, cache_mock): + client_type, client_info = _loadfile(datafile(existing_file)) + cache_mock.cache[fakename] = {client_type: client_info} + + +class CacheMock(object): + def __init__(self): + self.cache = {} + + def get(self, key, namespace=''): + # ignoring namespace for easier testing + return self.cache.get(key, None) + + def set(self, key, value, namespace=''): + # ignoring namespace for easier testing + self.cache[key] = value + + class UserMock(object): """Mock the app engine user service""" @@ -439,6 +458,14 @@ class DecoratorTests(unittest.TestCase): http = self.decorator.http() self.assertEquals('foo_access_token', http.request.credentials.access_token) + def test_decorator_from_cached_client_secrets(self): + cache_mock = CacheMock() + load_and_cache('client_secrets.json', 'secret', cache_mock) + decorator = oauth2decorator_from_clientsecrets( + # filename, scope, message=None, cache=None + 'secret', '', cache=cache_mock) + self.assertFalse(decorator._in_error) + def test_decorator_from_client_secrets_not_logged_in_required(self): decorator = oauth2decorator_from_clientsecrets( datafile('client_secrets.json'), diff --git a/tests/test_oauth2client_clientsecrets.py b/tests/test_oauth2client_clientsecrets.py index 20cdf14..f69fb36 100644 --- a/tests/test_oauth2client_clientsecrets.py +++ b/tests/test_oauth2client_clientsecrets.py @@ -26,6 +26,11 @@ import httplib2 from oauth2client import clientsecrets +DATA_DIR = os.path.join(os.path.dirname(__file__), 'data') +VALID_FILE = os.path.join(DATA_DIR, 'client_secrets.json') +INVALID_FILE = os.path.join(DATA_DIR, 'unfilled_client_secrets.json') +NONEXISTENT_FILE = os.path.join(__file__, '..', 'afilethatisntthere.json') + class OAuth2CredentialsTests(unittest.TestCase): def setUp(self): @@ -69,12 +74,72 @@ class OAuth2CredentialsTests(unittest.TestCase): def test_load_by_filename(self): try: - clientsecrets.loadfile(os.path.join(__file__, '..', - 'afilethatisntthere.json')) + clientsecrets._loadfile(NONEXISTENT_FILE) self.fail('should fail to load a missing client_secrets file.') except clientsecrets.InvalidClientSecretsError, e: self.assertTrue(str(e).startswith('File')) +class CachedClientsecretsTests(unittest.TestCase): + + class CacheMock(object): + def __init__(self): + self.cache = {} + self.last_get_ns = None + self.last_set_ns = None + + def get(self, key, namespace=''): + # ignoring namespace for easier testing + self.last_get_ns = namespace + return self.cache.get(key, None) + + def set(self, key, value, namespace=''): + # ignoring namespace for easier testing + self.last_set_ns = namespace + self.cache[key] = value + + def setUp(self): + self.cache_mock = self.CacheMock() + + def test_cache_miss(self): + client_type, client_info = clientsecrets.loadfile( + VALID_FILE, cache=self.cache_mock) + self.assertEquals('web', client_type) + self.assertEquals('foo_client_secret', client_info['client_secret']) + + cached = self.cache_mock.cache[VALID_FILE] + self.assertEquals({client_type: client_info}, cached) + + # make sure we're using non-empty namespace + ns = self.cache_mock.last_set_ns + self.assertTrue(bool(ns)) + # make sure they're equal + self.assertEquals(ns, self.cache_mock.last_get_ns) + + def test_cache_hit(self): + self.cache_mock.cache[NONEXISTENT_FILE] = { 'web': 'secret info' } + + client_type, client_info = clientsecrets.loadfile( + NONEXISTENT_FILE, cache=self.cache_mock) + self.assertEquals('web', client_type) + self.assertEquals('secret info', client_info) + # make sure we didn't do any set() RPCs + self.assertEqual(None, self.cache_mock.last_set_ns) + + def test_validation(self): + try: + clientsecrets.loadfile(INVALID_FILE, cache=self.cache_mock) + self.fail('Expected InvalidClientSecretsError to be raised ' + 'while loading %s' % INVALID_FILE) + except clientsecrets.InvalidClientSecretsError: + pass + + def test_without_cache(self): + # this also ensures loadfile() is backward compatible + client_type, client_info = clientsecrets.loadfile(VALID_FILE) + self.assertEquals('web', client_type) + self.assertEquals('foo_client_secret', client_info['client_secret']) + + if __name__ == '__main__': unittest.main()