Expose the full token response in OAuth2Client and OAuth2Decorator.

Reviewed in https://codereview.appspot.com/7301099/.
This commit is contained in:
Joe Gregorio
2013-02-22 16:22:48 -05:00
parent e7bbbb94d0
commit cda875206a
4 changed files with 51 additions and 6 deletions

View File

@@ -26,6 +26,8 @@ import logging
import os import os
import pickle import pickle
import time import time
import urllib
import urlparse
from google.appengine.api import app_identity from google.appengine.api import app_identity
from google.appengine.api import memcache from google.appengine.api import memcache
@@ -34,6 +36,7 @@ from google.appengine.ext import db
from google.appengine.ext import webapp from google.appengine.ext import webapp
from google.appengine.ext.webapp.util import login_required from google.appengine.ext.webapp.util import login_required
from google.appengine.ext.webapp.util import run_wsgi_app from google.appengine.ext.webapp.util import run_wsgi_app
from apiclient import discovery
from oauth2client import GOOGLE_AUTH_URI from oauth2client import GOOGLE_AUTH_URI
from oauth2client import GOOGLE_REVOKE_URI from oauth2client import GOOGLE_REVOKE_URI
from oauth2client import GOOGLE_TOKEN_URI from oauth2client import GOOGLE_TOKEN_URI
@@ -55,6 +58,11 @@ try:
except ImportError: except ImportError:
ndb = None ndb = None
try:
from urlparse import parse_qsl
except ImportError:
from cgi import parse_qsl
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
OAUTH2CLIENT_NAMESPACE = 'oauth2client#ns' OAUTH2CLIENT_NAMESPACE = 'oauth2client#ns'
@@ -570,6 +578,7 @@ class OAuth2Decorator(object):
user_agent=None, user_agent=None,
message=None, message=None,
callback_path='/oauth2callback', callback_path='/oauth2callback',
token_response_param=None,
**kwargs): **kwargs):
"""Constructor for OAuth2Decorator """Constructor for OAuth2Decorator
@@ -592,6 +601,10 @@ class OAuth2Decorator(object):
callback_path: string, The absolute path to use as the callback URI. Note callback_path: string, The absolute path to use as the callback URI. Note
that this must match up with the URI given when registering the that this must match up with the URI given when registering the
application in the APIs Console. application in the APIs Console.
token_response_param: string. If provided, the full JSON response
to the access token request will be encoded and included in this query
parameter in the callback URI. This is useful with providers (e.g.
wordpress.com) that include extra fields that the client may want.
**kwargs: dict, Keyword arguments are be passed along as kwargs to the **kwargs: dict, Keyword arguments are be passed along as kwargs to the
OAuth2WebServerFlow constructor. OAuth2WebServerFlow constructor.
""" """
@@ -608,6 +621,7 @@ class OAuth2Decorator(object):
self._message = message self._message = message
self._in_error = False self._in_error = False
self._callback_path = callback_path self._callback_path = callback_path
self._token_response_param = token_response_param
def _display_error_message(self, request_handler): def _display_error_message(self, request_handler):
request_handler.response.out.write('<html><body>') request_handler.response.out.write('<html><body>')
@@ -782,6 +796,12 @@ class OAuth2Decorator(object):
CredentialsModel, user.user_id(), 'credentials').put(credentials) CredentialsModel, user.user_id(), 'credentials').put(credentials)
redirect_uri = _parse_state_value(str(self.request.get('state')), redirect_uri = _parse_state_value(str(self.request.get('state')),
user) user)
if decorator._token_response_param and credentials.token_response:
resp_json = simplejson.dumps(credentials.token_response)
redirect_uri = discovery._add_query_parameter(
redirect_uri, decorator._token_response_param, resp_json)
self.redirect(redirect_uri) self.redirect(redirect_uri)
return OAuth2Handler return OAuth2Handler

View File

@@ -393,7 +393,7 @@ class OAuth2Credentials(Credentials):
@util.positional(8) @util.positional(8)
def __init__(self, access_token, client_id, client_secret, refresh_token, def __init__(self, access_token, client_id, client_secret, refresh_token,
token_expiry, token_uri, user_agent, revoke_uri=None, token_expiry, token_uri, user_agent, revoke_uri=None,
id_token=None): id_token=None, token_response=None):
"""Create an instance of OAuth2Credentials. """Create an instance of OAuth2Credentials.
This constructor is not usually called by the user, instead This constructor is not usually called by the user, instead
@@ -410,6 +410,9 @@ class OAuth2Credentials(Credentials):
revoke_uri: string, URI for revoke endpoint. Defaults to None; a token revoke_uri: string, URI for revoke endpoint. Defaults to None; a token
can't be revoked if this is None. can't be revoked if this is None.
id_token: object, The identity of the resource owner. id_token: object, The identity of the resource owner.
token_response: dict, the decoded response to the token request. None
if a token hasn't been requested yet. Stored because some providers
(e.g. wordpress.com) include extra fields that clients may want.
Notes: Notes:
store: callable, A callable that when passed a Credential store: callable, A callable that when passed a Credential
@@ -427,6 +430,7 @@ class OAuth2Credentials(Credentials):
self.user_agent = user_agent self.user_agent = user_agent
self.revoke_uri = revoke_uri self.revoke_uri = revoke_uri
self.id_token = id_token self.id_token = id_token
self.token_response = token_response
# True if the credentials have been revoked or expired and can't be # True if the credentials have been revoked or expired and can't be
# refreshed. # refreshed.
@@ -559,7 +563,8 @@ class OAuth2Credentials(Credentials):
data['token_uri'], data['token_uri'],
data['user_agent'], data['user_agent'],
revoke_uri=data.get('revoke_uri', None), revoke_uri=data.get('revoke_uri', None),
id_token=data.get('id_token', None)) id_token=data.get('id_token', None),
token_response=data.get('token_response', None))
retval.invalid = data['invalid'] retval.invalid = data['invalid']
return retval return retval
@@ -678,6 +683,7 @@ class OAuth2Credentials(Credentials):
if resp.status == 200: if resp.status == 200:
# TODO(jcgregorio) Raise an error if loads fails? # TODO(jcgregorio) Raise an error if loads fails?
d = simplejson.loads(content) d = simplejson.loads(content)
self.token_response = d
self.access_token = d['access_token'] self.access_token = d['access_token']
self.refresh_token = d.get('refresh_token', self.refresh_token) self.refresh_token = d.get('refresh_token', self.refresh_token)
if 'expires_in' in d: if 'expires_in' in d:
@@ -1292,7 +1298,8 @@ class OAuth2WebServerFlow(Flow):
self.client_secret, refresh_token, token_expiry, self.client_secret, refresh_token, token_expiry,
self.token_uri, self.user_agent, self.token_uri, self.user_agent,
revoke_uri=self.revoke_uri, revoke_uri=self.revoke_uri,
id_token=d.get('id_token', None)) id_token=d.get('id_token', None),
token_response=d)
else: else:
logger.info('Failed to retrieve access token: %s' % content) logger.info('Failed to retrieve access token: %s' % content)
if 'error' in d: if 'error' in d:

View File

@@ -142,15 +142,17 @@ class BasicCredentialsTests(unittest.TestCase):
def test_token_refresh_success(self): def test_token_refresh_success(self):
for status_code in REFRESH_STATUS_CODES: for status_code in REFRESH_STATUS_CODES:
token_response = {'access_token': '1/3w', 'expires_in': 3600}
http = HttpMockSequence([ http = HttpMockSequence([
({'status': status_code}, ''), ({'status': status_code}, ''),
({'status': '200'}, '{"access_token":"1/3w","expires_in":3600}'), ({'status': '200'}, simplejson.dumps(token_response)),
({'status': '200'}, 'echo_request_headers'), ({'status': '200'}, 'echo_request_headers'),
]) ])
http = self.credentials.authorize(http) http = self.credentials.authorize(http)
resp, content = http.request('http://example.com') resp, content = http.request('http://example.com')
self.assertEqual('Bearer 1/3w', content['Authorization']) self.assertEqual('Bearer 1/3w', content['Authorization'])
self.assertFalse(self.credentials.access_token_expired) self.assertFalse(self.credentials.access_token_expired)
self.assertEqual(token_response, self.credentials.token_response)
def test_token_refresh_failure(self): def test_token_refresh_failure(self):
for status_code in REFRESH_STATUS_CODES: for status_code in REFRESH_STATUS_CODES:
@@ -165,6 +167,7 @@ class BasicCredentialsTests(unittest.TestCase):
except AccessTokenRefreshError: except AccessTokenRefreshError:
pass pass
self.assertTrue(self.credentials.access_token_expired) self.assertTrue(self.credentials.access_token_expired)
self.assertEqual(None, self.credentials.token_response)
def test_token_revoke_success(self): def test_token_revoke_success(self):
_token_revoke_test_helper( _token_revoke_test_helper(
@@ -183,6 +186,7 @@ class BasicCredentialsTests(unittest.TestCase):
http = self.credentials.authorize(http) http = self.credentials.authorize(http)
resp, content = http.request('http://example.com') resp, content = http.request('http://example.com')
self.assertEqual(400, resp.status) self.assertEqual(400, resp.status)
self.assertEqual(None, self.credentials.token_response)
def test_to_from_json(self): def test_to_from_json(self):
json = self.credentials.to_json() json = self.credentials.to_json()
@@ -221,6 +225,10 @@ class BasicCredentialsTests(unittest.TestCase):
except NonAsciiHeaderError: except NonAsciiHeaderError:
pass pass
self.credentials.token_response = 'foobar'
instance = OAuth2Credentials.from_json(self.credentials.to_json())
self.assertEqual('foobar', instance.token_response)
class AccessTokenCredentialsTests(unittest.TestCase): class AccessTokenCredentialsTests(unittest.TestCase):

View File

@@ -29,7 +29,7 @@ import mox
import os import os
import time import time
import unittest import unittest
import urlparse import urllib
try: try:
from urlparse import parse_qs from urlparse import parse_qs
@@ -121,6 +121,7 @@ class Http2Mock(object):
'access_token': 'foo_access_token', 'access_token': 'foo_access_token',
'refresh_token': 'foo_refresh_token', 'refresh_token': 'foo_refresh_token',
'expires_in': 3600, 'expires_in': 3600,
'extra': 'value',
} }
def request(self, token_uri, method, body, headers, *args, **kwargs): def request(self, token_uri, method, body, headers, *args, **kwargs):
@@ -499,8 +500,13 @@ class DecoratorTests(unittest.TestCase):
'code': 'foo_access_code', 'code': 'foo_access_code',
'state': 'foo_path:xsrfkey123', 'state': 'foo_path:xsrfkey123',
}) })
self.assertEqual('http://localhost/foo_path', response.headers['Location']) parts = response.headers['Location'].split('?', 1)
self.assertEqual('http://localhost/foo_path', parts[0])
self.assertEqual(None, self.decorator.credentials) self.assertEqual(None, self.decorator.credentials)
if self.decorator._token_response_param:
response = parse_qs(parts[1])[self.decorator._token_response_param][0]
self.assertEqual(Http2Mock.content,
simplejson.loads(urllib.unquote(response)))
m.UnsetStubs() m.UnsetStubs()
m.VerifyAll() m.VerifyAll()
@@ -629,6 +635,10 @@ class DecoratorTests(unittest.TestCase):
self.assertEqual('dummy_revoke_uri', decorator.flow.revoke_uri) self.assertEqual('dummy_revoke_uri', decorator.flow.revoke_uri)
self.assertEqual(None, decorator.flow.params.get('user_agent', None)) self.assertEqual(None, decorator.flow.params.get('user_agent', None))
def test_token_response_param(self):
self.decorator._token_response_param = 'foobar'
self.test_required()
def test_decorator_from_client_secrets(self): def test_decorator_from_client_secrets(self):
decorator = oauth2decorator_from_clientsecrets( decorator = oauth2decorator_from_clientsecrets(
datafile('client_secrets.json'), datafile('client_secrets.json'),