diff --git a/oauth2client/appengine.py b/oauth2client/appengine.py index e4169e9..f97ce69 100644 --- a/oauth2client/appengine.py +++ b/oauth2client/appengine.py @@ -263,6 +263,16 @@ class StorageByKeyName(Storage): if self._cache: self._cache.set(self._key_name, credentials.to_json()) + def locked_delete(self): + """Delete Credential from datastore.""" + + if self._cache: + self._cache.delete(self._key_name) + + entity = self._model.get_by_key_name(self._key_name) + if entity is not None: + entity.delete() + class CredentialsModel(db.Model): """Storage for OAuth 2.0 Credentials diff --git a/oauth2client/client.py b/oauth2client/client.py index 2d60f5a..f492563 100644 --- a/oauth2client/client.py +++ b/oauth2client/client.py @@ -224,6 +224,13 @@ class Storage(object): """ _abstract() + def locked_delete(self): + """Delete a credential. + + The Storage lock must be held when this is called. + """ + _abstract() + def get(self): """Retrieve credential. @@ -252,6 +259,21 @@ class Storage(object): finally: self.release_lock() + def delete(self): + """Delete credential. + + Frees any resources associated with storing the credential. + The Storage lock must *not* be held when this is called. + + Returns: + None + """ + self.acquire_lock() + try: + return self.locked_delete() + finally: + self.release_lock() + class OAuth2Credentials(Credentials): """Credentials object for OAuth 2.0. diff --git a/oauth2client/file.py b/oauth2client/file.py index d71e888..1abc6d2 100644 --- a/oauth2client/file.py +++ b/oauth2client/file.py @@ -85,7 +85,6 @@ class Storage(BaseStorage): finally: os.umask(old_umask) - def locked_put(self, credentials): """Write Credentials to file. @@ -97,3 +96,12 @@ class Storage(BaseStorage): f = open(self._filename, 'wb') f.write(credentials.to_json()) f.close() + + def locked_delete(self): + """Delete Credentials file. + + Args: + credentials: Credentials, the credentials to store. + """ + + os.unlink(self._filename) diff --git a/oauth2client/multistore_file.py b/oauth2client/multistore_file.py index d9b89c8..1f756c7 100644 --- a/oauth2client/multistore_file.py +++ b/oauth2client/multistore_file.py @@ -159,6 +159,17 @@ class _MultiStore(object): """ self._multistore._update_credential(credentials, self._scope) + def locked_delete(self): + """Delete a credential. + + The Storage lock must be held when this is called. + + Args: + credentials: Credentials, the credentials to store. + """ + self._multistore._delete_credential(self._client_id, self._user_agent, + self._scope) + def _create_file_if_needed(self): """Create an empty file if necessary. @@ -344,6 +355,23 @@ class _MultiStore(object): self._data[key] = cred self._write() + def _delete_credential(self, client_id, user_agent, scope): + """Delete a credential and write the multistore. + + This must be called when the multistore is locked. + + Args: + client_id: The client_id for the credential + user_agent: The user agent for the credential + scope: The scope(s) that this credential covers + """ + key = (client_id, user_agent, scope) + try: + del self._data[key] + except KeyError: + pass + self._write() + def _get_storage(self, client_id, user_agent, scope): """Get a Storage object to get/set a credential. diff --git a/tests/test_oauth2client_appengine.py b/tests/test_oauth2client_appengine.py index f2f470f..90ee335 100644 --- a/tests/test_oauth2client_appengine.py +++ b/tests/test_oauth2client_appengine.py @@ -203,6 +203,30 @@ class DecoratorTests(unittest.TestCase): q = parse_qs(response.headers['Location'].split('?', 1)[1]) self.assertEqual('http://localhost/oauth2callback', q['redirect_uri'][0]) + def test_storage_delete(self): + # An initial request to an oauth_required decorated path should be a + # redirect to start the OAuth dance. + response = self.app.get('/foo_path') + self.assertTrue(response.status.startswith('302')) + + # Now simulate the callback to /oauth2callback. + response = self.app.get('/oauth2callback', { + 'code': 'foo_access_code', + 'state': 'foo_path', + }) + self.assertEqual('http://localhost/foo_path', response.headers['Location']) + self.assertEqual(None, self.decorator.credentials) + + # Now requesting the decorated path should work. + response = self.app.get('/foo_path') + + # Invalidate the stored Credentials. + self.decorator.credentials.store.delete() + + # Invalid Credentials should start the OAuth dance again. + response = self.app.get('/foo_path') + self.assertTrue(response.status.startswith('302')) + def test_aware(self): # An initial request to an oauth_aware decorated path should not redirect. response = self.app.get('/bar_path') diff --git a/tests/test_oauth2client_file.py b/tests/test_oauth2client_file.py index 7ff1eeb..596a4b6 100644 --- a/tests/test_oauth2client_file.py +++ b/tests/test_oauth2client_file.py @@ -98,7 +98,6 @@ class OAuth2ClientFileTests(unittest.TestCase): self.assertEquals(data['_module'], OAuth2Credentials.__module__) def test_token_refresh(self): - # Write a file with a pickled OAuth2Credentials. access_token = 'foo' client_id = 'some_client_id' client_secret = 'cOuDdkfjxxnv+' @@ -122,6 +121,27 @@ class OAuth2ClientFileTests(unittest.TestCase): credentials._refresh(lambda x: x) self.assertEquals(credentials.access_token, 'bar') + def test_credentials_delete(self): + access_token = 'foo' + client_id = 'some_client_id' + client_secret = 'cOuDdkfjxxnv+' + refresh_token = '1/0/a.df219fjls0' + token_expiry = datetime.datetime.utcnow() + token_uri = 'https://www.google.com/accounts/o8/oauth2/token' + user_agent = 'refresh_checker/1.0' + + credentials = OAuth2Credentials( + access_token, client_id, client_secret, + refresh_token, token_expiry, token_uri, + user_agent) + + s = Storage(FILENAME) + s.put(credentials) + credentials = s.get() + self.assertNotEquals(None, credentials) + s.delete() + credentials = s.get() + self.assertEquals(None, credentials) def test_access_token_credentials(self): access_token = 'foo' @@ -205,6 +225,11 @@ class OAuth2ClientFileTests(unittest.TestCase): self.assertNotEquals(None, credentials) self.assertEquals('foo', credentials.access_token) + store.delete() + credentials = store.get() + + self.assertEquals(None, credentials) + if os.name == 'posix': self.assertEquals('0600', oct(stat.S_IMODE(os.stat(FILENAME).st_mode)))