Credentials keep track of scopes.

* Added new attributes scopes and token_info_uri to OAuth2Credentials.
* Added method OAuth2Credentials.has_scopes to test which scopes the credentials have.
* OAuth2WebServerFlow now passes in the authorized scopes when constructing OAuth2Credentials.
* Added method Oauth2Credentials.refresh_scopes, which gets the canonical list of scopes from the OAuth2 tokeninfo service.
* Added new utility function string_to_scopes, the inverse of scopes_to_string.

This will fix #228 when merged.
This commit is contained in:
Jon Wayne Parrott
2015-08-03 11:04:52 -07:00
parent bb654536e5
commit 969b130eb0
5 changed files with 181 additions and 6 deletions

View File

@@ -6,3 +6,4 @@ GOOGLE_AUTH_URI = 'https://accounts.google.com/o/oauth2/auth'
GOOGLE_DEVICE_URI = 'https://accounts.google.com/o/oauth2/device/code'
GOOGLE_REVOKE_URI = 'https://accounts.google.com/o/oauth2/revoke'
GOOGLE_TOKEN_URI = 'https://accounts.google.com/o/oauth2/token'
GOOGLE_TOKEN_INFO_URI = 'https://www.googleapis.com/oauth2/v2/tokeninfo'

View File

@@ -39,6 +39,7 @@ from oauth2client import GOOGLE_AUTH_URI
from oauth2client import GOOGLE_DEVICE_URI
from oauth2client import GOOGLE_REVOKE_URI
from oauth2client import GOOGLE_TOKEN_URI
from oauth2client import GOOGLE_TOKEN_INFO_URI
from oauth2client._helpers import _urlsafe_b64decode
from oauth2client import clientsecrets
from oauth2client import util
@@ -252,6 +253,8 @@ class Credentials(object):
for key, val in d.items():
if isinstance(val, bytes):
d[key] = val.decode('utf-8')
if isinstance(val, set):
d[key] = list(val)
return json.dumps(d)
def to_json(self):
@@ -459,7 +462,8 @@ class OAuth2Credentials(Credentials):
@util.positional(8)
def __init__(self, access_token, client_id, client_secret, refresh_token,
token_expiry, token_uri, user_agent, revoke_uri=None,
id_token=None, token_response=None):
id_token=None, token_response=None, scopes=None,
token_info_uri=None):
"""Create an instance of OAuth2Credentials.
This constructor is not usually called by the user, instead
@@ -479,6 +483,9 @@ class OAuth2Credentials(Credentials):
token_response: dict, the decoded response to the token request. None
if a token hasn't been requested yet. Stored because some providers
(e.g. wordpress.com) include extra fields that clients may want.
scopes: list, authorized scopes for these credentials.
token_info_uri: string, the URI for the token info endpoint. Defaults to
None; scopes can not be refreshed if this is None.
Notes:
store: callable, A callable that when passed a Credential
@@ -497,6 +504,8 @@ class OAuth2Credentials(Credentials):
self.revoke_uri = revoke_uri
self.id_token = id_token
self.token_response = token_response
self.scopes = set(util.string_to_scopes(scopes or []))
self.token_info_uri = token_info_uri
# True if the credentials have been revoked or expired and can't be
# refreshed.
@@ -614,6 +623,39 @@ class OAuth2Credentials(Credentials):
"""
headers['Authorization'] = 'Bearer ' + self.access_token
def has_scopes(self, scopes):
"""Verify that the credentials are authorized for the given scopes.
Returns True if the credentials authorized scopes contain all of the scopes
given.
Args:
scopes: list or string, the scopes to check.
Notes:
There are cases where the credentials are unaware of which scopes are
authorized. Notably, credentials obtained and stored before this code was
added will not have scopes, AccessTokenCredentials do not have scopes. In
both cases, you can use refresh_scopes() to obtain the canonical set of
scopes.
"""
scopes = util.string_to_scopes(scopes)
return set(scopes).issubset(self.scopes)
def retrieve_scopes(self, http):
"""Retrieves the canonical list of scopes for this access token from the
OAuth2 provider.
Args:
http: httplib2.Http, an http object to be used to make the refresh
request.
Returns:
A set of strings containing the canonical list of scopes.
"""
self._retrieve_scopes(http.request)
return self.scopes
def to_json(self):
return self._to_json(Credentials.NON_SERIALIZED_MEMBERS)
@@ -648,7 +690,9 @@ class OAuth2Credentials(Credentials):
data['user_agent'],
revoke_uri=data.get('revoke_uri', None),
id_token=data.get('id_token', None),
token_response=data.get('token_response', None))
token_response=data.get('token_response', None),
scopes=data.get('scopes', None),
token_info_uri=data.get('token_info_uri', None))
retval.invalid = data['invalid']
return retval
@@ -858,6 +902,10 @@ class OAuth2Credentials(Credentials):
query_params = {'token': token}
token_revoke_uri = _update_query_params(self.revoke_uri, query_params)
resp, content = http_request(token_revoke_uri)
if six.PY3 and isinstance(content, bytes):
content = content.decode('utf-8')
if resp.status == 200:
self.invalid = True
else:
@@ -873,6 +921,48 @@ class OAuth2Credentials(Credentials):
if self.store:
self.store.delete()
def _retrieve_scopes(self, http_request):
"""Retrieves the list of authorized scopes from the OAuth2 provider.
Args:
http_request: callable, a callable that matches the method signature of
httplib2.Http.request, used to make the revoke request.
"""
self._do_retrieve_scopes(http_request, self.access_token)
def _do_retrieve_scopes(self, http_request, token):
"""Retrieves the list of authorized scopes from the OAuth2 provider.
Args:
http_request: callable, a callable that matches the method signature of
httplib2.Http.request, used to make the refresh request.
token: A string used as the token to identify the credentials to the
provider.
Raises:
Error: When refresh fails, indicating the the access token is invalid.
"""
logger.info('Refreshing scopes')
query_params = {'access_token': token, 'fields': 'scope'}
token_info_uri = _update_query_params(self.token_info_uri, query_params)
resp, content = http_request(token_info_uri)
if six.PY3 and isinstance(content, bytes):
content = content.decode('utf-8')
if resp.status == 200:
d = json.loads(content)
self.scopes = set(util.string_to_scopes(d.get('scope', '')))
else:
error_msg = 'Invalid response %s.' % (resp.status,)
try:
d = json.loads(content)
if 'error_description' in d:
error_msg = d['error_description']
except (TypeError, ValueError):
pass
raise Error(error_msg)
class AccessTokenCredentials(OAuth2Credentials):
"""Credentials object for OAuth 2.0.
@@ -1650,7 +1740,8 @@ def credentials_from_code(client_id, client_secret, scope, code,
user_agent=None, token_uri=GOOGLE_TOKEN_URI,
auth_uri=GOOGLE_AUTH_URI,
revoke_uri=GOOGLE_REVOKE_URI,
device_uri=GOOGLE_DEVICE_URI):
device_uri=GOOGLE_DEVICE_URI,
token_info_uri=GOOGLE_TOKEN_INFO_URI):
"""Exchanges an authorization code for an OAuth2Credentials object.
Args:
@@ -1681,7 +1772,8 @@ def credentials_from_code(client_id, client_secret, scope, code,
flow = OAuth2WebServerFlow(client_id, client_secret, scope,
redirect_uri=redirect_uri, user_agent=user_agent,
auth_uri=auth_uri, token_uri=token_uri,
revoke_uri=revoke_uri, device_uri=device_uri)
revoke_uri=revoke_uri, device_uri=device_uri,
token_info_uri=token_info_uri)
credentials = flow.step2_exchange(code, http=http)
return credentials
@@ -1786,6 +1878,7 @@ class OAuth2WebServerFlow(Flow):
revoke_uri=GOOGLE_REVOKE_URI,
login_hint=None,
device_uri=GOOGLE_DEVICE_URI,
token_info_uri=GOOGLE_TOKEN_INFO_URI,
authorization_header=None,
**kwargs):
"""Constructor for OAuth2WebServerFlow.
@@ -1834,6 +1927,7 @@ class OAuth2WebServerFlow(Flow):
self.token_uri = token_uri
self.revoke_uri = revoke_uri
self.device_uri = device_uri
self.token_info_uri = token_info_uri
self.authorization_header = authorization_header
self.params = {
'access_type': 'offline',
@@ -2011,7 +2105,9 @@ class OAuth2WebServerFlow(Flow):
self.token_uri, self.user_agent,
revoke_uri=self.revoke_uri,
id_token=extracted_id_token,
token_response=d)
token_response=d,
scopes=self.scope,
token_info_uri=self.token_info_uri)
else:
logger.info('Failed to retrieve access token: %s', content)
if 'error' in d:

View File

@@ -163,6 +163,26 @@ def scopes_to_string(scopes):
return ' '.join(scopes)
def string_to_scopes(scopes):
"""Converts stringifed scope value to a list.
If scopes is a list then it is simply passed through. If scopes is an
string then a list of each individual scope is returned.
Args:
scopes: a string or iterable of strings, the scopes.
Returns:
The scopes in a list.
"""
if not scopes:
return []
if isinstance(scopes, six.string_types):
return scopes.split(' ')
else:
return scopes
def dict_to_tuple_key(dictionary):
"""Converts a dictionary to a tuple that can be used as an immutable key.

View File

@@ -40,6 +40,7 @@ from .http_mock import HttpMock
from .http_mock import HttpMockSequence
from oauth2client import GOOGLE_REVOKE_URI
from oauth2client import GOOGLE_TOKEN_URI
from oauth2client import GOOGLE_TOKEN_INFO_URI
from oauth2client import client
from oauth2client import util as oauth2client_util
from oauth2client.client import AccessTokenCredentials
@@ -50,6 +51,7 @@ from oauth2client.client import AssertionCredentials
from oauth2client.client import AUTHORIZED_USER
from oauth2client.client import Credentials
from oauth2client.client import DEFAULT_ENV_NAME
from oauth2client.client import Error
from oauth2client.client import ApplicationDefaultCredentialsError
from oauth2client.client import FlowExchangeError
from oauth2client.client import GoogleCredentials
@@ -647,7 +649,8 @@ class BasicCredentialsTests(unittest.TestCase):
self.credentials = OAuth2Credentials(
access_token, client_id, client_secret,
refresh_token, token_expiry, GOOGLE_TOKEN_URI,
user_agent, revoke_uri=GOOGLE_REVOKE_URI)
user_agent, revoke_uri=GOOGLE_REVOKE_URI, scopes='foo',
token_info_uri=GOOGLE_TOKEN_INFO_URI)
# Provoke a failure if @util.positional is not respected.
self.old_positional_enforcement = (
@@ -843,6 +846,43 @@ class BasicCredentialsTests(unittest.TestCase):
self.assertFalse(self.credentials.access_token_expired)
self.assertEqual(token_response_second, self.credentials.token_response)
def test_has_scopes(self):
self.assertTrue(self.credentials.has_scopes('foo'))
self.assertTrue(self.credentials.has_scopes(['foo']))
self.assertFalse(self.credentials.has_scopes('bar'))
self.assertFalse(self.credentials.has_scopes(['bar']))
self.credentials.scopes = set(['foo', 'bar'])
self.assertTrue(self.credentials.has_scopes('foo'))
self.assertTrue(self.credentials.has_scopes('bar'))
self.assertFalse(self.credentials.has_scopes('baz'))
self.assertTrue(self.credentials.has_scopes(['foo', 'bar']))
self.assertFalse(self.credentials.has_scopes(['foo', 'baz']))
self.credentials.scopes = set([])
self.assertFalse(self.credentials.has_scopes('foo'))
def test_retrieve_scopes(self):
info_response_first = {'scope': 'foo bar'}
info_response_second = {'error_description': 'abcdef'}
http = HttpMockSequence([
({'status': '200'}, json.dumps(info_response_first).encode('utf-8')),
({'status': '400'}, json.dumps(info_response_second).encode('utf-8')),
({'status': '500'}, b''),
])
self.credentials.retrieve_scopes(http)
self.assertEqual(set(['foo', 'bar']), self.credentials.scopes)
self.assertRaises(
Error,
self.credentials.retrieve_scopes,
http)
self.assertRaises(
Error,
self.credentials.retrieve_scopes,
http)
class AccessTokenCredentialsTests(unittest.TestCase):
@@ -1069,6 +1109,7 @@ class OAuth2WebServerFlowTest(unittest.TestCase):
self.assertNotEqual(None, credentials.token_expiry)
self.assertEqual('8xLOxBtZp8', credentials.refresh_token)
self.assertEqual('dummy_revoke_uri', credentials.revoke_uri)
self.assertEqual(set(['foo']), credentials.scopes)
def test_exchange_dictlike(self):
class FakeDict(object):
@@ -1095,6 +1136,7 @@ class OAuth2WebServerFlowTest(unittest.TestCase):
self.assertNotEqual(None, credentials.token_expiry)
self.assertEqual('8xLOxBtZp8', credentials.refresh_token)
self.assertEqual('dummy_revoke_uri', credentials.revoke_uri)
self.assertEqual(set(['foo']), credentials.scopes)
request_code = urllib.parse.parse_qs(http.requests[0]['body'])['code'][0]
self.assertEqual(code, request_code)
@@ -1229,6 +1271,7 @@ class CredentialsFromCodeTests(unittest.TestCase):
http=http)
self.assertEqual(credentials.access_token, token)
self.assertNotEqual(None, credentials.token_expiry)
self.assertEqual(set(['foo']), credentials.scopes)
def test_exchange_code_for_token_fail(self):
http = HttpMockSequence([
@@ -1254,6 +1297,7 @@ class CredentialsFromCodeTests(unittest.TestCase):
self.code, http=http)
self.assertEqual(credentials.access_token, 'asdfghjkl')
self.assertNotEqual(None, credentials.token_expiry)
self.assertEqual(set(['foo']), credentials.scopes)
def test_exchange_code_and_cached_file_for_token(self):
http = HttpMockSequence([
@@ -1266,6 +1310,7 @@ class CredentialsFromCodeTests(unittest.TestCase):
'some_secrets', self.scope,
self.code, http=http, cache=cache_mock)
self.assertEqual(credentials.access_token, 'asdfghjkl')
self.assertEqual(set(['foo']), credentials.scopes)
def test_exchange_code_and_file_for_token_fail(self):
http = HttpMockSequence([

View File

@@ -27,6 +27,19 @@ class ScopeToStringTests(unittest.TestCase):
self.assertEqual(expected, util.scopes_to_string(case))
class StringToScopeTests(unittest.TestCase):
def test_conversion(self):
cases = [
(['a', 'b'], ['a', 'b']),
('', []),
('a', ['a']),
('a b c d e f', ['a', 'b', 'c', 'd', 'e', 'f']),
]
for case, expected in cases:
self.assertEqual(expected, util.string_to_scopes(case))
class KeyConversionTests(unittest.TestCase):
def test_key_conversions(self):