diff --git a/oauth2client/flask_util.py b/oauth2client/flask_util.py index 2c455ac..eabf2d2 100644 --- a/oauth2client/flask_util.py +++ b/oauth2client/flask_util.py @@ -23,7 +23,7 @@ available. Configuration ============= -To configure, you'll need a set of OAuth2 client ID from the +To configure, you'll need a set of OAuth2 web application credentials from the `Google Developer's Console `__. @@ -164,6 +164,7 @@ available outside of a request context, you will need to implement your own import hashlib import json import os +import pickle from functools import wraps import six.moves.http_client as httplib @@ -185,12 +186,26 @@ from oauth2client.client import OAuth2Credentials from oauth2client.client import OAuth2WebServerFlow from oauth2client.client import Storage from oauth2client import clientsecrets -from oauth2client import util __author__ = 'jonwayne@google.com (Jon Wayne Parrott)' -DEFAULT_SCOPES = ('email',) +_DEFAULT_SCOPES = ('email',) +_CREDENTIALS_KEY = 'google_oauth2_credentials' +_FLOW_KEY = 'google_oauth2_flow_{0}' +_CSRF_KEY = 'google_oauth2_csrf_token' + + +def _get_flow_for_token(csrf_token): + """Retrieves the flow instance associated with a given CSRF token from + the Flask session.""" + flow_pickle = session.get( + _FLOW_KEY.format(csrf_token), None) + + if flow_pickle is None: + return None + else: + return pickle.loads(flow_pickle) class UserOAuth2(object): @@ -250,7 +265,7 @@ class UserOAuth2(object): self.storage = storage if scopes is None: - scopes = app.config.get('GOOGLE_OAUTH2_SCOPES', DEFAULT_SCOPES) + scopes = app.config.get('GOOGLE_OAUTH2_SCOPES', _DEFAULT_SCOPES) self.scopes = scopes self._load_config(client_secrets_file, client_id, client_secret) @@ -300,7 +315,8 @@ class UserOAuth2(object): client_type, client_info = clientsecrets.loadfile(filename) if client_type != clientsecrets.TYPE_WEB: raise ValueError( - 'The flow specified in %s is not supported.' % client_type) + 'The flow specified in {0} is not supported.'.format( + client_type)) self.client_id = client_info['client_id'] self.client_secret = client_info['client_secret'] @@ -310,7 +326,7 @@ class UserOAuth2(object): # Generate a CSRF token to prevent malicious requests. csrf_token = hashlib.sha256(os.urandom(1024)).hexdigest() - session['google_oauth2_csrf_token'] = csrf_token + session[_CSRF_KEY] = csrf_token state = json.dumps({ 'csrf_token': csrf_token, @@ -320,10 +336,10 @@ class UserOAuth2(object): kw = self.flow_kwargs.copy() kw.update(kwargs) - extra_scopes = util.scopes_to_string(kw.pop('scopes', '')) - scopes = ' '.join([util.scopes_to_string(self.scopes), extra_scopes]) + extra_scopes = kw.pop('scopes', []) + scopes = set(self.scopes).union(set(extra_scopes)) - return OAuth2WebServerFlow( + flow = OAuth2WebServerFlow( client_id=self.client_id, client_secret=self.client_secret, scope=scopes, @@ -331,6 +347,11 @@ class UserOAuth2(object): redirect_uri=url_for('oauth2.callback', _external=True), **kw) + flow_key = _FLOW_KEY.format(csrf_token) + session[flow_key] = pickle.dumps(flow) + + return flow + def _create_blueprint(self): bp = Blueprint('oauth2', __name__) bp.add_url_rule('/oauth2authorize', 'authorize', self.authorize_view) @@ -367,11 +388,12 @@ class UserOAuth2(object): if 'error' in request.args: reason = request.args.get( 'error_description', request.args.get('error', '')) - return 'Authorization failed: %s' % reason, httplib.BAD_REQUEST + return ('Authorization failed: {0}'.format(reason), + httplib.BAD_REQUEST) try: encoded_state = request.args['state'] - server_csrf = session['google_oauth2_csrf_token'] + server_csrf = session[_CSRF_KEY] code = request.args['code'] except KeyError: return 'Invalid request', httplib.BAD_REQUEST @@ -386,14 +408,17 @@ class UserOAuth2(object): if client_csrf != server_csrf: return 'Invalid request state', httplib.BAD_REQUEST - flow = self._make_flow() + flow = _get_flow_for_token(server_csrf) + + if flow is None: + return 'Invalid request state', httplib.BAD_REQUEST # Exchange the auth code for credentials. try: credentials = flow.step2_exchange(code) except FlowExchangeError as exchange_error: current_app.logger.exception(exchange_error) - content = 'An error occurred: %s' % (exchange_error,) + content = 'An error occurred: {0}'.format(exchange_error) return content, httplib.BAD_REQUEST # Save the credentials to the storage. @@ -409,7 +434,7 @@ class UserOAuth2(object): """The credentials for the current user or None if unavailable.""" ctx = _app_ctx_stack.top - if not hasattr(ctx, 'google_oauth2_credentials'): + if not hasattr(ctx, _CREDENTIALS_KEY): ctx.google_oauth2_credentials = self.storage.get() return ctx.google_oauth2_credentials @@ -432,7 +457,7 @@ class UserOAuth2(object): return self.credentials.id_token['email'] except KeyError: current_app.logger.error( - 'Invalid id_token %s', self.credentials.id_token) + 'Invalid id_token {0}'.format(self.credentials.id_token)) @property def user_id(self): @@ -448,7 +473,7 @@ class UserOAuth2(object): return self.credentials.id_token['sub'] except KeyError: current_app.logger.error( - 'Invalid id_token %s', self.credentials.id_token) + 'Invalid id_token {0}'.format(self.credentials.id_token)) def authorize_url(self, return_url, **kwargs): """Creates a URL that can be used to start the authorization flow. @@ -473,28 +498,30 @@ class UserOAuth2(object): def curry_wrapper(wrapped_function): @wraps(wrapped_function) def required_wrapper(*args, **kwargs): - return_url = decorator_kwargs.pop('return_url', request.url) - # No credentials, redirect for new authorization. - if not self.has_credentials(): + requested_scopes = set(self.scopes) + if scopes is not None: + requested_scopes |= set(scopes) + if self.has_credentials(): + requested_scopes |= self.credentials.scopes + + requested_scopes = list(requested_scopes) + + # Does the user have credentials and does the credentials have + # all of the needed scopes? + if (self.has_credentials() and + self.credentials.has_scopes(requested_scopes)): + return wrapped_function(*args, **kwargs) + # Otherwise, redirect to authorization + else: auth_url = self.authorize_url( return_url, - scopes=scopes, + scopes=requested_scopes, **decorator_kwargs) - return redirect(auth_url) - # Existing credentials but mismatching scopes, redirect for - # incremental authorization. - if scopes and not self.credentials.has_scopes(scopes): - auth_url = self.authorize_url( - return_url, - scopes=list(self.credentials.scopes) + scopes, - **decorator_kwargs) return redirect(auth_url) - return wrapped_function(*args, **kwargs) - return required_wrapper if decorated_function: @@ -530,7 +557,7 @@ class FlaskSessionStorage(Storage): """ def locked_get(self): - serialized = session.get('google_oauth2_credentials') + serialized = session.get(_CREDENTIALS_KEY) if serialized is None: return None @@ -541,8 +568,8 @@ class FlaskSessionStorage(Storage): return credentials def locked_put(self, credentials): - session['google_oauth2_credentials'] = credentials.to_json() + session[_CREDENTIALS_KEY] = credentials.to_json() def locked_delete(self): - if 'google_oauth2_credentials' in session: - del session['google_oauth2_credentials'] + if _CREDENTIALS_KEY in session: + del session[_CREDENTIALS_KEY] diff --git a/tests/test_flask_util.py b/tests/test_flask_util.py index 83fcb95..5ceb7b7 100644 --- a/tests/test_flask_util.py +++ b/tests/test_flask_util.py @@ -26,6 +26,7 @@ import six.moves.urllib.parse as urlparse from oauth2client import GOOGLE_AUTH_URI from oauth2client import GOOGLE_TOKEN_URI from oauth2client import clientsecrets +from oauth2client.flask_util import _get_flow_for_token from oauth2client.flask_util import UserOAuth2 as FlaskOAuth2 from oauth2client.client import OAuth2Credentials @@ -201,9 +202,9 @@ class FlaskOAuth2Tests(unittest.TestCase): self.assertEqual(flow.params['extra_arg'], 'test') def test_authorize_view(self): - with self.app.test_client() as c: - rv = c.get('/oauth2authorize') - location = rv.headers['Location'] + with self.app.test_client() as client: + response = client.get('/oauth2authorize') + location = response.headers['Location'] q = urlparse.parse_qs(location.split('?', 1)[1]) state = json.loads(q['state'][0]) @@ -214,35 +215,47 @@ class FlaskOAuth2Tests(unittest.TestCase): flask.session['google_oauth2_csrf_token'], state['csrf_token']) self.assertEqual(state['return_url'], '/') - with self.app.test_client() as c: - rv = c.get('/oauth2authorize?return_url=/test') - location = rv.headers['Location'] + with self.app.test_client() as client: + response = client.get('/oauth2authorize?return_url=/test') + location = response.headers['Location'] q = urlparse.parse_qs(location.split('?', 1)[1]) state = json.loads(q['state'][0]) self.assertEqual(state['return_url'], '/test') - with self.app.test_client() as c: - rv = c.get('/oauth2authorize?extra_param=test') - location = rv.headers['Location'] + with self.app.test_client() as client: + response = client.get('/oauth2authorize?extra_param=test') + location = response.headers['Location'] self.assertTrue('extra_param=test' in location) + def _setup_callback_state(self, client, **kwargs): + with self.app.test_request_context(): + # Flask doesn't create a request context with a session + # transaction for some reason, so, set up the flow here, + # then apply it to the session in the transaction. + if not kwargs: + self.oauth2._make_flow(return_url='/return_url') + else: + self.oauth2._make_flow(**kwargs) + + with client.session_transaction() as session: + session.update(flask.session) + csrf_token = session['google_oauth2_csrf_token'] + flow = _get_flow_for_token(csrf_token) + state = flow.params['state'] + + return state + def test_callback_view(self): self.oauth2.storage = mock.Mock() - - with self.app.test_client() as c: + with self.app.test_client() as client: with Http2Mock() as http: - with c.session_transaction() as session: - session['google_oauth2_csrf_token'] = 'tokenz' + state = self._setup_callback_state(client) - state = json.dumps({ - 'csrf_token': 'tokenz', - 'return_url': '/return_url' - }) + response = client.get( + '/oauth2callback?state={0}&code=codez'.format(state)) - rv = c.get('/oauth2callback?state=%s&code=codez' % state) - - self.assertEqual(rv.status_code, httplib.FOUND) - self.assertTrue('/return_url' in rv.headers['Location']) + self.assertEqual(response.status_code, httplib.FOUND) + self.assertTrue('/return_url' in response.headers['Location']) self.assertTrue(self.oauth2.client_secret in http.body) self.assertTrue('codez' in http.body) self.assertTrue(self.oauth2.storage.put.called) @@ -254,17 +267,17 @@ class FlaskOAuth2Tests(unittest.TestCase): def test_callback_view_errors(self): # Error supplied to callback - with self.app.test_client() as c: - with c.session_transaction() as session: + with self.app.test_client() as client: + with client.session_transaction() as session: session['google_oauth2_csrf_token'] = 'tokenz' - rv = c.get('/oauth2callback?state={}&error=something') - self.assertEqual(rv.status_code, httplib.BAD_REQUEST) - self.assertTrue('something' in rv.data.decode('utf-8')) + response = client.get('/oauth2callback?state={}&error=something') + self.assertEqual(response.status_code, httplib.BAD_REQUEST) + self.assertTrue('something' in response.data.decode('utf-8')) # CSRF mismatch - with self.app.test_client() as c: - with c.session_transaction() as session: + with self.app.test_client() as client: + with client.session_transaction() as session: session['google_oauth2_csrf_token'] = 'goodstate' state = json.dumps({ @@ -272,36 +285,47 @@ class FlaskOAuth2Tests(unittest.TestCase): 'return_url': '/return_url' }) - rv = c.get('/oauth2callback?state=%s&code=codez' % state) - self.assertEqual(rv.status_code, httplib.BAD_REQUEST) + response = client.get( + '/oauth2callback?state={0}&code=codez'.format(state)) + self.assertEqual(response.status_code, httplib.BAD_REQUEST) # KeyError, no CSRF state. - with self.app.test_client() as c: - rv = c.get('/oauth2callback?state={}&code=codez') - self.assertEqual(rv.status_code, httplib.BAD_REQUEST) + with self.app.test_client() as client: + response = client.get('/oauth2callback?state={}&code=codez') + self.assertEqual(response.status_code, httplib.BAD_REQUEST) # Code exchange error - with self.app.test_client() as c: - with Http2Mock(status=500): - with c.session_transaction() as session: - session['google_oauth2_csrf_token'] = 'tokenz' + with self.app.test_client() as client: + state = self._setup_callback_state(client) - state = json.dumps({ - 'csrf_token': 'tokenz', - 'return_url': '/return_url' - }) - - rv = c.get('/oauth2callback?state=%s&code=codez' % state) - self.assertEqual(rv.status_code, httplib.BAD_REQUEST) + with Http2Mock(status=httplib.INTERNAL_SERVER_ERROR): + response = client.get( + '/oauth2callback?state={0}&code=codez'.format(state)) + self.assertEqual(response.status_code, httplib.BAD_REQUEST) # Invalid state json - with self.app.test_client() as c: - with c.session_transaction() as session: + with self.app.test_client() as client: + with client.session_transaction() as session: session['google_oauth2_csrf_token'] = 'tokenz' state = '[{' - rv = c.get('/oauth2callback?state=%s&code=codez' % state) - self.assertEqual(rv.status_code, httplib.BAD_REQUEST) + response = client.get( + '/oauth2callback?state={0}&code=codez'.format(state)) + self.assertEqual(response.status_code, httplib.BAD_REQUEST) + + # Missing flow. + with self.app.test_client() as client: + with client.session_transaction() as session: + session['google_oauth2_csrf_token'] = 'tokenz' + + state = json.dumps({ + 'csrf_token': 'tokenz', + 'return_url': '/return_url' + }) + + response = client.get( + '/oauth2callback?state={0}&code=codez'.format(state)) + self.assertEqual(response.status_code, httplib.BAD_REQUEST) def test_no_credentials(self): with self.app.test_request_context(): @@ -343,24 +367,24 @@ class FlaskOAuth2Tests(unittest.TestCase): return 'Hello' # No credentials, should redirect - with self.app.test_client() as c: - rv = c.get('/protected') - self.assertEqual(rv.status_code, httplib.FOUND) - self.assertTrue('oauth2authorize' in rv.headers['Location']) - self.assertTrue('protected' in rv.headers['Location']) + with self.app.test_client() as client: + response = client.get('/protected') + self.assertEqual(response.status_code, httplib.FOUND) + self.assertTrue('oauth2authorize' in response.headers['Location']) + self.assertTrue('protected' in response.headers['Location']) - credentials = self._generate_credentials() + credentials = self._generate_credentials(scopes=self.oauth2.scopes) # With credentials, should allow - with self.app.test_client() as c: - with c.session_transaction() as session: + with self.app.test_client() as client: + with client.session_transaction() as session: session['google_oauth2_credentials'] = credentials.to_json() - rv = c.get('/protected') - self.assertEqual(rv.status_code, httplib.OK) - self.assertTrue('Hello' in rv.data.decode('utf-8')) + response = client.get('/protected') + self.assertEqual(response.status_code, httplib.OK) + self.assertTrue('Hello' in response.data.decode('utf-8')) - def test_incremental_auth(self): + def _create_incremental_auth_app(self): self.app = flask.Flask(__name__) self.app.testing = True self.app.config['SECRET_KEY'] = 'notasecert' @@ -380,41 +404,67 @@ class FlaskOAuth2Tests(unittest.TestCase): def two(): return 'Hello' + def test_incremental_auth(self): + self._create_incremental_auth_app() + # No credentials, should redirect - with self.app.test_client() as c: - rv = c.get('/one') - self.assertTrue('one' in rv.headers['Location']) - self.assertEqual(rv.status_code, httplib.FOUND) + with self.app.test_client() as client: + response = client.get('/one') + self.assertTrue('one' in response.headers['Location']) + self.assertEqual(response.status_code, httplib.FOUND) # Credentials for one. /one should allow, /two should redirect. - credentials = self._generate_credentials(scopes=['one']) + credentials = self._generate_credentials(scopes=['email', 'one']) - with self.app.test_client() as c: - with c.session_transaction() as session: + with self.app.test_client() as client: + with client.session_transaction() as session: session['google_oauth2_credentials'] = credentials.to_json() - rv = c.get('/one') - self.assertEqual(rv.status_code, httplib.OK) + response = client.get('/one') + self.assertEqual(response.status_code, httplib.OK) - rv = c.get('/two') - self.assertTrue('two' in rv.headers['Location']) - self.assertEqual(rv.status_code, httplib.FOUND) + response = client.get('/two') + self.assertTrue('two' in response.headers['Location']) + self.assertEqual(response.status_code, httplib.FOUND) # Starting the authorization flow should include the # include_granted_scopes parameter as well as the scopes. - rv = c.get(rv.headers['Location'][17:]) - q = urlparse.parse_qs(rv.headers['Location'].split('?', 1)[1]) + response = client.get(response.headers['Location'][17:]) + q = urlparse.parse_qs(response.headers['Location'].split('?', 1)[1]) self.assertTrue('include_granted_scopes' in q) - self.assertEqual(q['scope'][0], 'email one two three') + self.assertEqual( + set(q['scope'][0].split(' ')), + set(['one', 'email', 'two', 'three'])) # Actually call two() without a redirect. - credentials2 = self._generate_credentials(scopes=['two', 'three']) - with self.app.test_client() as c: - with c.session_transaction() as session: + credentials2 = self._generate_credentials( + scopes=['email', 'two', 'three']) + + with self.app.test_client() as client: + with client.session_transaction() as session: session['google_oauth2_credentials'] = credentials2.to_json() - rv = c.get('/two') - self.assertEqual(rv.status_code, httplib.OK) + response = client.get('/two') + self.assertEqual(response.status_code, httplib.OK) + + def test_incremental_auth_exchange(self): + self._create_incremental_auth_app() + + with Http2Mock(): + with self.app.test_client() as client: + state = self._setup_callback_state( + client, + return_url='/return_url', + # Incremental auth scopes. + scopes=['one', 'two']) + + response = client.get( + '/oauth2callback?state={0}&code=codez'.format(state)) + self.assertEqual(response.status_code, httplib.FOUND) + + credentials = self.oauth2.credentials + self.assertTrue( + credentials.has_scopes(['email', 'one', 'two'])) def test_refresh(self): with self.app.test_request_context():