Modify the OAuth2Credentials.get_access_token() method to return the expiration information, in addition to the access token.

This commit is contained in:
Orest Bolohan
2014-06-02 15:04:34 -07:00
parent a3e52c0eee
commit 50bcae88b7
5 changed files with 37 additions and 21 deletions

View File

@@ -600,18 +600,16 @@ class OAuth2Credentials(Credentials):
return False return False
def get_access_token(self, http=None): def get_access_token(self, http=None):
"""Return the access token. """Return the access token and its expiration information.
If the token does not exist, get one. If the token does not exist, get one.
If the token expired, refresh it. If the token expired, refresh it.
""" """
if self.access_token and not self.access_token_expired: if not self.access_token or self.access_token_expired:
return self.access_token
else:
if not http: if not http:
http = httplib2.Http() http = httplib2.Http()
self.refresh(http) self.refresh(http)
return self.access_token return {'access_token': self.access_token, 'expires_in': self._expires_in()}
def set_store(self, store): def set_store(self, store):
"""Set the Storage for the credential. """Set the Storage for the credential.
@@ -625,6 +623,16 @@ class OAuth2Credentials(Credentials):
""" """
self.store = store self.store = store
def _expires_in(self):
"""Get in how many seconds does the token expire."""
if self.token_expiry:
now = datetime.datetime.utcnow()
if self.token_expiry > now:
time_delta = self.token_expiry - now
return int(round(time_delta.days * 86400.0 +
time_delta.seconds +
time_delta.microseconds * 0.000001))
def _updateFromCredential(self, other): def _updateFromCredential(self, other):
"""Update this Credential from another instance.""" """Update this Credential from another instance."""
self.__dict__.update(other.__getstate__()) self.__dict__.update(other.__getstate__())

View File

@@ -245,7 +245,8 @@ class TestAppAssertionCredentials(unittest.TestCase):
credentials = AppAssertionCredentials(['dummy_scope']) credentials = AppAssertionCredentials(['dummy_scope'])
token = credentials.get_access_token() token = credentials.get_access_token()
self.assertEqual('a_token_123', token) self.assertEqual('a_token_123', token['access_token'])
self.assertEqual(None, token['expires_in'])
class TestFlowModel(db.Model): class TestFlowModel(db.Model):

View File

@@ -125,8 +125,9 @@ class AssertionCredentialsTests(unittest.TestCase):
http = httplib2.Http() http = httplib2.Http()
http.request = httplib2_request http.request = httplib2_request
self.assertEquals('this-is-a-token', token = credentials.get_access_token(http=http)
credentials.get_access_token(http=http)) self.assertEqual('this-is-a-token', token['access_token'])
self.assertEqual(None, token['expires_in'])
m.UnsetStubs() m.UnsetStubs()
m.VerifyAll() m.VerifyAll()

View File

@@ -590,21 +590,24 @@ class BasicCredentialsTests(unittest.TestCase):
({'status': '200'}, simplejson.dumps(token_response_second)), ({'status': '200'}, simplejson.dumps(token_response_second)),
]) ])
self.assertEqual('first_token', token = self.credentials.get_access_token(http=http)
self.credentials.get_access_token(http=http)) self.assertEqual('first_token', token['access_token'])
self.assertEqual(1, token['expires_in'])
self.assertFalse(self.credentials.access_token_expired) self.assertFalse(self.credentials.access_token_expired)
self.assertEqual(token_response_first, self.credentials.token_response) self.assertEqual(token_response_first, self.credentials.token_response)
self.assertEqual('first_token', token = self.credentials.get_access_token(http=http)
self.credentials.get_access_token(http=http)) self.assertEqual('first_token', token['access_token'])
self.assertEqual(1, token['expires_in'])
self.assertFalse(self.credentials.access_token_expired) self.assertFalse(self.credentials.access_token_expired)
self.assertEqual(token_response_first, self.credentials.token_response) self.assertEqual(token_response_first, self.credentials.token_response)
time.sleep(1) time.sleep(1)
self.assertTrue(self.credentials.access_token_expired) self.assertTrue(self.credentials.access_token_expired)
self.assertEqual('second_token', token = self.credentials.get_access_token(http=http)
self.credentials.get_access_token(http=http)) self.assertEqual('second_token', token['access_token'])
self.assertEqual(1, token['expires_in'])
self.assertFalse(self.credentials.access_token_expired) self.assertFalse(self.credentials.access_token_expired)
self.assertEqual(token_response_second, self.credentials.token_response) self.assertEqual(token_response_second, self.credentials.token_response)

View File

@@ -101,20 +101,23 @@ class ServiceAccountCredentialsTests(unittest.TestCase):
({'status': '200'}, simplejson.dumps(token_response_second)), ({'status': '200'}, simplejson.dumps(token_response_second)),
]) ])
self.assertEqual('first_token', token = self.credentials.get_access_token(http=http)
self.credentials.get_access_token(http=http)) self.assertEqual('first_token', token['access_token'])
self.assertEqual(1, token['expires_in'])
self.assertFalse(self.credentials.access_token_expired) self.assertFalse(self.credentials.access_token_expired)
self.assertEqual(token_response_first, self.credentials.token_response) self.assertEqual(token_response_first, self.credentials.token_response)
self.assertEqual('first_token', token = self.credentials.get_access_token(http=http)
self.credentials.get_access_token(http=http)) self.assertEqual('first_token', token['access_token'])
self.assertEqual(1, token['expires_in'])
self.assertFalse(self.credentials.access_token_expired) self.assertFalse(self.credentials.access_token_expired)
self.assertEqual(token_response_first, self.credentials.token_response) self.assertEqual(token_response_first, self.credentials.token_response)
time.sleep(1) time.sleep(1)
self.assertTrue(self.credentials.access_token_expired) self.assertTrue(self.credentials.access_token_expired)
self.assertEqual('second_token', token = self.credentials.get_access_token(http=http)
self.credentials.get_access_token(http=http)) self.assertEqual('second_token', token['access_token'])
self.assertEqual(1, token['expires_in'])
self.assertFalse(self.credentials.access_token_expired) self.assertFalse(self.credentials.access_token_expired)
self.assertEqual(token_response_second, self.credentials.token_response) self.assertEqual(token_response_second, self.credentials.token_response)