Expose the full token response in OAuth2Client and OAuth2Decorator.
Reviewed in https://codereview.appspot.com/7301099/.
This commit is contained in:
@@ -26,6 +26,8 @@ import logging
|
|||||||
import os
|
import os
|
||||||
import pickle
|
import pickle
|
||||||
import time
|
import time
|
||||||
|
import urllib
|
||||||
|
import urlparse
|
||||||
|
|
||||||
from google.appengine.api import app_identity
|
from google.appengine.api import app_identity
|
||||||
from google.appengine.api import memcache
|
from google.appengine.api import memcache
|
||||||
@@ -34,6 +36,7 @@ from google.appengine.ext import db
|
|||||||
from google.appengine.ext import webapp
|
from google.appengine.ext import webapp
|
||||||
from google.appengine.ext.webapp.util import login_required
|
from google.appengine.ext.webapp.util import login_required
|
||||||
from google.appengine.ext.webapp.util import run_wsgi_app
|
from google.appengine.ext.webapp.util import run_wsgi_app
|
||||||
|
from apiclient import discovery
|
||||||
from oauth2client import GOOGLE_AUTH_URI
|
from oauth2client import GOOGLE_AUTH_URI
|
||||||
from oauth2client import GOOGLE_REVOKE_URI
|
from oauth2client import GOOGLE_REVOKE_URI
|
||||||
from oauth2client import GOOGLE_TOKEN_URI
|
from oauth2client import GOOGLE_TOKEN_URI
|
||||||
@@ -55,6 +58,11 @@ try:
|
|||||||
except ImportError:
|
except ImportError:
|
||||||
ndb = None
|
ndb = None
|
||||||
|
|
||||||
|
try:
|
||||||
|
from urlparse import parse_qsl
|
||||||
|
except ImportError:
|
||||||
|
from cgi import parse_qsl
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
OAUTH2CLIENT_NAMESPACE = 'oauth2client#ns'
|
OAUTH2CLIENT_NAMESPACE = 'oauth2client#ns'
|
||||||
@@ -570,6 +578,7 @@ class OAuth2Decorator(object):
|
|||||||
user_agent=None,
|
user_agent=None,
|
||||||
message=None,
|
message=None,
|
||||||
callback_path='/oauth2callback',
|
callback_path='/oauth2callback',
|
||||||
|
token_response_param=None,
|
||||||
**kwargs):
|
**kwargs):
|
||||||
|
|
||||||
"""Constructor for OAuth2Decorator
|
"""Constructor for OAuth2Decorator
|
||||||
@@ -592,6 +601,10 @@ class OAuth2Decorator(object):
|
|||||||
callback_path: string, The absolute path to use as the callback URI. Note
|
callback_path: string, The absolute path to use as the callback URI. Note
|
||||||
that this must match up with the URI given when registering the
|
that this must match up with the URI given when registering the
|
||||||
application in the APIs Console.
|
application in the APIs Console.
|
||||||
|
token_response_param: string. If provided, the full JSON response
|
||||||
|
to the access token request will be encoded and included in this query
|
||||||
|
parameter in the callback URI. This is useful with providers (e.g.
|
||||||
|
wordpress.com) that include extra fields that the client may want.
|
||||||
**kwargs: dict, Keyword arguments are be passed along as kwargs to the
|
**kwargs: dict, Keyword arguments are be passed along as kwargs to the
|
||||||
OAuth2WebServerFlow constructor.
|
OAuth2WebServerFlow constructor.
|
||||||
"""
|
"""
|
||||||
@@ -608,6 +621,7 @@ class OAuth2Decorator(object):
|
|||||||
self._message = message
|
self._message = message
|
||||||
self._in_error = False
|
self._in_error = False
|
||||||
self._callback_path = callback_path
|
self._callback_path = callback_path
|
||||||
|
self._token_response_param = token_response_param
|
||||||
|
|
||||||
def _display_error_message(self, request_handler):
|
def _display_error_message(self, request_handler):
|
||||||
request_handler.response.out.write('<html><body>')
|
request_handler.response.out.write('<html><body>')
|
||||||
@@ -782,6 +796,12 @@ class OAuth2Decorator(object):
|
|||||||
CredentialsModel, user.user_id(), 'credentials').put(credentials)
|
CredentialsModel, user.user_id(), 'credentials').put(credentials)
|
||||||
redirect_uri = _parse_state_value(str(self.request.get('state')),
|
redirect_uri = _parse_state_value(str(self.request.get('state')),
|
||||||
user)
|
user)
|
||||||
|
|
||||||
|
if decorator._token_response_param and credentials.token_response:
|
||||||
|
resp_json = simplejson.dumps(credentials.token_response)
|
||||||
|
redirect_uri = discovery._add_query_parameter(
|
||||||
|
redirect_uri, decorator._token_response_param, resp_json)
|
||||||
|
|
||||||
self.redirect(redirect_uri)
|
self.redirect(redirect_uri)
|
||||||
|
|
||||||
return OAuth2Handler
|
return OAuth2Handler
|
||||||
|
|||||||
@@ -393,7 +393,7 @@ class OAuth2Credentials(Credentials):
|
|||||||
@util.positional(8)
|
@util.positional(8)
|
||||||
def __init__(self, access_token, client_id, client_secret, refresh_token,
|
def __init__(self, access_token, client_id, client_secret, refresh_token,
|
||||||
token_expiry, token_uri, user_agent, revoke_uri=None,
|
token_expiry, token_uri, user_agent, revoke_uri=None,
|
||||||
id_token=None):
|
id_token=None, token_response=None):
|
||||||
"""Create an instance of OAuth2Credentials.
|
"""Create an instance of OAuth2Credentials.
|
||||||
|
|
||||||
This constructor is not usually called by the user, instead
|
This constructor is not usually called by the user, instead
|
||||||
@@ -410,6 +410,9 @@ class OAuth2Credentials(Credentials):
|
|||||||
revoke_uri: string, URI for revoke endpoint. Defaults to None; a token
|
revoke_uri: string, URI for revoke endpoint. Defaults to None; a token
|
||||||
can't be revoked if this is None.
|
can't be revoked if this is None.
|
||||||
id_token: object, The identity of the resource owner.
|
id_token: object, The identity of the resource owner.
|
||||||
|
token_response: dict, the decoded response to the token request. None
|
||||||
|
if a token hasn't been requested yet. Stored because some providers
|
||||||
|
(e.g. wordpress.com) include extra fields that clients may want.
|
||||||
|
|
||||||
Notes:
|
Notes:
|
||||||
store: callable, A callable that when passed a Credential
|
store: callable, A callable that when passed a Credential
|
||||||
@@ -427,6 +430,7 @@ class OAuth2Credentials(Credentials):
|
|||||||
self.user_agent = user_agent
|
self.user_agent = user_agent
|
||||||
self.revoke_uri = revoke_uri
|
self.revoke_uri = revoke_uri
|
||||||
self.id_token = id_token
|
self.id_token = id_token
|
||||||
|
self.token_response = token_response
|
||||||
|
|
||||||
# True if the credentials have been revoked or expired and can't be
|
# True if the credentials have been revoked or expired and can't be
|
||||||
# refreshed.
|
# refreshed.
|
||||||
@@ -559,7 +563,8 @@ class OAuth2Credentials(Credentials):
|
|||||||
data['token_uri'],
|
data['token_uri'],
|
||||||
data['user_agent'],
|
data['user_agent'],
|
||||||
revoke_uri=data.get('revoke_uri', None),
|
revoke_uri=data.get('revoke_uri', None),
|
||||||
id_token=data.get('id_token', None))
|
id_token=data.get('id_token', None),
|
||||||
|
token_response=data.get('token_response', None))
|
||||||
retval.invalid = data['invalid']
|
retval.invalid = data['invalid']
|
||||||
return retval
|
return retval
|
||||||
|
|
||||||
@@ -678,6 +683,7 @@ class OAuth2Credentials(Credentials):
|
|||||||
if resp.status == 200:
|
if resp.status == 200:
|
||||||
# TODO(jcgregorio) Raise an error if loads fails?
|
# TODO(jcgregorio) Raise an error if loads fails?
|
||||||
d = simplejson.loads(content)
|
d = simplejson.loads(content)
|
||||||
|
self.token_response = d
|
||||||
self.access_token = d['access_token']
|
self.access_token = d['access_token']
|
||||||
self.refresh_token = d.get('refresh_token', self.refresh_token)
|
self.refresh_token = d.get('refresh_token', self.refresh_token)
|
||||||
if 'expires_in' in d:
|
if 'expires_in' in d:
|
||||||
@@ -1292,7 +1298,8 @@ class OAuth2WebServerFlow(Flow):
|
|||||||
self.client_secret, refresh_token, token_expiry,
|
self.client_secret, refresh_token, token_expiry,
|
||||||
self.token_uri, self.user_agent,
|
self.token_uri, self.user_agent,
|
||||||
revoke_uri=self.revoke_uri,
|
revoke_uri=self.revoke_uri,
|
||||||
id_token=d.get('id_token', None))
|
id_token=d.get('id_token', None),
|
||||||
|
token_response=d)
|
||||||
else:
|
else:
|
||||||
logger.info('Failed to retrieve access token: %s' % content)
|
logger.info('Failed to retrieve access token: %s' % content)
|
||||||
if 'error' in d:
|
if 'error' in d:
|
||||||
|
|||||||
@@ -142,15 +142,17 @@ class BasicCredentialsTests(unittest.TestCase):
|
|||||||
|
|
||||||
def test_token_refresh_success(self):
|
def test_token_refresh_success(self):
|
||||||
for status_code in REFRESH_STATUS_CODES:
|
for status_code in REFRESH_STATUS_CODES:
|
||||||
|
token_response = {'access_token': '1/3w', 'expires_in': 3600}
|
||||||
http = HttpMockSequence([
|
http = HttpMockSequence([
|
||||||
({'status': status_code}, ''),
|
({'status': status_code}, ''),
|
||||||
({'status': '200'}, '{"access_token":"1/3w","expires_in":3600}'),
|
({'status': '200'}, simplejson.dumps(token_response)),
|
||||||
({'status': '200'}, 'echo_request_headers'),
|
({'status': '200'}, 'echo_request_headers'),
|
||||||
])
|
])
|
||||||
http = self.credentials.authorize(http)
|
http = self.credentials.authorize(http)
|
||||||
resp, content = http.request('http://example.com')
|
resp, content = http.request('http://example.com')
|
||||||
self.assertEqual('Bearer 1/3w', content['Authorization'])
|
self.assertEqual('Bearer 1/3w', content['Authorization'])
|
||||||
self.assertFalse(self.credentials.access_token_expired)
|
self.assertFalse(self.credentials.access_token_expired)
|
||||||
|
self.assertEqual(token_response, self.credentials.token_response)
|
||||||
|
|
||||||
def test_token_refresh_failure(self):
|
def test_token_refresh_failure(self):
|
||||||
for status_code in REFRESH_STATUS_CODES:
|
for status_code in REFRESH_STATUS_CODES:
|
||||||
@@ -165,6 +167,7 @@ class BasicCredentialsTests(unittest.TestCase):
|
|||||||
except AccessTokenRefreshError:
|
except AccessTokenRefreshError:
|
||||||
pass
|
pass
|
||||||
self.assertTrue(self.credentials.access_token_expired)
|
self.assertTrue(self.credentials.access_token_expired)
|
||||||
|
self.assertEqual(None, self.credentials.token_response)
|
||||||
|
|
||||||
def test_token_revoke_success(self):
|
def test_token_revoke_success(self):
|
||||||
_token_revoke_test_helper(
|
_token_revoke_test_helper(
|
||||||
@@ -183,6 +186,7 @@ class BasicCredentialsTests(unittest.TestCase):
|
|||||||
http = self.credentials.authorize(http)
|
http = self.credentials.authorize(http)
|
||||||
resp, content = http.request('http://example.com')
|
resp, content = http.request('http://example.com')
|
||||||
self.assertEqual(400, resp.status)
|
self.assertEqual(400, resp.status)
|
||||||
|
self.assertEqual(None, self.credentials.token_response)
|
||||||
|
|
||||||
def test_to_from_json(self):
|
def test_to_from_json(self):
|
||||||
json = self.credentials.to_json()
|
json = self.credentials.to_json()
|
||||||
@@ -221,6 +225,10 @@ class BasicCredentialsTests(unittest.TestCase):
|
|||||||
except NonAsciiHeaderError:
|
except NonAsciiHeaderError:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
self.credentials.token_response = 'foobar'
|
||||||
|
instance = OAuth2Credentials.from_json(self.credentials.to_json())
|
||||||
|
self.assertEqual('foobar', instance.token_response)
|
||||||
|
|
||||||
|
|
||||||
class AccessTokenCredentialsTests(unittest.TestCase):
|
class AccessTokenCredentialsTests(unittest.TestCase):
|
||||||
|
|
||||||
|
|||||||
@@ -29,7 +29,7 @@ import mox
|
|||||||
import os
|
import os
|
||||||
import time
|
import time
|
||||||
import unittest
|
import unittest
|
||||||
import urlparse
|
import urllib
|
||||||
|
|
||||||
try:
|
try:
|
||||||
from urlparse import parse_qs
|
from urlparse import parse_qs
|
||||||
@@ -121,6 +121,7 @@ class Http2Mock(object):
|
|||||||
'access_token': 'foo_access_token',
|
'access_token': 'foo_access_token',
|
||||||
'refresh_token': 'foo_refresh_token',
|
'refresh_token': 'foo_refresh_token',
|
||||||
'expires_in': 3600,
|
'expires_in': 3600,
|
||||||
|
'extra': 'value',
|
||||||
}
|
}
|
||||||
|
|
||||||
def request(self, token_uri, method, body, headers, *args, **kwargs):
|
def request(self, token_uri, method, body, headers, *args, **kwargs):
|
||||||
@@ -499,8 +500,13 @@ class DecoratorTests(unittest.TestCase):
|
|||||||
'code': 'foo_access_code',
|
'code': 'foo_access_code',
|
||||||
'state': 'foo_path:xsrfkey123',
|
'state': 'foo_path:xsrfkey123',
|
||||||
})
|
})
|
||||||
self.assertEqual('http://localhost/foo_path', response.headers['Location'])
|
parts = response.headers['Location'].split('?', 1)
|
||||||
|
self.assertEqual('http://localhost/foo_path', parts[0])
|
||||||
self.assertEqual(None, self.decorator.credentials)
|
self.assertEqual(None, self.decorator.credentials)
|
||||||
|
if self.decorator._token_response_param:
|
||||||
|
response = parse_qs(parts[1])[self.decorator._token_response_param][0]
|
||||||
|
self.assertEqual(Http2Mock.content,
|
||||||
|
simplejson.loads(urllib.unquote(response)))
|
||||||
|
|
||||||
m.UnsetStubs()
|
m.UnsetStubs()
|
||||||
m.VerifyAll()
|
m.VerifyAll()
|
||||||
@@ -629,6 +635,10 @@ class DecoratorTests(unittest.TestCase):
|
|||||||
self.assertEqual('dummy_revoke_uri', decorator.flow.revoke_uri)
|
self.assertEqual('dummy_revoke_uri', decorator.flow.revoke_uri)
|
||||||
self.assertEqual(None, decorator.flow.params.get('user_agent', None))
|
self.assertEqual(None, decorator.flow.params.get('user_agent', None))
|
||||||
|
|
||||||
|
def test_token_response_param(self):
|
||||||
|
self.decorator._token_response_param = 'foobar'
|
||||||
|
self.test_required()
|
||||||
|
|
||||||
def test_decorator_from_client_secrets(self):
|
def test_decorator_from_client_secrets(self):
|
||||||
decorator = oauth2decorator_from_clientsecrets(
|
decorator = oauth2decorator_from_clientsecrets(
|
||||||
datafile('client_secrets.json'),
|
datafile('client_secrets.json'),
|
||||||
|
|||||||
Reference in New Issue
Block a user