Add .delete() to all Storages.

Reviewed in http://codereview.appspot.com/5608049/.
This commit is contained in:
Joe Gregorio
2012-02-06 13:40:42 -05:00
parent 9f2f38f009
commit ec75dc109a
6 changed files with 119 additions and 2 deletions

View File

@@ -263,6 +263,16 @@ class StorageByKeyName(Storage):
if self._cache:
self._cache.set(self._key_name, credentials.to_json())
def locked_delete(self):
"""Delete Credential from datastore."""
if self._cache:
self._cache.delete(self._key_name)
entity = self._model.get_by_key_name(self._key_name)
if entity is not None:
entity.delete()
class CredentialsModel(db.Model):
"""Storage for OAuth 2.0 Credentials

View File

@@ -224,6 +224,13 @@ class Storage(object):
"""
_abstract()
def locked_delete(self):
"""Delete a credential.
The Storage lock must be held when this is called.
"""
_abstract()
def get(self):
"""Retrieve credential.
@@ -252,6 +259,21 @@ class Storage(object):
finally:
self.release_lock()
def delete(self):
"""Delete credential.
Frees any resources associated with storing the credential.
The Storage lock must *not* be held when this is called.
Returns:
None
"""
self.acquire_lock()
try:
return self.locked_delete()
finally:
self.release_lock()
class OAuth2Credentials(Credentials):
"""Credentials object for OAuth 2.0.

View File

@@ -85,7 +85,6 @@ class Storage(BaseStorage):
finally:
os.umask(old_umask)
def locked_put(self, credentials):
"""Write Credentials to file.
@@ -97,3 +96,12 @@ class Storage(BaseStorage):
f = open(self._filename, 'wb')
f.write(credentials.to_json())
f.close()
def locked_delete(self):
"""Delete Credentials file.
Args:
credentials: Credentials, the credentials to store.
"""
os.unlink(self._filename)

View File

@@ -159,6 +159,17 @@ class _MultiStore(object):
"""
self._multistore._update_credential(credentials, self._scope)
def locked_delete(self):
"""Delete a credential.
The Storage lock must be held when this is called.
Args:
credentials: Credentials, the credentials to store.
"""
self._multistore._delete_credential(self._client_id, self._user_agent,
self._scope)
def _create_file_if_needed(self):
"""Create an empty file if necessary.
@@ -344,6 +355,23 @@ class _MultiStore(object):
self._data[key] = cred
self._write()
def _delete_credential(self, client_id, user_agent, scope):
"""Delete a credential and write the multistore.
This must be called when the multistore is locked.
Args:
client_id: The client_id for the credential
user_agent: The user agent for the credential
scope: The scope(s) that this credential covers
"""
key = (client_id, user_agent, scope)
try:
del self._data[key]
except KeyError:
pass
self._write()
def _get_storage(self, client_id, user_agent, scope):
"""Get a Storage object to get/set a credential.

View File

@@ -203,6 +203,30 @@ class DecoratorTests(unittest.TestCase):
q = parse_qs(response.headers['Location'].split('?', 1)[1])
self.assertEqual('http://localhost/oauth2callback', q['redirect_uri'][0])
def test_storage_delete(self):
# An initial request to an oauth_required decorated path should be a
# redirect to start the OAuth dance.
response = self.app.get('/foo_path')
self.assertTrue(response.status.startswith('302'))
# Now simulate the callback to /oauth2callback.
response = self.app.get('/oauth2callback', {
'code': 'foo_access_code',
'state': 'foo_path',
})
self.assertEqual('http://localhost/foo_path', response.headers['Location'])
self.assertEqual(None, self.decorator.credentials)
# Now requesting the decorated path should work.
response = self.app.get('/foo_path')
# Invalidate the stored Credentials.
self.decorator.credentials.store.delete()
# Invalid Credentials should start the OAuth dance again.
response = self.app.get('/foo_path')
self.assertTrue(response.status.startswith('302'))
def test_aware(self):
# An initial request to an oauth_aware decorated path should not redirect.
response = self.app.get('/bar_path')

View File

@@ -98,7 +98,6 @@ class OAuth2ClientFileTests(unittest.TestCase):
self.assertEquals(data['_module'], OAuth2Credentials.__module__)
def test_token_refresh(self):
# Write a file with a pickled OAuth2Credentials.
access_token = 'foo'
client_id = 'some_client_id'
client_secret = 'cOuDdkfjxxnv+'
@@ -122,6 +121,27 @@ class OAuth2ClientFileTests(unittest.TestCase):
credentials._refresh(lambda x: x)
self.assertEquals(credentials.access_token, 'bar')
def test_credentials_delete(self):
access_token = 'foo'
client_id = 'some_client_id'
client_secret = 'cOuDdkfjxxnv+'
refresh_token = '1/0/a.df219fjls0'
token_expiry = datetime.datetime.utcnow()
token_uri = 'https://www.google.com/accounts/o8/oauth2/token'
user_agent = 'refresh_checker/1.0'
credentials = OAuth2Credentials(
access_token, client_id, client_secret,
refresh_token, token_expiry, token_uri,
user_agent)
s = Storage(FILENAME)
s.put(credentials)
credentials = s.get()
self.assertNotEquals(None, credentials)
s.delete()
credentials = s.get()
self.assertEquals(None, credentials)
def test_access_token_credentials(self):
access_token = 'foo'
@@ -205,6 +225,11 @@ class OAuth2ClientFileTests(unittest.TestCase):
self.assertNotEquals(None, credentials)
self.assertEquals('foo', credentials.access_token)
store.delete()
credentials = store.get()
self.assertEquals(None, credentials)
if os.name == 'posix':
self.assertEquals('0600', oct(stat.S_IMODE(os.stat(FILENAME).st_mode)))