Merge pull request #15 from orestica/get_access_token

Modify the OAuth2Credentials.get_access_token() method.
This commit is contained in:
Craig Citro
2014-06-03 21:11:40 -07:00
5 changed files with 59 additions and 27 deletions

View File

@@ -31,6 +31,7 @@ import time
import urllib import urllib
import urlparse import urlparse
from collections import namedtuple
from oauth2client import GOOGLE_AUTH_URI from oauth2client import GOOGLE_AUTH_URI
from oauth2client import GOOGLE_REVOKE_URI from oauth2client import GOOGLE_REVOKE_URI
from oauth2client import GOOGLE_TOKEN_URI from oauth2client import GOOGLE_TOKEN_URI
@@ -75,6 +76,9 @@ SERVICE_ACCOUNT = 'service_account'
# The environment variable pointing the file with local Default Credentials. # The environment variable pointing the file with local Default Credentials.
GOOGLE_CREDENTIALS_DEFAULT = 'GOOGLE_CREDENTIALS_DEFAULT' GOOGLE_CREDENTIALS_DEFAULT = 'GOOGLE_CREDENTIALS_DEFAULT'
# The access token along with the seconds in which it expires.
AccessTokenInfo = namedtuple('AccessTokenInfo', ['access_token', 'expires_in'])
class Error(Exception): class Error(Exception):
"""Base error for this module.""" """Base error for this module."""
@@ -600,18 +604,17 @@ class OAuth2Credentials(Credentials):
return False return False
def get_access_token(self, http=None): def get_access_token(self, http=None):
"""Return the access token. """Return the access token and its expiration information.
If the token does not exist, get one. If the token does not exist, get one.
If the token expired, refresh it. If the token expired, refresh it.
""" """
if self.access_token and not self.access_token_expired: if not self.access_token or self.access_token_expired:
return self.access_token
else:
if not http: if not http:
http = httplib2.Http() http = httplib2.Http()
self.refresh(http) self.refresh(http)
return self.access_token return AccessTokenInfo(access_token=self.access_token,
expires_in=self._expires_in())
def set_store(self, store): def set_store(self, store):
"""Set the Storage for the credential. """Set the Storage for the credential.
@@ -625,6 +628,25 @@ class OAuth2Credentials(Credentials):
""" """
self.store = store self.store = store
def _expires_in(self):
"""Return the number of seconds until this token expires.
If token_expiry is in the past, this method will return 0, meaning the
token has already expired.
If token_expiry is None, this method will return None. Note that returning
0 in such a case would not be fair: the token may still be valid;
we just don't know anything about it.
"""
if self.token_expiry:
now = datetime.datetime.utcnow()
if self.token_expiry > now:
time_delta = self.token_expiry - now
# TODO(orestica): return time_delta.total_seconds()
# once dropping support for Python 2.6
return time_delta.days * 86400 + time_delta.seconds
else:
return 0
def _updateFromCredential(self, other): def _updateFromCredential(self, other):
"""Update this Credential from another instance.""" """Update this Credential from another instance."""
self.__dict__.update(other.__getstate__()) self.__dict__.update(other.__getstate__())

View File

@@ -245,7 +245,8 @@ class TestAppAssertionCredentials(unittest.TestCase):
credentials = AppAssertionCredentials(['dummy_scope']) credentials = AppAssertionCredentials(['dummy_scope'])
token = credentials.get_access_token() token = credentials.get_access_token()
self.assertEqual('a_token_123', token) self.assertEqual('a_token_123', token.access_token)
self.assertEqual(None, token.expires_in)
class TestFlowModel(db.Model): class TestFlowModel(db.Model):

View File

@@ -125,8 +125,9 @@ class AssertionCredentialsTests(unittest.TestCase):
http = httplib2.Http() http = httplib2.Http()
http.request = httplib2_request http.request = httplib2_request
self.assertEquals('this-is-a-token', token = credentials.get_access_token(http=http)
credentials.get_access_token(http=http)) self.assertEqual('this-is-a-token', token.access_token)
self.assertEqual(None, token.expires_in)
m.UnsetStubs() m.UnsetStubs()
m.VerifyAll() m.VerifyAll()

View File

@@ -583,28 +583,32 @@ class BasicCredentialsTests(unittest.TestCase):
self.assertEqual('foobar', instance.token_response) self.assertEqual('foobar', instance.token_response)
def test_get_access_token(self): def test_get_access_token(self):
token_response_first = {'access_token': 'first_token', 'expires_in': 1} S = 2 # number of seconds in which the token expires
token_response_second = {'access_token': 'second_token', 'expires_in': 1} token_response_first = {'access_token': 'first_token', 'expires_in': S}
token_response_second = {'access_token': 'second_token', 'expires_in': S}
http = HttpMockSequence([ http = HttpMockSequence([
({'status': '200'}, simplejson.dumps(token_response_first)), ({'status': '200'}, simplejson.dumps(token_response_first)),
({'status': '200'}, simplejson.dumps(token_response_second)), ({'status': '200'}, simplejson.dumps(token_response_second)),
]) ])
self.assertEqual('first_token', token = self.credentials.get_access_token(http=http)
self.credentials.get_access_token(http=http)) self.assertEqual('first_token', token.access_token)
self.assertEqual(S - 1, token.expires_in)
self.assertFalse(self.credentials.access_token_expired) self.assertFalse(self.credentials.access_token_expired)
self.assertEqual(token_response_first, self.credentials.token_response) self.assertEqual(token_response_first, self.credentials.token_response)
self.assertEqual('first_token', token = self.credentials.get_access_token(http=http)
self.credentials.get_access_token(http=http)) self.assertEqual('first_token', token.access_token)
self.assertEqual(S - 1, token.expires_in)
self.assertFalse(self.credentials.access_token_expired) self.assertFalse(self.credentials.access_token_expired)
self.assertEqual(token_response_first, self.credentials.token_response) self.assertEqual(token_response_first, self.credentials.token_response)
time.sleep(1) time.sleep(S)
self.assertTrue(self.credentials.access_token_expired) self.assertTrue(self.credentials.access_token_expired)
self.assertEqual('second_token', token = self.credentials.get_access_token(http=http)
self.credentials.get_access_token(http=http)) self.assertEqual('second_token', token.access_token)
self.assertEqual(S - 1, token.expires_in)
self.assertFalse(self.credentials.access_token_expired) self.assertFalse(self.credentials.access_token_expired)
self.assertEqual(token_response_second, self.credentials.token_response) self.assertEqual(token_response_second, self.credentials.token_response)

View File

@@ -94,27 +94,31 @@ class ServiceAccountCredentialsTests(unittest.TestCase):
self.assertEqual('dummy_scope', new_credentials._scopes) self.assertEqual('dummy_scope', new_credentials._scopes)
def test_access_token(self): def test_access_token(self):
token_response_first = {'access_token': 'first_token', 'expires_in': 1} S = 2 # number of seconds in which the token expires
token_response_second = {'access_token': 'second_token', 'expires_in': 1} token_response_first = {'access_token': 'first_token', 'expires_in': S}
token_response_second = {'access_token': 'second_token', 'expires_in': S}
http = HttpMockSequence([ http = HttpMockSequence([
({'status': '200'}, simplejson.dumps(token_response_first)), ({'status': '200'}, simplejson.dumps(token_response_first)),
({'status': '200'}, simplejson.dumps(token_response_second)), ({'status': '200'}, simplejson.dumps(token_response_second)),
]) ])
self.assertEqual('first_token', token = self.credentials.get_access_token(http=http)
self.credentials.get_access_token(http=http)) self.assertEqual('first_token', token.access_token)
self.assertEqual(S - 1, token.expires_in)
self.assertFalse(self.credentials.access_token_expired) self.assertFalse(self.credentials.access_token_expired)
self.assertEqual(token_response_first, self.credentials.token_response) self.assertEqual(token_response_first, self.credentials.token_response)
self.assertEqual('first_token', token = self.credentials.get_access_token(http=http)
self.credentials.get_access_token(http=http)) self.assertEqual('first_token', token.access_token)
self.assertEqual(S - 1, token.expires_in)
self.assertFalse(self.credentials.access_token_expired) self.assertFalse(self.credentials.access_token_expired)
self.assertEqual(token_response_first, self.credentials.token_response) self.assertEqual(token_response_first, self.credentials.token_response)
time.sleep(1) time.sleep(S)
self.assertTrue(self.credentials.access_token_expired) self.assertTrue(self.credentials.access_token_expired)
self.assertEqual('second_token', token = self.credentials.get_access_token(http=http)
self.credentials.get_access_token(http=http)) self.assertEqual('second_token', token.access_token)
self.assertEqual(S - 1, token.expires_in)
self.assertFalse(self.credentials.access_token_expired) self.assertFalse(self.credentials.access_token_expired)
self.assertEqual(token_response_second, self.credentials.token_response) self.assertEqual(token_response_second, self.credentials.token_response)