Credentials keep track of scopes.
* Added new attributes scopes and token_info_uri to OAuth2Credentials. * Added method OAuth2Credentials.has_scopes to test which scopes the credentials have. * OAuth2WebServerFlow now passes in the authorized scopes when constructing OAuth2Credentials. * Added method Oauth2Credentials.refresh_scopes, which gets the canonical list of scopes from the OAuth2 tokeninfo service. * Added new utility function string_to_scopes, the inverse of scopes_to_string. This will fix #228 when merged.
This commit is contained in:
@@ -6,3 +6,4 @@ GOOGLE_AUTH_URI = 'https://accounts.google.com/o/oauth2/auth'
|
||||
GOOGLE_DEVICE_URI = 'https://accounts.google.com/o/oauth2/device/code'
|
||||
GOOGLE_REVOKE_URI = 'https://accounts.google.com/o/oauth2/revoke'
|
||||
GOOGLE_TOKEN_URI = 'https://accounts.google.com/o/oauth2/token'
|
||||
GOOGLE_TOKEN_INFO_URI = 'https://www.googleapis.com/oauth2/v2/tokeninfo'
|
||||
|
||||
@@ -39,6 +39,7 @@ from oauth2client import GOOGLE_AUTH_URI
|
||||
from oauth2client import GOOGLE_DEVICE_URI
|
||||
from oauth2client import GOOGLE_REVOKE_URI
|
||||
from oauth2client import GOOGLE_TOKEN_URI
|
||||
from oauth2client import GOOGLE_TOKEN_INFO_URI
|
||||
from oauth2client._helpers import _urlsafe_b64decode
|
||||
from oauth2client import clientsecrets
|
||||
from oauth2client import util
|
||||
@@ -252,6 +253,8 @@ class Credentials(object):
|
||||
for key, val in d.items():
|
||||
if isinstance(val, bytes):
|
||||
d[key] = val.decode('utf-8')
|
||||
if isinstance(val, set):
|
||||
d[key] = list(val)
|
||||
return json.dumps(d)
|
||||
|
||||
def to_json(self):
|
||||
@@ -459,7 +462,8 @@ class OAuth2Credentials(Credentials):
|
||||
@util.positional(8)
|
||||
def __init__(self, access_token, client_id, client_secret, refresh_token,
|
||||
token_expiry, token_uri, user_agent, revoke_uri=None,
|
||||
id_token=None, token_response=None):
|
||||
id_token=None, token_response=None, scopes=None,
|
||||
token_info_uri=None):
|
||||
"""Create an instance of OAuth2Credentials.
|
||||
|
||||
This constructor is not usually called by the user, instead
|
||||
@@ -479,6 +483,9 @@ class OAuth2Credentials(Credentials):
|
||||
token_response: dict, the decoded response to the token request. None
|
||||
if a token hasn't been requested yet. Stored because some providers
|
||||
(e.g. wordpress.com) include extra fields that clients may want.
|
||||
scopes: list, authorized scopes for these credentials.
|
||||
token_info_uri: string, the URI for the token info endpoint. Defaults to
|
||||
None; scopes can not be refreshed if this is None.
|
||||
|
||||
Notes:
|
||||
store: callable, A callable that when passed a Credential
|
||||
@@ -497,6 +504,8 @@ class OAuth2Credentials(Credentials):
|
||||
self.revoke_uri = revoke_uri
|
||||
self.id_token = id_token
|
||||
self.token_response = token_response
|
||||
self.scopes = set(util.string_to_scopes(scopes or []))
|
||||
self.token_info_uri = token_info_uri
|
||||
|
||||
# True if the credentials have been revoked or expired and can't be
|
||||
# refreshed.
|
||||
@@ -614,6 +623,39 @@ class OAuth2Credentials(Credentials):
|
||||
"""
|
||||
headers['Authorization'] = 'Bearer ' + self.access_token
|
||||
|
||||
def has_scopes(self, scopes):
|
||||
"""Verify that the credentials are authorized for the given scopes.
|
||||
|
||||
Returns True if the credentials authorized scopes contain all of the scopes
|
||||
given.
|
||||
|
||||
Args:
|
||||
scopes: list or string, the scopes to check.
|
||||
|
||||
Notes:
|
||||
There are cases where the credentials are unaware of which scopes are
|
||||
authorized. Notably, credentials obtained and stored before this code was
|
||||
added will not have scopes, AccessTokenCredentials do not have scopes. In
|
||||
both cases, you can use refresh_scopes() to obtain the canonical set of
|
||||
scopes.
|
||||
"""
|
||||
scopes = util.string_to_scopes(scopes)
|
||||
return set(scopes).issubset(self.scopes)
|
||||
|
||||
def retrieve_scopes(self, http):
|
||||
"""Retrieves the canonical list of scopes for this access token from the
|
||||
OAuth2 provider.
|
||||
|
||||
Args:
|
||||
http: httplib2.Http, an http object to be used to make the refresh
|
||||
request.
|
||||
|
||||
Returns:
|
||||
A set of strings containing the canonical list of scopes.
|
||||
"""
|
||||
self._retrieve_scopes(http.request)
|
||||
return self.scopes
|
||||
|
||||
def to_json(self):
|
||||
return self._to_json(Credentials.NON_SERIALIZED_MEMBERS)
|
||||
|
||||
@@ -648,7 +690,9 @@ class OAuth2Credentials(Credentials):
|
||||
data['user_agent'],
|
||||
revoke_uri=data.get('revoke_uri', None),
|
||||
id_token=data.get('id_token', None),
|
||||
token_response=data.get('token_response', None))
|
||||
token_response=data.get('token_response', None),
|
||||
scopes=data.get('scopes', None),
|
||||
token_info_uri=data.get('token_info_uri', None))
|
||||
retval.invalid = data['invalid']
|
||||
return retval
|
||||
|
||||
@@ -858,6 +902,10 @@ class OAuth2Credentials(Credentials):
|
||||
query_params = {'token': token}
|
||||
token_revoke_uri = _update_query_params(self.revoke_uri, query_params)
|
||||
resp, content = http_request(token_revoke_uri)
|
||||
|
||||
if six.PY3 and isinstance(content, bytes):
|
||||
content = content.decode('utf-8')
|
||||
|
||||
if resp.status == 200:
|
||||
self.invalid = True
|
||||
else:
|
||||
@@ -873,6 +921,48 @@ class OAuth2Credentials(Credentials):
|
||||
if self.store:
|
||||
self.store.delete()
|
||||
|
||||
def _retrieve_scopes(self, http_request):
|
||||
"""Retrieves the list of authorized scopes from the OAuth2 provider.
|
||||
|
||||
Args:
|
||||
http_request: callable, a callable that matches the method signature of
|
||||
httplib2.Http.request, used to make the revoke request.
|
||||
"""
|
||||
self._do_retrieve_scopes(http_request, self.access_token)
|
||||
|
||||
def _do_retrieve_scopes(self, http_request, token):
|
||||
"""Retrieves the list of authorized scopes from the OAuth2 provider.
|
||||
|
||||
Args:
|
||||
http_request: callable, a callable that matches the method signature of
|
||||
httplib2.Http.request, used to make the refresh request.
|
||||
token: A string used as the token to identify the credentials to the
|
||||
provider.
|
||||
|
||||
Raises:
|
||||
Error: When refresh fails, indicating the the access token is invalid.
|
||||
"""
|
||||
logger.info('Refreshing scopes')
|
||||
query_params = {'access_token': token, 'fields': 'scope'}
|
||||
token_info_uri = _update_query_params(self.token_info_uri, query_params)
|
||||
resp, content = http_request(token_info_uri)
|
||||
|
||||
if six.PY3 and isinstance(content, bytes):
|
||||
content = content.decode('utf-8')
|
||||
|
||||
if resp.status == 200:
|
||||
d = json.loads(content)
|
||||
self.scopes = set(util.string_to_scopes(d.get('scope', '')))
|
||||
else:
|
||||
error_msg = 'Invalid response %s.' % (resp.status,)
|
||||
try:
|
||||
d = json.loads(content)
|
||||
if 'error_description' in d:
|
||||
error_msg = d['error_description']
|
||||
except (TypeError, ValueError):
|
||||
pass
|
||||
raise Error(error_msg)
|
||||
|
||||
|
||||
class AccessTokenCredentials(OAuth2Credentials):
|
||||
"""Credentials object for OAuth 2.0.
|
||||
@@ -1650,7 +1740,8 @@ def credentials_from_code(client_id, client_secret, scope, code,
|
||||
user_agent=None, token_uri=GOOGLE_TOKEN_URI,
|
||||
auth_uri=GOOGLE_AUTH_URI,
|
||||
revoke_uri=GOOGLE_REVOKE_URI,
|
||||
device_uri=GOOGLE_DEVICE_URI):
|
||||
device_uri=GOOGLE_DEVICE_URI,
|
||||
token_info_uri=GOOGLE_TOKEN_INFO_URI):
|
||||
"""Exchanges an authorization code for an OAuth2Credentials object.
|
||||
|
||||
Args:
|
||||
@@ -1681,7 +1772,8 @@ def credentials_from_code(client_id, client_secret, scope, code,
|
||||
flow = OAuth2WebServerFlow(client_id, client_secret, scope,
|
||||
redirect_uri=redirect_uri, user_agent=user_agent,
|
||||
auth_uri=auth_uri, token_uri=token_uri,
|
||||
revoke_uri=revoke_uri, device_uri=device_uri)
|
||||
revoke_uri=revoke_uri, device_uri=device_uri,
|
||||
token_info_uri=token_info_uri)
|
||||
|
||||
credentials = flow.step2_exchange(code, http=http)
|
||||
return credentials
|
||||
@@ -1786,6 +1878,7 @@ class OAuth2WebServerFlow(Flow):
|
||||
revoke_uri=GOOGLE_REVOKE_URI,
|
||||
login_hint=None,
|
||||
device_uri=GOOGLE_DEVICE_URI,
|
||||
token_info_uri=GOOGLE_TOKEN_INFO_URI,
|
||||
authorization_header=None,
|
||||
**kwargs):
|
||||
"""Constructor for OAuth2WebServerFlow.
|
||||
@@ -1834,6 +1927,7 @@ class OAuth2WebServerFlow(Flow):
|
||||
self.token_uri = token_uri
|
||||
self.revoke_uri = revoke_uri
|
||||
self.device_uri = device_uri
|
||||
self.token_info_uri = token_info_uri
|
||||
self.authorization_header = authorization_header
|
||||
self.params = {
|
||||
'access_type': 'offline',
|
||||
@@ -2011,7 +2105,9 @@ class OAuth2WebServerFlow(Flow):
|
||||
self.token_uri, self.user_agent,
|
||||
revoke_uri=self.revoke_uri,
|
||||
id_token=extracted_id_token,
|
||||
token_response=d)
|
||||
token_response=d,
|
||||
scopes=self.scope,
|
||||
token_info_uri=self.token_info_uri)
|
||||
else:
|
||||
logger.info('Failed to retrieve access token: %s', content)
|
||||
if 'error' in d:
|
||||
|
||||
@@ -163,6 +163,26 @@ def scopes_to_string(scopes):
|
||||
return ' '.join(scopes)
|
||||
|
||||
|
||||
def string_to_scopes(scopes):
|
||||
"""Converts stringifed scope value to a list.
|
||||
|
||||
If scopes is a list then it is simply passed through. If scopes is an
|
||||
string then a list of each individual scope is returned.
|
||||
|
||||
Args:
|
||||
scopes: a string or iterable of strings, the scopes.
|
||||
|
||||
Returns:
|
||||
The scopes in a list.
|
||||
"""
|
||||
if not scopes:
|
||||
return []
|
||||
if isinstance(scopes, six.string_types):
|
||||
return scopes.split(' ')
|
||||
else:
|
||||
return scopes
|
||||
|
||||
|
||||
def dict_to_tuple_key(dictionary):
|
||||
"""Converts a dictionary to a tuple that can be used as an immutable key.
|
||||
|
||||
|
||||
@@ -40,6 +40,7 @@ from .http_mock import HttpMock
|
||||
from .http_mock import HttpMockSequence
|
||||
from oauth2client import GOOGLE_REVOKE_URI
|
||||
from oauth2client import GOOGLE_TOKEN_URI
|
||||
from oauth2client import GOOGLE_TOKEN_INFO_URI
|
||||
from oauth2client import client
|
||||
from oauth2client import util as oauth2client_util
|
||||
from oauth2client.client import AccessTokenCredentials
|
||||
@@ -50,6 +51,7 @@ from oauth2client.client import AssertionCredentials
|
||||
from oauth2client.client import AUTHORIZED_USER
|
||||
from oauth2client.client import Credentials
|
||||
from oauth2client.client import DEFAULT_ENV_NAME
|
||||
from oauth2client.client import Error
|
||||
from oauth2client.client import ApplicationDefaultCredentialsError
|
||||
from oauth2client.client import FlowExchangeError
|
||||
from oauth2client.client import GoogleCredentials
|
||||
@@ -647,7 +649,8 @@ class BasicCredentialsTests(unittest.TestCase):
|
||||
self.credentials = OAuth2Credentials(
|
||||
access_token, client_id, client_secret,
|
||||
refresh_token, token_expiry, GOOGLE_TOKEN_URI,
|
||||
user_agent, revoke_uri=GOOGLE_REVOKE_URI)
|
||||
user_agent, revoke_uri=GOOGLE_REVOKE_URI, scopes='foo',
|
||||
token_info_uri=GOOGLE_TOKEN_INFO_URI)
|
||||
|
||||
# Provoke a failure if @util.positional is not respected.
|
||||
self.old_positional_enforcement = (
|
||||
@@ -843,6 +846,43 @@ class BasicCredentialsTests(unittest.TestCase):
|
||||
self.assertFalse(self.credentials.access_token_expired)
|
||||
self.assertEqual(token_response_second, self.credentials.token_response)
|
||||
|
||||
def test_has_scopes(self):
|
||||
self.assertTrue(self.credentials.has_scopes('foo'))
|
||||
self.assertTrue(self.credentials.has_scopes(['foo']))
|
||||
self.assertFalse(self.credentials.has_scopes('bar'))
|
||||
self.assertFalse(self.credentials.has_scopes(['bar']))
|
||||
|
||||
self.credentials.scopes = set(['foo', 'bar'])
|
||||
self.assertTrue(self.credentials.has_scopes('foo'))
|
||||
self.assertTrue(self.credentials.has_scopes('bar'))
|
||||
self.assertFalse(self.credentials.has_scopes('baz'))
|
||||
self.assertTrue(self.credentials.has_scopes(['foo', 'bar']))
|
||||
self.assertFalse(self.credentials.has_scopes(['foo', 'baz']))
|
||||
|
||||
self.credentials.scopes = set([])
|
||||
self.assertFalse(self.credentials.has_scopes('foo'))
|
||||
|
||||
def test_retrieve_scopes(self):
|
||||
info_response_first = {'scope': 'foo bar'}
|
||||
info_response_second = {'error_description': 'abcdef'}
|
||||
http = HttpMockSequence([
|
||||
({'status': '200'}, json.dumps(info_response_first).encode('utf-8')),
|
||||
({'status': '400'}, json.dumps(info_response_second).encode('utf-8')),
|
||||
({'status': '500'}, b''),
|
||||
])
|
||||
|
||||
self.credentials.retrieve_scopes(http)
|
||||
self.assertEqual(set(['foo', 'bar']), self.credentials.scopes)
|
||||
|
||||
self.assertRaises(
|
||||
Error,
|
||||
self.credentials.retrieve_scopes,
|
||||
http)
|
||||
|
||||
self.assertRaises(
|
||||
Error,
|
||||
self.credentials.retrieve_scopes,
|
||||
http)
|
||||
|
||||
class AccessTokenCredentialsTests(unittest.TestCase):
|
||||
|
||||
@@ -1069,6 +1109,7 @@ class OAuth2WebServerFlowTest(unittest.TestCase):
|
||||
self.assertNotEqual(None, credentials.token_expiry)
|
||||
self.assertEqual('8xLOxBtZp8', credentials.refresh_token)
|
||||
self.assertEqual('dummy_revoke_uri', credentials.revoke_uri)
|
||||
self.assertEqual(set(['foo']), credentials.scopes)
|
||||
|
||||
def test_exchange_dictlike(self):
|
||||
class FakeDict(object):
|
||||
@@ -1095,6 +1136,7 @@ class OAuth2WebServerFlowTest(unittest.TestCase):
|
||||
self.assertNotEqual(None, credentials.token_expiry)
|
||||
self.assertEqual('8xLOxBtZp8', credentials.refresh_token)
|
||||
self.assertEqual('dummy_revoke_uri', credentials.revoke_uri)
|
||||
self.assertEqual(set(['foo']), credentials.scopes)
|
||||
request_code = urllib.parse.parse_qs(http.requests[0]['body'])['code'][0]
|
||||
self.assertEqual(code, request_code)
|
||||
|
||||
@@ -1229,6 +1271,7 @@ class CredentialsFromCodeTests(unittest.TestCase):
|
||||
http=http)
|
||||
self.assertEqual(credentials.access_token, token)
|
||||
self.assertNotEqual(None, credentials.token_expiry)
|
||||
self.assertEqual(set(['foo']), credentials.scopes)
|
||||
|
||||
def test_exchange_code_for_token_fail(self):
|
||||
http = HttpMockSequence([
|
||||
@@ -1254,6 +1297,7 @@ class CredentialsFromCodeTests(unittest.TestCase):
|
||||
self.code, http=http)
|
||||
self.assertEqual(credentials.access_token, 'asdfghjkl')
|
||||
self.assertNotEqual(None, credentials.token_expiry)
|
||||
self.assertEqual(set(['foo']), credentials.scopes)
|
||||
|
||||
def test_exchange_code_and_cached_file_for_token(self):
|
||||
http = HttpMockSequence([
|
||||
@@ -1266,6 +1310,7 @@ class CredentialsFromCodeTests(unittest.TestCase):
|
||||
'some_secrets', self.scope,
|
||||
self.code, http=http, cache=cache_mock)
|
||||
self.assertEqual(credentials.access_token, 'asdfghjkl')
|
||||
self.assertEqual(set(['foo']), credentials.scopes)
|
||||
|
||||
def test_exchange_code_and_file_for_token_fail(self):
|
||||
http = HttpMockSequence([
|
||||
|
||||
@@ -27,6 +27,19 @@ class ScopeToStringTests(unittest.TestCase):
|
||||
self.assertEqual(expected, util.scopes_to_string(case))
|
||||
|
||||
|
||||
class StringToScopeTests(unittest.TestCase):
|
||||
|
||||
def test_conversion(self):
|
||||
cases = [
|
||||
(['a', 'b'], ['a', 'b']),
|
||||
('', []),
|
||||
('a', ['a']),
|
||||
('a b c d e f', ['a', 'b', 'c', 'd', 'e', 'f']),
|
||||
]
|
||||
|
||||
for case, expected in cases:
|
||||
self.assertEqual(expected, util.string_to_scopes(case))
|
||||
|
||||
class KeyConversionTests(unittest.TestCase):
|
||||
|
||||
def test_key_conversions(self):
|
||||
|
||||
Reference in New Issue
Block a user