diff --git a/docs/source/oauth2client.contrib.multiprocess_file_storage.rst b/docs/source/oauth2client.contrib.multiprocess_file_storage.rst new file mode 100644 index 0000000..6f683a0 --- /dev/null +++ b/docs/source/oauth2client.contrib.multiprocess_file_storage.rst @@ -0,0 +1,7 @@ +oauth2client.contrib.multiprocess_file_storage module +===================================================== + +.. automodule:: oauth2client.contrib.multiprocess_file_storage + :members: + :undoc-members: + :show-inheritance: diff --git a/docs/source/oauth2client.contrib.rst b/docs/source/oauth2client.contrib.rst index d926c76..a06b48f 100644 --- a/docs/source/oauth2client.contrib.rst +++ b/docs/source/oauth2client.contrib.rst @@ -21,6 +21,7 @@ Submodules oauth2client.contrib.gce oauth2client.contrib.keyring_storage oauth2client.contrib.locked_file + oauth2client.contrib.multiprocess_file_storage oauth2client.contrib.multistore_file oauth2client.contrib.sqlalchemy oauth2client.contrib.xsrfutil diff --git a/oauth2client/contrib/multiprocess_file_storage.py b/oauth2client/contrib/multiprocess_file_storage.py new file mode 100644 index 0000000..14e5fc3 --- /dev/null +++ b/oauth2client/contrib/multiprocess_file_storage.py @@ -0,0 +1,356 @@ +# Copyright 2016 Google Inc. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Multiprocess file credential storage. + +This module provides file-based storage that supports multiple credentials and +cross-thread and process access. + +This module supersedes the functionality previously found in `multistore_file`. + +This module provides :class:`MultiprocessFileStorage` which: + * Is tied to a single credential via a user-specified key. This key can be + used to distinguish between multiple users, client ids, and/or scopes. + * Can be safely accessed and refreshed across threads and processes. + +Process & thread safety guarantees the following behavior: + * If one thread or process refreshes a credential, subsequent refreshes + from other processes will re-fetch the credentials from the file instead + of performing an http request. + * If two processes or threads attempt to refresh concurrently, only one + will be able to acquire the lock and refresh, with the deadlock caveat + below. + * The interprocess lock will not deadlock, instead, the if a process can + not acquire the interprocess lock within ``INTERPROCESS_LOCK_DEADLINE`` + it will allow refreshing the credential but will not write the updated + credential to disk, This logic happens during every lock cycle - if the + credentials are refreshed again it will retry locking and writing as + normal. + +Usage +===== + +Before using the storage, you need to decide how you want to key the +credentials. A few common strategies include: + + * If you're storing credentials for multiple users in a single file, use + a unique identifier for each user as the key. + * If you're storing credentials for multiple client IDs in a single file, + use the client ID as the key. + * If you're storing multiple credentials for one user, use the scopes as + the key. + * If you have a complicated setup, use a compound key. For example, you + can use a combination of the client ID and scopes as the key. + +Create an instance of :class:`MultiprocessFileStorage` for each credential you +want to store, for example:: + + filename = 'credentials' + key = '{}-{}'.format(client_id, user_id) + storage = MultiprocessFileStorage(filename, key) + +To store the credentials:: + + storage.put(credentials) + +If you're going to continue to use the credentials after storing them, be sure +to call :func:`set_store`:: + + credentials.set_store(storage) + +To retrieve the credentials:: + + storage.get(credentials) + +""" + +import base64 +import json +import logging +import os +import threading + +import fasteners +from six import iteritems + +from oauth2client import _helpers +from oauth2client.client import Credentials +from oauth2client.client import Storage as BaseStorage + + +#: The maximum amount of time, in seconds, to wait when acquire the +#: interprocess lock before falling back to read-only mode. +INTERPROCESS_LOCK_DEADLINE = 1 + +logger = logging.getLogger(__name__) +_backends = {} +_backends_lock = threading.Lock() + + +def _create_file_if_needed(filename): + """Creates the an empty file if it does not already exist. + + Returns: + True if the file was created, False otherwise. + """ + if os.path.exists(filename): + return False + else: + # Equivalent to "touch". + open(filename, 'a+b').close() + logger.info('Credential file {0} created'.format(filename)) + return True + + +def _load_credentials_file(credentials_file): + """Load credentials from the given file handle. + + The file is expected to be in this format: + + { + "file_version": 2, + "credentials": { + "key": "base64 encoded json representation of credentials." + } + } + + This function will warn and return empty credentials instead of raising + exceptions. + + Args: + credentials_file: An open file handle. + + Returns: + A dictionary mapping user-defined keys to an instance of + :class:`oauth2client.client.Credentials`. + """ + try: + credentials_file.seek(0) + data = json.load(credentials_file) + except Exception: + logger.warning( + 'Credentials file could not be loaded, will ignore and ' + 'overwrite.') + return {} + + if data.get('file_version') != 2: + logger.warning( + 'Credentials file is not version 2, will ignore and ' + 'overwrite.') + return {} + + credentials = {} + + for key, encoded_credential in iteritems(data.get('credentials', {})): + try: + credential_json = base64.b64decode(encoded_credential) + credential = Credentials.new_from_json(credential_json) + credentials[key] = credential + except: + logger.warning( + 'Invalid credential {0} in file, ignoring.'.format(key)) + + return credentials + + +def _write_credentials_file(credentials_file, credentials): + """Writes credentials to a file. + + Refer to :func:`_load_credentials_file` for the format. + + Args: + credentials_file: An open file handle, must be read/write. + credentials: A dictionary mapping user-defined keys to an instance of + :class:`oauth2client.client.Credentials`. + """ + data = {'file_version': 2, 'credentials': {}} + + for key, credential in iteritems(credentials): + credential_json = credential.to_json() + encoded_credential = _helpers._from_bytes(base64.b64encode( + _helpers._to_bytes(credential_json))) + data['credentials'][key] = encoded_credential + + credentials_file.seek(0) + json.dump(data, credentials_file) + credentials_file.truncate() + + +class _MultiprocessStorageBackend(object): + """Thread-local backend for multiprocess storage. + + Each process has only one instance of this backend per file. All threads + share a single instance of this backend. This ensures that all threads + use the same thread lock and process lock when accessing the file. + """ + + def __init__(self, filename): + self._file = None + self._filename = filename + self._process_lock = fasteners.InterProcessLock( + '{0}.lock'.format(filename)) + self._thread_lock = threading.Lock() + self._read_only = False + self._credentials = {} + + def _load_credentials(self): + """(Re-)loads the credentials from the file.""" + if not self._file: + return + + loaded_credentials = _load_credentials_file(self._file) + self._credentials.update(loaded_credentials) + + logger.debug('Read credential file') + + def _write_credentials(self): + if self._read_only: + logger.debug('In read-only mode, not writing credentials.') + return + + _write_credentials_file(self._file, self._credentials) + logger.debug('Wrote credential file {0}.'.format(self._filename)) + + def acquire_lock(self): + self._thread_lock.acquire() + locked = self._process_lock.acquire(timeout=INTERPROCESS_LOCK_DEADLINE) + + if locked: + _create_file_if_needed(self._filename) + self._file = open(self._filename, 'r+') + self._read_only = False + + else: + logger.warn( + 'Failed to obtain interprocess lock for credentials. ' + 'If a credential is being refreshed, other processes may ' + 'not see the updated access token and refresh as well.') + if os.path.exists(self._filename): + self._file = open(self._filename, 'r') + else: + self._file = None + self._read_only = True + + self._load_credentials() + + def release_lock(self): + if self._file is not None: + self._file.close() + self._file = None + + if not self._read_only: + self._process_lock.release() + + self._thread_lock.release() + + def _refresh_predicate(self, credentials): + if credentials is None: + return True + elif credentials.invalid: + return True + elif credentials.access_token_expired: + return True + else: + return False + + def locked_get(self, key): + # Check if the credential is already in memory. + credentials = self._credentials.get(key, None) + + # Use the refresh predicate to determine if the entire store should be + # reloaded. This basically checks if the credentials are invalid + # or expired. This covers the situation where another process has + # refreshed the credentials and this process doesn't know about it yet. + # In that case, this process won't needlessly refresh the credentials. + if self._refresh_predicate(credentials): + self._load_credentials() + credentials = self._credentials.get(key, None) + + return credentials + + def locked_put(self, key, credentials): + self._load_credentials() + self._credentials[key] = credentials + self._write_credentials() + + def locked_delete(self, key): + self._load_credentials() + self._credentials.pop(key, None) + self._write_credentials() + + +def _get_backend(filename): + """A helper method to get or create a backend with thread locking. + + This ensures that only one backend is used per-file per-process, so that + thread and process locks are appropriately shared. + + Args: + filename: The full path to the credential storage file. + + Returns: + An instance of :class:`_MultiprocessStorageBackend`. + """ + filename = os.path.abspath(filename) + + with _backends_lock: + if filename not in _backends: + _backends[filename] = _MultiprocessStorageBackend(filename) + return _backends[filename] + + +class MultiprocessFileStorage(BaseStorage): + """Multiprocess file credential storage. + + Args: + filename: The path to the file where credentials will be stored. + key: An arbitrary string used to uniquely identify this set of + credentials. For example, you may use the user's ID as the key or + a combination of the client ID and user ID. + """ + def __init__(self, filename, key): + self._key = key + self._backend = _get_backend(filename) + + def acquire_lock(self): + self._backend.acquire_lock() + + def release_lock(self): + self._backend.release_lock() + + def locked_get(self): + """Retrieves the current credentials from the store. + + Returns: + An instance of :class:`oauth2client.client.Credentials` or `None`. + """ + credential = self._backend.locked_get(self._key) + + if credential is not None: + credential.set_store(self) + + return credential + + def locked_put(self, credentials): + """Writes the given credentials to the store. + + Args: + credentials: an instance of + :class:`oauth2client.client.Credentials`. + """ + return self._backend.locked_put(self._key, credentials) + + def locked_delete(self): + """Deletes the current credentials from the store.""" + return self._backend.locked_delete(self._key) diff --git a/tests/contrib/test_multiprocess_file_storage.py b/tests/contrib/test_multiprocess_file_storage.py new file mode 100644 index 0000000..a59d7f6 --- /dev/null +++ b/tests/contrib/test_multiprocess_file_storage.py @@ -0,0 +1,313 @@ +# Copyright 2015 Google Inc. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Unit tests for oauth2client.multistore_file.""" + +import contextlib +import datetime +import json +import multiprocessing +import os +import tempfile + +import fasteners +import mock +from six import StringIO +import unittest2 + +from oauth2client.client import OAuth2Credentials +from oauth2client.contrib import multiprocess_file_storage + +from ..http_mock import HttpMockSequence + + +@contextlib.contextmanager +def scoped_child_process(target, **kwargs): + die_event = multiprocessing.Event() + ready_event = multiprocessing.Event() + process = multiprocessing.Process( + target=target, args=(die_event, ready_event), kwargs=kwargs) + process.start() + try: + ready_event.wait() + yield + finally: + die_event.set() + process.join(5) + + +def _create_test_credentials(expiration=None): + access_token = 'foo' + client_secret = 'cOuDdkfjxxnv+' + refresh_token = '1/0/a.df219fjls0' + token_expiry = expiration or ( + datetime.datetime.utcnow() + datetime.timedelta(seconds=3600)) + token_uri = 'https://www.google.com/accounts/o8/oauth2/token' + user_agent = 'refresh_checker/1.0' + + credentials = OAuth2Credentials( + access_token, 'test-client-id', client_secret, + refresh_token, token_expiry, token_uri, + user_agent) + return credentials + + +def _generate_token_response_http(new_token='new_token'): + token_response = json.dumps({ + 'access_token': new_token, + 'expires_in': '3600', + }) + http = HttpMockSequence([ + ({'status': '200'}, token_response), + ]) + + return http + + +class MultiprocessStorageBehaviorTests(unittest2.TestCase): + + def setUp(self): + filehandle, self.filename = tempfile.mkstemp( + 'oauth2client_test.data') + os.close(filehandle) + + def tearDown(self): + try: + os.unlink(self.filename) + os.unlink('{0}.lock'.format(self.filename)) + except OSError: # pragma: NO COVER + pass + + def test_basic_operations(self): + credentials = _create_test_credentials() + + store = multiprocess_file_storage.MultiprocessFileStorage( + self.filename, 'basic') + + # Save credentials + store.put(credentials) + credentials = store.get() + + self.assertIsNotNone(credentials) + self.assertEqual('foo', credentials.access_token) + + # Reset internal cache, ensure credentials were saved. + store._backend._credentials = {} + credentials = store.get() + + self.assertIsNotNone(credentials) + self.assertEqual('foo', credentials.access_token) + + # Delete credentials + store.delete() + credentials = store.get() + + self.assertIsNone(credentials) + + def test_single_process_refresh(self): + store = multiprocess_file_storage.MultiprocessFileStorage( + self.filename, 'single-process') + credentials = _create_test_credentials() + credentials.set_store(store) + + http = _generate_token_response_http() + credentials.refresh(http) + self.assertEqual(credentials.access_token, 'new_token') + + retrieved = store.get() + self.assertEqual(retrieved.access_token, 'new_token') + + def test_multi_process_refresh(self): + # This will test that two processes attempting to refresh credentials + # will only refresh once. + store = multiprocess_file_storage.MultiprocessFileStorage( + self.filename, 'multi-process') + credentials = _create_test_credentials() + credentials.set_store(store) + store.put(credentials) + + def child_process_func( + die_event, ready_event, check_event): # pragma: NO COVER + store = multiprocess_file_storage.MultiprocessFileStorage( + self.filename, 'multi-process') + + credentials = store.get() + self.assertIsNotNone(credentials) + + # Make sure this thread gets to refresh first. + original_acquire_lock = store.acquire_lock + + def replacement_acquire_lock(*args, **kwargs): + result = original_acquire_lock(*args, **kwargs) + ready_event.set() + check_event.wait() + return result + + credentials.store.acquire_lock = replacement_acquire_lock + + http = _generate_token_response_http('b') + credentials.refresh(http) + + self.assertEqual(credentials.access_token, 'b') + + check_event = multiprocessing.Event() + with scoped_child_process(child_process_func, check_event=check_event): + # The lock should be currently held by the child process. + self.assertFalse( + store._backend._process_lock.acquire(blocking=False)) + check_event.set() + + # The child process will refresh first, so we should end up + # with 'b' as the token. + http = mock.Mock() + credentials.refresh(http=http) + self.assertEqual(credentials.access_token, 'b') + self.assertFalse(http.request.called) + + retrieved = store.get() + self.assertEqual(retrieved.access_token, 'b') + + def test_read_only_file_fail_lock(self): + credentials = _create_test_credentials() + + # Grab the lock in another process, preventing this process from + # acquiring the lock. + def child_process(die_event, ready_event): # pragma: NO COVER + lock = fasteners.InterProcessLock( + '{0}.lock'.format(self.filename)) + with lock: + ready_event.set() + die_event.wait() + + with scoped_child_process(child_process): + store = multiprocess_file_storage.MultiprocessFileStorage( + self.filename, 'fail-lock') + store.put(credentials) + self.assertTrue(store._backend._read_only) + + # These credentials should still be in the store's memory-only cache. + self.assertIsNotNone(store.get()) + + +class MultiprocessStorageUnitTests(unittest2.TestCase): + + def setUp(self): + filehandle, self.filename = tempfile.mkstemp( + 'oauth2client_test.data') + os.close(filehandle) + + def tearDown(self): + try: + os.unlink(self.filename) + os.unlink('{0}.lock'.format(self.filename)) + except OSError: # pragma: NO COVER + pass + + def test__create_file_if_needed(self): + self.assertFalse( + multiprocess_file_storage._create_file_if_needed(self.filename)) + os.unlink(self.filename) + self.assertTrue( + multiprocess_file_storage._create_file_if_needed(self.filename)) + self.assertTrue( + os.path.exists(self.filename)) + + def test__get_backend(self): + backend_one = multiprocess_file_storage._get_backend('file_a') + backend_two = multiprocess_file_storage._get_backend('file_a') + backend_three = multiprocess_file_storage._get_backend('file_b') + + self.assertIs(backend_one, backend_two) + self.assertIsNot(backend_one, backend_three) + + def test__read_write_credentials_file(self): + credentials = _create_test_credentials() + contents = StringIO() + + multiprocess_file_storage._write_credentials_file( + contents, {'key': credentials}) + + contents.seek(0) + data = json.load(contents) + self.assertEqual(data['file_version'], 2) + self.assertTrue(data['credentials']['key']) + + # Read it back. + contents.seek(0) + results = multiprocess_file_storage._load_credentials_file(contents) + self.assertEqual( + results['key'].access_token, credentials.access_token) + + # Add an invalid credential and try reading it back. It should ignore + # the invalid one but still load the valid one. + data['credentials']['invalid'] = '123' + results = multiprocess_file_storage._load_credentials_file( + StringIO(json.dumps(data))) + self.assertNotIn('invalid', results) + self.assertEqual( + results['key'].access_token, credentials.access_token) + + def test__load_credentials_file_invalid_json(self): + contents = StringIO('{[') + self.assertEqual( + multiprocess_file_storage._load_credentials_file(contents), {}) + + def test__load_credentials_file_no_file_version(self): + contents = StringIO('{}') + self.assertEqual( + multiprocess_file_storage._load_credentials_file(contents), {}) + + def test__load_credentials_file_bad_file_version(self): + contents = StringIO(json.dumps({'file_version': 1})) + self.assertEqual( + multiprocess_file_storage._load_credentials_file(contents), {}) + + def test__load_credentials_no_open_file(self): + backend = multiprocess_file_storage._get_backend(self.filename) + backend._credentials = mock.Mock() + backend._credentials.update.side_effect = AssertionError() + backend._load_credentials() + + def test_acquire_lock_nonexistent_file(self): + backend = multiprocess_file_storage._get_backend(self.filename) + os.unlink(self.filename) + backend._process_lock = mock.Mock() + backend._process_lock.acquire.return_value = False + backend.acquire_lock() + self.assertIsNone(backend._file) + + def test_release_lock_with_no_file(self): + backend = multiprocess_file_storage._get_backend(self.filename) + backend._file = None + backend._read_only = True + backend._thread_lock.acquire() + backend.release_lock() + + def test__refresh_predicate(self): + backend = multiprocess_file_storage._get_backend(self.filename) + + credentials = _create_test_credentials() + self.assertFalse(backend._refresh_predicate(credentials)) + + credentials.invalid = True + self.assertTrue(backend._refresh_predicate(credentials)) + + credentials = _create_test_credentials( + expiration=( + datetime.datetime.utcnow() - datetime.timedelta(seconds=3600))) + self.assertTrue(backend._refresh_predicate(credentials)) + + +if __name__ == '__main__': # pragma: NO COVER + unittest2.main() diff --git a/tox.ini b/tox.ini index 1c9d92b..1237a3b 100644 --- a/tox.ini +++ b/tox.ini @@ -11,6 +11,7 @@ basedeps = mock>=1.3.0 flask unittest2 sqlalchemy + fasteners deps = {[testenv]basedeps} django keyring