Fix flask required decorator to redirect on expired credentials.

This commit is contained in:
Jon Wayne Parrott
2016-03-04 10:52:15 -08:00
parent 15c945f1c7
commit a82146a651
2 changed files with 78 additions and 23 deletions

View File

@@ -443,7 +443,14 @@ class UserOAuth2(object):
def has_credentials(self):
"""Returns True if there are valid credentials for the current user."""
return self.credentials and not self.credentials.invalid
if not self.credentials:
return False
# Is the access token expired? If so, do we have an refresh token?
elif (self.credentials.access_token_expired
and not self.credentials.refresh_token):
return False
else:
return True
@property
def email(self):

View File

@@ -14,9 +14,10 @@
"""Unit tests for the Flask utilities"""
import datetime
import httplib2
import json
import unittest
import unittest2
import flask
import six.moves.http_client as httplib
@@ -64,7 +65,7 @@ class Http2Mock(object):
return self
class FlaskOAuth2Tests(unittest.TestCase):
class FlaskOAuth2Tests(unittest2.TestCase):
def setUp(self):
self.app = flask.Flask(__name__)
@@ -81,7 +82,7 @@ class FlaskOAuth2Tests(unittest.TestCase):
'client_idz',
'client_secretz',
'refresh_tokenz',
'3600',
datetime.datetime.utcnow() + datetime.timedelta(seconds=3600),
GOOGLE_TOKEN_URI,
'Test',
id_token={
@@ -175,13 +176,13 @@ class FlaskOAuth2Tests(unittest.TestCase):
with self.app.test_request_context():
flow = self.oauth2._make_flow()
state = json.loads(flow.params['state'])
self.assertTrue('google_oauth2_csrf_token' in flask.session)
self.assertIn('google_oauth2_csrf_token', flask.session)
self.assertEqual(
flask.session['google_oauth2_csrf_token'], state['csrf_token'])
self.assertEqual(flow.client_id, self.oauth2.client_id)
self.assertEqual(flow.client_secret, self.oauth2.client_secret)
self.assertTrue('http' in flow.redirect_uri)
self.assertTrue('oauth2callback' in flow.redirect_uri)
self.assertIn('http', flow.redirect_uri)
self.assertIn('oauth2callback', flow.redirect_uri)
flow = self.oauth2._make_flow(return_url='/return_url')
state = json.loads(flow.params['state'])
@@ -208,9 +209,9 @@ class FlaskOAuth2Tests(unittest.TestCase):
q = urlparse.parse_qs(location.split('?', 1)[1])
state = json.loads(q['state'][0])
self.assertTrue(GOOGLE_AUTH_URI in location)
self.assertFalse(self.oauth2.client_secret in location)
self.assertTrue(self.oauth2.client_id in q['client_id'])
self.assertIn(GOOGLE_AUTH_URI, location)
self.assertNotIn(self.oauth2.client_secret, location)
self.assertIn(self.oauth2.client_id, q['client_id'])
self.assertEqual(
flask.session['google_oauth2_csrf_token'], state['csrf_token'])
self.assertEqual(state['return_url'], '/')
@@ -225,7 +226,7 @@ class FlaskOAuth2Tests(unittest.TestCase):
with self.app.test_client() as client:
response = client.get('/oauth2authorize?extra_param=test')
location = response.headers['Location']
self.assertTrue('extra_param=test' in location)
self.assertIn('extra_param=test', location)
def _setup_callback_state(self, client, **kwargs):
with self.app.test_request_context():
@@ -255,9 +256,9 @@ class FlaskOAuth2Tests(unittest.TestCase):
'/oauth2callback?state={0}&code=codez'.format(state))
self.assertEqual(response.status_code, httplib.FOUND)
self.assertTrue('/return_url' in response.headers['Location'])
self.assertTrue(self.oauth2.client_secret in http.body)
self.assertTrue('codez' in http.body)
self.assertIn('/return_url', response.headers['Location'])
self.assertIn(self.oauth2.client_secret, http.body)
self.assertIn('codez', http.body)
self.assertTrue(self.oauth2.storage.put.called)
def test_authorize_callback(self):
@@ -273,7 +274,7 @@ class FlaskOAuth2Tests(unittest.TestCase):
response = client.get('/oauth2callback?state={}&error=something')
self.assertEqual(response.status_code, httplib.BAD_REQUEST)
self.assertTrue('something' in response.data.decode('utf-8'))
self.assertIn('something', response.data.decode('utf-8'))
# CSRF mismatch
with self.app.test_client() as client:
@@ -352,6 +353,24 @@ class FlaskOAuth2Tests(unittest.TestCase):
self.assertEqual(self.oauth2.email, 'user@example.com')
self.assertTrue(self.oauth2.http())
@mock.patch('oauth2client.client._UTCNOW')
def test_with_expired_credentials(self, utcnow):
utcnow.return_value = datetime.datetime(1990, 5, 29)
credentials = self._generate_credentials()
credentials.token_expiry = datetime.datetime(1990, 5, 28)
# Has a refresh token, so this should be fine.
with self.app.test_request_context():
self.oauth2.storage.put(credentials)
self.assertTrue(self.oauth2.has_credentials())
# Without a refresh token this should return false.
credentials.refresh_token = None
with self.app.test_request_context():
self.oauth2.storage.put(credentials)
self.assertFalse(self.oauth2.has_credentials())
def test_bad_id_token(self):
credentials = self._generate_credentials()
credentials.id_token = {}
@@ -370,8 +389,8 @@ class FlaskOAuth2Tests(unittest.TestCase):
with self.app.test_client() as client:
response = client.get('/protected')
self.assertEqual(response.status_code, httplib.FOUND)
self.assertTrue('oauth2authorize' in response.headers['Location'])
self.assertTrue('protected' in response.headers['Location'])
self.assertIn('oauth2authorize', response.headers['Location'])
self.assertIn('protected', response.headers['Location'])
credentials = self._generate_credentials(scopes=self.oauth2.scopes)
@@ -382,7 +401,36 @@ class FlaskOAuth2Tests(unittest.TestCase):
response = client.get('/protected')
self.assertEqual(response.status_code, httplib.OK)
self.assertTrue('Hello' in response.data.decode('utf-8'))
self.assertIn('Hello', response.data.decode('utf-8'))
# Expired credentials with refresh token, should allow.
credentials.token_expiry = datetime.datetime(1990, 5, 28)
with mock.patch('oauth2client.client._UTCNOW') as utcnow:
utcnow.return_value = datetime.datetime(1990, 5, 29)
with self.app.test_client() as client:
with client.session_transaction() as session:
session['google_oauth2_credentials'] = (
credentials.to_json())
response = client.get('/protected')
self.assertEqual(response.status_code, httplib.OK)
self.assertIn('Hello', response.data.decode('utf-8'))
# Expired credentials without a refresh token, should redirect.
credentials.refresh_token = None
with mock.patch('oauth2client.client._UTCNOW') as utcnow:
utcnow.return_value = datetime.datetime(1990, 5, 29)
with self.app.test_client() as client:
with client.session_transaction() as session:
session['google_oauth2_credentials'] = (
credentials.to_json())
response = client.get('/protected')
self.assertEqual(response.status_code, httplib.FOUND)
self.assertIn('oauth2authorize', response.headers['Location'])
self.assertIn('protected', response.headers['Location'])
def _create_incremental_auth_app(self):
self.app = flask.Flask(__name__)
@@ -410,7 +458,7 @@ class FlaskOAuth2Tests(unittest.TestCase):
# No credentials, should redirect
with self.app.test_client() as client:
response = client.get('/one')
self.assertTrue('one' in response.headers['Location'])
self.assertIn('one', response.headers['Location'])
self.assertEqual(response.status_code, httplib.FOUND)
# Credentials for one. /one should allow, /two should redirect.
@@ -424,14 +472,14 @@ class FlaskOAuth2Tests(unittest.TestCase):
self.assertEqual(response.status_code, httplib.OK)
response = client.get('/two')
self.assertTrue('two' in response.headers['Location'])
self.assertIn('two', response.headers['Location'])
self.assertEqual(response.status_code, httplib.FOUND)
# Starting the authorization flow should include the
# include_granted_scopes parameter as well as the scopes.
response = client.get(response.headers['Location'][17:])
q = urlparse.parse_qs(response.headers['Location'].split('?', 1)[1])
self.assertTrue('include_granted_scopes' in q)
self.assertIn('include_granted_scopes', q)
self.assertEqual(
set(q['scope'][0].split(' ')),
set(['one', 'email', 'two', 'three']))
@@ -483,8 +531,8 @@ class FlaskOAuth2Tests(unittest.TestCase):
self.oauth2.storage.put(self._generate_credentials())
self.oauth2.storage.delete()
self.assertFalse('google_oauth2_credentials' in flask.session)
self.assertNotIn('google_oauth2_credentials', flask.session)
if __name__ == '__main__': # pragma: NO COVER
unittest.main()
unittest2.main()