Fixing incremental auth in flask_util.

* Flow is now stored in the session, ensuring that the scopes survive the round trip to the auth server.
* The base credentials are now checked in `required`, solving an issue where it's possible to pass the check with the incremental scopes but without the base scopes.

Fixes #320
This commit is contained in:
Jon Wayne Parrott
2015-10-12 11:20:32 -07:00
parent 55c2bcc857
commit 07ea7c5475
2 changed files with 191 additions and 114 deletions

View File

@@ -23,7 +23,7 @@ available.
Configuration Configuration
============= =============
To configure, you'll need a set of OAuth2 client ID from the To configure, you'll need a set of OAuth2 web application credentials from the
`Google Developer's Console <https://console.developers.google.com/project/_/\ `Google Developer's Console <https://console.developers.google.com/project/_/\
apiui/credential>`__. apiui/credential>`__.
@@ -164,6 +164,7 @@ available outside of a request context, you will need to implement your own
import hashlib import hashlib
import json import json
import os import os
import pickle
from functools import wraps from functools import wraps
import six.moves.http_client as httplib import six.moves.http_client as httplib
@@ -185,12 +186,26 @@ from oauth2client.client import OAuth2Credentials
from oauth2client.client import OAuth2WebServerFlow from oauth2client.client import OAuth2WebServerFlow
from oauth2client.client import Storage from oauth2client.client import Storage
from oauth2client import clientsecrets from oauth2client import clientsecrets
from oauth2client import util
__author__ = 'jonwayne@google.com (Jon Wayne Parrott)' __author__ = 'jonwayne@google.com (Jon Wayne Parrott)'
DEFAULT_SCOPES = ('email',) _DEFAULT_SCOPES = ('email',)
_CREDENTIALS_KEY = 'google_oauth2_credentials'
_FLOW_KEY = 'google_oauth2_flow_{0}'
_CSRF_KEY = 'google_oauth2_csrf_token'
def _get_flow_for_token(csrf_token):
"""Retrieves the flow instance associated with a given CSRF token from
the Flask session."""
flow_pickle = session.get(
_FLOW_KEY.format(csrf_token), None)
if flow_pickle is None:
return None
else:
return pickle.loads(flow_pickle)
class UserOAuth2(object): class UserOAuth2(object):
@@ -250,7 +265,7 @@ class UserOAuth2(object):
self.storage = storage self.storage = storage
if scopes is None: if scopes is None:
scopes = app.config.get('GOOGLE_OAUTH2_SCOPES', DEFAULT_SCOPES) scopes = app.config.get('GOOGLE_OAUTH2_SCOPES', _DEFAULT_SCOPES)
self.scopes = scopes self.scopes = scopes
self._load_config(client_secrets_file, client_id, client_secret) self._load_config(client_secrets_file, client_id, client_secret)
@@ -300,7 +315,8 @@ class UserOAuth2(object):
client_type, client_info = clientsecrets.loadfile(filename) client_type, client_info = clientsecrets.loadfile(filename)
if client_type != clientsecrets.TYPE_WEB: if client_type != clientsecrets.TYPE_WEB:
raise ValueError( raise ValueError(
'The flow specified in %s is not supported.' % client_type) 'The flow specified in {0} is not supported.'.format(
client_type))
self.client_id = client_info['client_id'] self.client_id = client_info['client_id']
self.client_secret = client_info['client_secret'] self.client_secret = client_info['client_secret']
@@ -310,7 +326,7 @@ class UserOAuth2(object):
# Generate a CSRF token to prevent malicious requests. # Generate a CSRF token to prevent malicious requests.
csrf_token = hashlib.sha256(os.urandom(1024)).hexdigest() csrf_token = hashlib.sha256(os.urandom(1024)).hexdigest()
session['google_oauth2_csrf_token'] = csrf_token session[_CSRF_KEY] = csrf_token
state = json.dumps({ state = json.dumps({
'csrf_token': csrf_token, 'csrf_token': csrf_token,
@@ -320,10 +336,10 @@ class UserOAuth2(object):
kw = self.flow_kwargs.copy() kw = self.flow_kwargs.copy()
kw.update(kwargs) kw.update(kwargs)
extra_scopes = util.scopes_to_string(kw.pop('scopes', '')) extra_scopes = kw.pop('scopes', [])
scopes = ' '.join([util.scopes_to_string(self.scopes), extra_scopes]) scopes = set(self.scopes).union(set(extra_scopes))
return OAuth2WebServerFlow( flow = OAuth2WebServerFlow(
client_id=self.client_id, client_id=self.client_id,
client_secret=self.client_secret, client_secret=self.client_secret,
scope=scopes, scope=scopes,
@@ -331,6 +347,11 @@ class UserOAuth2(object):
redirect_uri=url_for('oauth2.callback', _external=True), redirect_uri=url_for('oauth2.callback', _external=True),
**kw) **kw)
flow_key = _FLOW_KEY.format(csrf_token)
session[flow_key] = pickle.dumps(flow)
return flow
def _create_blueprint(self): def _create_blueprint(self):
bp = Blueprint('oauth2', __name__) bp = Blueprint('oauth2', __name__)
bp.add_url_rule('/oauth2authorize', 'authorize', self.authorize_view) bp.add_url_rule('/oauth2authorize', 'authorize', self.authorize_view)
@@ -367,11 +388,12 @@ class UserOAuth2(object):
if 'error' in request.args: if 'error' in request.args:
reason = request.args.get( reason = request.args.get(
'error_description', request.args.get('error', '')) 'error_description', request.args.get('error', ''))
return 'Authorization failed: %s' % reason, httplib.BAD_REQUEST return ('Authorization failed: {0}'.format(reason),
httplib.BAD_REQUEST)
try: try:
encoded_state = request.args['state'] encoded_state = request.args['state']
server_csrf = session['google_oauth2_csrf_token'] server_csrf = session[_CSRF_KEY]
code = request.args['code'] code = request.args['code']
except KeyError: except KeyError:
return 'Invalid request', httplib.BAD_REQUEST return 'Invalid request', httplib.BAD_REQUEST
@@ -386,14 +408,17 @@ class UserOAuth2(object):
if client_csrf != server_csrf: if client_csrf != server_csrf:
return 'Invalid request state', httplib.BAD_REQUEST return 'Invalid request state', httplib.BAD_REQUEST
flow = self._make_flow() flow = _get_flow_for_token(server_csrf)
if flow is None:
return 'Invalid request state', httplib.BAD_REQUEST
# Exchange the auth code for credentials. # Exchange the auth code for credentials.
try: try:
credentials = flow.step2_exchange(code) credentials = flow.step2_exchange(code)
except FlowExchangeError as exchange_error: except FlowExchangeError as exchange_error:
current_app.logger.exception(exchange_error) current_app.logger.exception(exchange_error)
content = 'An error occurred: %s' % (exchange_error,) content = 'An error occurred: {0}'.format(exchange_error)
return content, httplib.BAD_REQUEST return content, httplib.BAD_REQUEST
# Save the credentials to the storage. # Save the credentials to the storage.
@@ -409,7 +434,7 @@ class UserOAuth2(object):
"""The credentials for the current user or None if unavailable.""" """The credentials for the current user or None if unavailable."""
ctx = _app_ctx_stack.top ctx = _app_ctx_stack.top
if not hasattr(ctx, 'google_oauth2_credentials'): if not hasattr(ctx, _CREDENTIALS_KEY):
ctx.google_oauth2_credentials = self.storage.get() ctx.google_oauth2_credentials = self.storage.get()
return ctx.google_oauth2_credentials return ctx.google_oauth2_credentials
@@ -432,7 +457,7 @@ class UserOAuth2(object):
return self.credentials.id_token['email'] return self.credentials.id_token['email']
except KeyError: except KeyError:
current_app.logger.error( current_app.logger.error(
'Invalid id_token %s', self.credentials.id_token) 'Invalid id_token {0}'.format(self.credentials.id_token))
@property @property
def user_id(self): def user_id(self):
@@ -448,7 +473,7 @@ class UserOAuth2(object):
return self.credentials.id_token['sub'] return self.credentials.id_token['sub']
except KeyError: except KeyError:
current_app.logger.error( current_app.logger.error(
'Invalid id_token %s', self.credentials.id_token) 'Invalid id_token {0}'.format(self.credentials.id_token))
def authorize_url(self, return_url, **kwargs): def authorize_url(self, return_url, **kwargs):
"""Creates a URL that can be used to start the authorization flow. """Creates a URL that can be used to start the authorization flow.
@@ -473,28 +498,30 @@ class UserOAuth2(object):
def curry_wrapper(wrapped_function): def curry_wrapper(wrapped_function):
@wraps(wrapped_function) @wraps(wrapped_function)
def required_wrapper(*args, **kwargs): def required_wrapper(*args, **kwargs):
return_url = decorator_kwargs.pop('return_url', request.url) return_url = decorator_kwargs.pop('return_url', request.url)
# No credentials, redirect for new authorization. requested_scopes = set(self.scopes)
if not self.has_credentials(): if scopes is not None:
requested_scopes |= set(scopes)
if self.has_credentials():
requested_scopes |= self.credentials.scopes
requested_scopes = list(requested_scopes)
# Does the user have credentials and does the credentials have
# all of the needed scopes?
if (self.has_credentials() and
self.credentials.has_scopes(requested_scopes)):
return wrapped_function(*args, **kwargs)
# Otherwise, redirect to authorization
else:
auth_url = self.authorize_url( auth_url = self.authorize_url(
return_url, return_url,
scopes=scopes, scopes=requested_scopes,
**decorator_kwargs) **decorator_kwargs)
return redirect(auth_url)
# Existing credentials but mismatching scopes, redirect for
# incremental authorization.
if scopes and not self.credentials.has_scopes(scopes):
auth_url = self.authorize_url(
return_url,
scopes=list(self.credentials.scopes) + scopes,
**decorator_kwargs)
return redirect(auth_url) return redirect(auth_url)
return wrapped_function(*args, **kwargs)
return required_wrapper return required_wrapper
if decorated_function: if decorated_function:
@@ -530,7 +557,7 @@ class FlaskSessionStorage(Storage):
""" """
def locked_get(self): def locked_get(self):
serialized = session.get('google_oauth2_credentials') serialized = session.get(_CREDENTIALS_KEY)
if serialized is None: if serialized is None:
return None return None
@@ -541,8 +568,8 @@ class FlaskSessionStorage(Storage):
return credentials return credentials
def locked_put(self, credentials): def locked_put(self, credentials):
session['google_oauth2_credentials'] = credentials.to_json() session[_CREDENTIALS_KEY] = credentials.to_json()
def locked_delete(self): def locked_delete(self):
if 'google_oauth2_credentials' in session: if _CREDENTIALS_KEY in session:
del session['google_oauth2_credentials'] del session[_CREDENTIALS_KEY]

View File

@@ -26,6 +26,7 @@ import six.moves.urllib.parse as urlparse
from oauth2client import GOOGLE_AUTH_URI from oauth2client import GOOGLE_AUTH_URI
from oauth2client import GOOGLE_TOKEN_URI from oauth2client import GOOGLE_TOKEN_URI
from oauth2client import clientsecrets from oauth2client import clientsecrets
from oauth2client.flask_util import _get_flow_for_token
from oauth2client.flask_util import UserOAuth2 as FlaskOAuth2 from oauth2client.flask_util import UserOAuth2 as FlaskOAuth2
from oauth2client.client import OAuth2Credentials from oauth2client.client import OAuth2Credentials
@@ -201,9 +202,9 @@ class FlaskOAuth2Tests(unittest.TestCase):
self.assertEqual(flow.params['extra_arg'], 'test') self.assertEqual(flow.params['extra_arg'], 'test')
def test_authorize_view(self): def test_authorize_view(self):
with self.app.test_client() as c: with self.app.test_client() as client:
rv = c.get('/oauth2authorize') response = client.get('/oauth2authorize')
location = rv.headers['Location'] location = response.headers['Location']
q = urlparse.parse_qs(location.split('?', 1)[1]) q = urlparse.parse_qs(location.split('?', 1)[1])
state = json.loads(q['state'][0]) state = json.loads(q['state'][0])
@@ -214,35 +215,47 @@ class FlaskOAuth2Tests(unittest.TestCase):
flask.session['google_oauth2_csrf_token'], state['csrf_token']) flask.session['google_oauth2_csrf_token'], state['csrf_token'])
self.assertEqual(state['return_url'], '/') self.assertEqual(state['return_url'], '/')
with self.app.test_client() as c: with self.app.test_client() as client:
rv = c.get('/oauth2authorize?return_url=/test') response = client.get('/oauth2authorize?return_url=/test')
location = rv.headers['Location'] location = response.headers['Location']
q = urlparse.parse_qs(location.split('?', 1)[1]) q = urlparse.parse_qs(location.split('?', 1)[1])
state = json.loads(q['state'][0]) state = json.loads(q['state'][0])
self.assertEqual(state['return_url'], '/test') self.assertEqual(state['return_url'], '/test')
with self.app.test_client() as c: with self.app.test_client() as client:
rv = c.get('/oauth2authorize?extra_param=test') response = client.get('/oauth2authorize?extra_param=test')
location = rv.headers['Location'] location = response.headers['Location']
self.assertTrue('extra_param=test' in location) self.assertTrue('extra_param=test' in location)
def _setup_callback_state(self, client, **kwargs):
with self.app.test_request_context():
# Flask doesn't create a request context with a session
# transaction for some reason, so, set up the flow here,
# then apply it to the session in the transaction.
if not kwargs:
self.oauth2._make_flow(return_url='/return_url')
else:
self.oauth2._make_flow(**kwargs)
with client.session_transaction() as session:
session.update(flask.session)
csrf_token = session['google_oauth2_csrf_token']
flow = _get_flow_for_token(csrf_token)
state = flow.params['state']
return state
def test_callback_view(self): def test_callback_view(self):
self.oauth2.storage = mock.Mock() self.oauth2.storage = mock.Mock()
with self.app.test_client() as client:
with self.app.test_client() as c:
with Http2Mock() as http: with Http2Mock() as http:
with c.session_transaction() as session: state = self._setup_callback_state(client)
session['google_oauth2_csrf_token'] = 'tokenz'
state = json.dumps({ response = client.get(
'csrf_token': 'tokenz', '/oauth2callback?state={0}&code=codez'.format(state))
'return_url': '/return_url'
})
rv = c.get('/oauth2callback?state=%s&code=codez' % state) self.assertEqual(response.status_code, httplib.FOUND)
self.assertTrue('/return_url' in response.headers['Location'])
self.assertEqual(rv.status_code, httplib.FOUND)
self.assertTrue('/return_url' in rv.headers['Location'])
self.assertTrue(self.oauth2.client_secret in http.body) self.assertTrue(self.oauth2.client_secret in http.body)
self.assertTrue('codez' in http.body) self.assertTrue('codez' in http.body)
self.assertTrue(self.oauth2.storage.put.called) self.assertTrue(self.oauth2.storage.put.called)
@@ -254,17 +267,17 @@ class FlaskOAuth2Tests(unittest.TestCase):
def test_callback_view_errors(self): def test_callback_view_errors(self):
# Error supplied to callback # Error supplied to callback
with self.app.test_client() as c: with self.app.test_client() as client:
with c.session_transaction() as session: with client.session_transaction() as session:
session['google_oauth2_csrf_token'] = 'tokenz' session['google_oauth2_csrf_token'] = 'tokenz'
rv = c.get('/oauth2callback?state={}&error=something') response = client.get('/oauth2callback?state={}&error=something')
self.assertEqual(rv.status_code, httplib.BAD_REQUEST) self.assertEqual(response.status_code, httplib.BAD_REQUEST)
self.assertTrue('something' in rv.data.decode('utf-8')) self.assertTrue('something' in response.data.decode('utf-8'))
# CSRF mismatch # CSRF mismatch
with self.app.test_client() as c: with self.app.test_client() as client:
with c.session_transaction() as session: with client.session_transaction() as session:
session['google_oauth2_csrf_token'] = 'goodstate' session['google_oauth2_csrf_token'] = 'goodstate'
state = json.dumps({ state = json.dumps({
@@ -272,36 +285,47 @@ class FlaskOAuth2Tests(unittest.TestCase):
'return_url': '/return_url' 'return_url': '/return_url'
}) })
rv = c.get('/oauth2callback?state=%s&code=codez' % state) response = client.get(
self.assertEqual(rv.status_code, httplib.BAD_REQUEST) '/oauth2callback?state={0}&code=codez'.format(state))
self.assertEqual(response.status_code, httplib.BAD_REQUEST)
# KeyError, no CSRF state. # KeyError, no CSRF state.
with self.app.test_client() as c: with self.app.test_client() as client:
rv = c.get('/oauth2callback?state={}&code=codez') response = client.get('/oauth2callback?state={}&code=codez')
self.assertEqual(rv.status_code, httplib.BAD_REQUEST) self.assertEqual(response.status_code, httplib.BAD_REQUEST)
# Code exchange error # Code exchange error
with self.app.test_client() as c: with self.app.test_client() as client:
with Http2Mock(status=500): state = self._setup_callback_state(client)
with c.session_transaction() as session:
session['google_oauth2_csrf_token'] = 'tokenz'
state = json.dumps({ with Http2Mock(status=httplib.INTERNAL_SERVER_ERROR):
'csrf_token': 'tokenz', response = client.get(
'return_url': '/return_url' '/oauth2callback?state={0}&code=codez'.format(state))
}) self.assertEqual(response.status_code, httplib.BAD_REQUEST)
rv = c.get('/oauth2callback?state=%s&code=codez' % state)
self.assertEqual(rv.status_code, httplib.BAD_REQUEST)
# Invalid state json # Invalid state json
with self.app.test_client() as c: with self.app.test_client() as client:
with c.session_transaction() as session: with client.session_transaction() as session:
session['google_oauth2_csrf_token'] = 'tokenz' session['google_oauth2_csrf_token'] = 'tokenz'
state = '[{' state = '[{'
rv = c.get('/oauth2callback?state=%s&code=codez' % state) response = client.get(
self.assertEqual(rv.status_code, httplib.BAD_REQUEST) '/oauth2callback?state={0}&code=codez'.format(state))
self.assertEqual(response.status_code, httplib.BAD_REQUEST)
# Missing flow.
with self.app.test_client() as client:
with client.session_transaction() as session:
session['google_oauth2_csrf_token'] = 'tokenz'
state = json.dumps({
'csrf_token': 'tokenz',
'return_url': '/return_url'
})
response = client.get(
'/oauth2callback?state={0}&code=codez'.format(state))
self.assertEqual(response.status_code, httplib.BAD_REQUEST)
def test_no_credentials(self): def test_no_credentials(self):
with self.app.test_request_context(): with self.app.test_request_context():
@@ -343,24 +367,24 @@ class FlaskOAuth2Tests(unittest.TestCase):
return 'Hello' return 'Hello'
# No credentials, should redirect # No credentials, should redirect
with self.app.test_client() as c: with self.app.test_client() as client:
rv = c.get('/protected') response = client.get('/protected')
self.assertEqual(rv.status_code, httplib.FOUND) self.assertEqual(response.status_code, httplib.FOUND)
self.assertTrue('oauth2authorize' in rv.headers['Location']) self.assertTrue('oauth2authorize' in response.headers['Location'])
self.assertTrue('protected' in rv.headers['Location']) self.assertTrue('protected' in response.headers['Location'])
credentials = self._generate_credentials() credentials = self._generate_credentials(scopes=self.oauth2.scopes)
# With credentials, should allow # With credentials, should allow
with self.app.test_client() as c: with self.app.test_client() as client:
with c.session_transaction() as session: with client.session_transaction() as session:
session['google_oauth2_credentials'] = credentials.to_json() session['google_oauth2_credentials'] = credentials.to_json()
rv = c.get('/protected') response = client.get('/protected')
self.assertEqual(rv.status_code, httplib.OK) self.assertEqual(response.status_code, httplib.OK)
self.assertTrue('Hello' in rv.data.decode('utf-8')) self.assertTrue('Hello' in response.data.decode('utf-8'))
def test_incremental_auth(self): def _create_incremental_auth_app(self):
self.app = flask.Flask(__name__) self.app = flask.Flask(__name__)
self.app.testing = True self.app.testing = True
self.app.config['SECRET_KEY'] = 'notasecert' self.app.config['SECRET_KEY'] = 'notasecert'
@@ -380,41 +404,67 @@ class FlaskOAuth2Tests(unittest.TestCase):
def two(): def two():
return 'Hello' return 'Hello'
def test_incremental_auth(self):
self._create_incremental_auth_app()
# No credentials, should redirect # No credentials, should redirect
with self.app.test_client() as c: with self.app.test_client() as client:
rv = c.get('/one') response = client.get('/one')
self.assertTrue('one' in rv.headers['Location']) self.assertTrue('one' in response.headers['Location'])
self.assertEqual(rv.status_code, httplib.FOUND) self.assertEqual(response.status_code, httplib.FOUND)
# Credentials for one. /one should allow, /two should redirect. # Credentials for one. /one should allow, /two should redirect.
credentials = self._generate_credentials(scopes=['one']) credentials = self._generate_credentials(scopes=['email', 'one'])
with self.app.test_client() as c: with self.app.test_client() as client:
with c.session_transaction() as session: with client.session_transaction() as session:
session['google_oauth2_credentials'] = credentials.to_json() session['google_oauth2_credentials'] = credentials.to_json()
rv = c.get('/one') response = client.get('/one')
self.assertEqual(rv.status_code, httplib.OK) self.assertEqual(response.status_code, httplib.OK)
rv = c.get('/two') response = client.get('/two')
self.assertTrue('two' in rv.headers['Location']) self.assertTrue('two' in response.headers['Location'])
self.assertEqual(rv.status_code, httplib.FOUND) self.assertEqual(response.status_code, httplib.FOUND)
# Starting the authorization flow should include the # Starting the authorization flow should include the
# include_granted_scopes parameter as well as the scopes. # include_granted_scopes parameter as well as the scopes.
rv = c.get(rv.headers['Location'][17:]) response = client.get(response.headers['Location'][17:])
q = urlparse.parse_qs(rv.headers['Location'].split('?', 1)[1]) q = urlparse.parse_qs(response.headers['Location'].split('?', 1)[1])
self.assertTrue('include_granted_scopes' in q) self.assertTrue('include_granted_scopes' in q)
self.assertEqual(q['scope'][0], 'email one two three') self.assertEqual(
set(q['scope'][0].split(' ')),
set(['one', 'email', 'two', 'three']))
# Actually call two() without a redirect. # Actually call two() without a redirect.
credentials2 = self._generate_credentials(scopes=['two', 'three']) credentials2 = self._generate_credentials(
with self.app.test_client() as c: scopes=['email', 'two', 'three'])
with c.session_transaction() as session:
with self.app.test_client() as client:
with client.session_transaction() as session:
session['google_oauth2_credentials'] = credentials2.to_json() session['google_oauth2_credentials'] = credentials2.to_json()
rv = c.get('/two') response = client.get('/two')
self.assertEqual(rv.status_code, httplib.OK) self.assertEqual(response.status_code, httplib.OK)
def test_incremental_auth_exchange(self):
self._create_incremental_auth_app()
with Http2Mock():
with self.app.test_client() as client:
state = self._setup_callback_state(
client,
return_url='/return_url',
# Incremental auth scopes.
scopes=['one', 'two'])
response = client.get(
'/oauth2callback?state={0}&code=codez'.format(state))
self.assertEqual(response.status_code, httplib.FOUND)
credentials = self.oauth2.credentials
self.assertTrue(
credentials.has_scopes(['email', 'one', 'two']))
def test_refresh(self): def test_refresh(self):
with self.app.test_request_context(): with self.app.test_request_context():