diff --git a/oauth2client/__init__.py b/oauth2client/__init__.py index 8678037..f3c5f91 100644 --- a/oauth2client/__init__.py +++ b/oauth2client/__init__.py @@ -6,3 +6,4 @@ GOOGLE_AUTH_URI = 'https://accounts.google.com/o/oauth2/auth' GOOGLE_DEVICE_URI = 'https://accounts.google.com/o/oauth2/device/code' GOOGLE_REVOKE_URI = 'https://accounts.google.com/o/oauth2/revoke' GOOGLE_TOKEN_URI = 'https://accounts.google.com/o/oauth2/token' +GOOGLE_TOKEN_INFO_URI = 'https://www.googleapis.com/oauth2/v2/tokeninfo' diff --git a/oauth2client/client.py b/oauth2client/client.py index 5956f44..54ec62b 100644 --- a/oauth2client/client.py +++ b/oauth2client/client.py @@ -39,6 +39,7 @@ from oauth2client import GOOGLE_AUTH_URI from oauth2client import GOOGLE_DEVICE_URI from oauth2client import GOOGLE_REVOKE_URI from oauth2client import GOOGLE_TOKEN_URI +from oauth2client import GOOGLE_TOKEN_INFO_URI from oauth2client._helpers import _urlsafe_b64decode from oauth2client import clientsecrets from oauth2client import util @@ -252,6 +253,8 @@ class Credentials(object): for key, val in d.items(): if isinstance(val, bytes): d[key] = val.decode('utf-8') + if isinstance(val, set): + d[key] = list(val) return json.dumps(d) def to_json(self): @@ -459,7 +462,8 @@ class OAuth2Credentials(Credentials): @util.positional(8) def __init__(self, access_token, client_id, client_secret, refresh_token, token_expiry, token_uri, user_agent, revoke_uri=None, - id_token=None, token_response=None): + id_token=None, token_response=None, scopes=None, + token_info_uri=None): """Create an instance of OAuth2Credentials. This constructor is not usually called by the user, instead @@ -479,6 +483,9 @@ class OAuth2Credentials(Credentials): token_response: dict, the decoded response to the token request. None if a token hasn't been requested yet. Stored because some providers (e.g. wordpress.com) include extra fields that clients may want. + scopes: list, authorized scopes for these credentials. + token_info_uri: string, the URI for the token info endpoint. Defaults to + None; scopes can not be refreshed if this is None. Notes: store: callable, A callable that when passed a Credential @@ -497,6 +504,8 @@ class OAuth2Credentials(Credentials): self.revoke_uri = revoke_uri self.id_token = id_token self.token_response = token_response + self.scopes = set(util.string_to_scopes(scopes or [])) + self.token_info_uri = token_info_uri # True if the credentials have been revoked or expired and can't be # refreshed. @@ -614,6 +623,39 @@ class OAuth2Credentials(Credentials): """ headers['Authorization'] = 'Bearer ' + self.access_token + def has_scopes(self, scopes): + """Verify that the credentials are authorized for the given scopes. + + Returns True if the credentials authorized scopes contain all of the scopes + given. + + Args: + scopes: list or string, the scopes to check. + + Notes: + There are cases where the credentials are unaware of which scopes are + authorized. Notably, credentials obtained and stored before this code was + added will not have scopes, AccessTokenCredentials do not have scopes. In + both cases, you can use refresh_scopes() to obtain the canonical set of + scopes. + """ + scopes = util.string_to_scopes(scopes) + return set(scopes).issubset(self.scopes) + + def retrieve_scopes(self, http): + """Retrieves the canonical list of scopes for this access token from the + OAuth2 provider. + + Args: + http: httplib2.Http, an http object to be used to make the refresh + request. + + Returns: + A set of strings containing the canonical list of scopes. + """ + self._retrieve_scopes(http.request) + return self.scopes + def to_json(self): return self._to_json(Credentials.NON_SERIALIZED_MEMBERS) @@ -648,7 +690,9 @@ class OAuth2Credentials(Credentials): data['user_agent'], revoke_uri=data.get('revoke_uri', None), id_token=data.get('id_token', None), - token_response=data.get('token_response', None)) + token_response=data.get('token_response', None), + scopes=data.get('scopes', None), + token_info_uri=data.get('token_info_uri', None)) retval.invalid = data['invalid'] return retval @@ -858,6 +902,10 @@ class OAuth2Credentials(Credentials): query_params = {'token': token} token_revoke_uri = _update_query_params(self.revoke_uri, query_params) resp, content = http_request(token_revoke_uri) + + if six.PY3 and isinstance(content, bytes): + content = content.decode('utf-8') + if resp.status == 200: self.invalid = True else: @@ -873,6 +921,48 @@ class OAuth2Credentials(Credentials): if self.store: self.store.delete() + def _retrieve_scopes(self, http_request): + """Retrieves the list of authorized scopes from the OAuth2 provider. + + Args: + http_request: callable, a callable that matches the method signature of + httplib2.Http.request, used to make the revoke request. + """ + self._do_retrieve_scopes(http_request, self.access_token) + + def _do_retrieve_scopes(self, http_request, token): + """Retrieves the list of authorized scopes from the OAuth2 provider. + + Args: + http_request: callable, a callable that matches the method signature of + httplib2.Http.request, used to make the refresh request. + token: A string used as the token to identify the credentials to the + provider. + + Raises: + Error: When refresh fails, indicating the the access token is invalid. + """ + logger.info('Refreshing scopes') + query_params = {'access_token': token, 'fields': 'scope'} + token_info_uri = _update_query_params(self.token_info_uri, query_params) + resp, content = http_request(token_info_uri) + + if six.PY3 and isinstance(content, bytes): + content = content.decode('utf-8') + + if resp.status == 200: + d = json.loads(content) + self.scopes = set(util.string_to_scopes(d.get('scope', ''))) + else: + error_msg = 'Invalid response %s.' % (resp.status,) + try: + d = json.loads(content) + if 'error_description' in d: + error_msg = d['error_description'] + except (TypeError, ValueError): + pass + raise Error(error_msg) + class AccessTokenCredentials(OAuth2Credentials): """Credentials object for OAuth 2.0. @@ -1650,7 +1740,8 @@ def credentials_from_code(client_id, client_secret, scope, code, user_agent=None, token_uri=GOOGLE_TOKEN_URI, auth_uri=GOOGLE_AUTH_URI, revoke_uri=GOOGLE_REVOKE_URI, - device_uri=GOOGLE_DEVICE_URI): + device_uri=GOOGLE_DEVICE_URI, + token_info_uri=GOOGLE_TOKEN_INFO_URI): """Exchanges an authorization code for an OAuth2Credentials object. Args: @@ -1681,7 +1772,8 @@ def credentials_from_code(client_id, client_secret, scope, code, flow = OAuth2WebServerFlow(client_id, client_secret, scope, redirect_uri=redirect_uri, user_agent=user_agent, auth_uri=auth_uri, token_uri=token_uri, - revoke_uri=revoke_uri, device_uri=device_uri) + revoke_uri=revoke_uri, device_uri=device_uri, + token_info_uri=token_info_uri) credentials = flow.step2_exchange(code, http=http) return credentials @@ -1786,6 +1878,7 @@ class OAuth2WebServerFlow(Flow): revoke_uri=GOOGLE_REVOKE_URI, login_hint=None, device_uri=GOOGLE_DEVICE_URI, + token_info_uri=GOOGLE_TOKEN_INFO_URI, authorization_header=None, **kwargs): """Constructor for OAuth2WebServerFlow. @@ -1834,6 +1927,7 @@ class OAuth2WebServerFlow(Flow): self.token_uri = token_uri self.revoke_uri = revoke_uri self.device_uri = device_uri + self.token_info_uri = token_info_uri self.authorization_header = authorization_header self.params = { 'access_type': 'offline', @@ -2011,7 +2105,9 @@ class OAuth2WebServerFlow(Flow): self.token_uri, self.user_agent, revoke_uri=self.revoke_uri, id_token=extracted_id_token, - token_response=d) + token_response=d, + scopes=self.scope, + token_info_uri=self.token_info_uri) else: logger.info('Failed to retrieve access token: %s', content) if 'error' in d: diff --git a/oauth2client/util.py b/oauth2client/util.py index a706f02..94c2523 100644 --- a/oauth2client/util.py +++ b/oauth2client/util.py @@ -163,6 +163,26 @@ def scopes_to_string(scopes): return ' '.join(scopes) +def string_to_scopes(scopes): + """Converts stringifed scope value to a list. + + If scopes is a list then it is simply passed through. If scopes is an + string then a list of each individual scope is returned. + + Args: + scopes: a string or iterable of strings, the scopes. + + Returns: + The scopes in a list. + """ + if not scopes: + return [] + if isinstance(scopes, six.string_types): + return scopes.split(' ') + else: + return scopes + + def dict_to_tuple_key(dictionary): """Converts a dictionary to a tuple that can be used as an immutable key. diff --git a/tests/test_oauth2client.py b/tests/test_oauth2client.py index 5be488e..04496d6 100644 --- a/tests/test_oauth2client.py +++ b/tests/test_oauth2client.py @@ -40,6 +40,7 @@ from .http_mock import HttpMock from .http_mock import HttpMockSequence from oauth2client import GOOGLE_REVOKE_URI from oauth2client import GOOGLE_TOKEN_URI +from oauth2client import GOOGLE_TOKEN_INFO_URI from oauth2client import client from oauth2client import util as oauth2client_util from oauth2client.client import AccessTokenCredentials @@ -50,6 +51,7 @@ from oauth2client.client import AssertionCredentials from oauth2client.client import AUTHORIZED_USER from oauth2client.client import Credentials from oauth2client.client import DEFAULT_ENV_NAME +from oauth2client.client import Error from oauth2client.client import ApplicationDefaultCredentialsError from oauth2client.client import FlowExchangeError from oauth2client.client import GoogleCredentials @@ -647,7 +649,8 @@ class BasicCredentialsTests(unittest.TestCase): self.credentials = OAuth2Credentials( access_token, client_id, client_secret, refresh_token, token_expiry, GOOGLE_TOKEN_URI, - user_agent, revoke_uri=GOOGLE_REVOKE_URI) + user_agent, revoke_uri=GOOGLE_REVOKE_URI, scopes='foo', + token_info_uri=GOOGLE_TOKEN_INFO_URI) # Provoke a failure if @util.positional is not respected. self.old_positional_enforcement = ( @@ -843,6 +846,43 @@ class BasicCredentialsTests(unittest.TestCase): self.assertFalse(self.credentials.access_token_expired) self.assertEqual(token_response_second, self.credentials.token_response) + def test_has_scopes(self): + self.assertTrue(self.credentials.has_scopes('foo')) + self.assertTrue(self.credentials.has_scopes(['foo'])) + self.assertFalse(self.credentials.has_scopes('bar')) + self.assertFalse(self.credentials.has_scopes(['bar'])) + + self.credentials.scopes = set(['foo', 'bar']) + self.assertTrue(self.credentials.has_scopes('foo')) + self.assertTrue(self.credentials.has_scopes('bar')) + self.assertFalse(self.credentials.has_scopes('baz')) + self.assertTrue(self.credentials.has_scopes(['foo', 'bar'])) + self.assertFalse(self.credentials.has_scopes(['foo', 'baz'])) + + self.credentials.scopes = set([]) + self.assertFalse(self.credentials.has_scopes('foo')) + + def test_retrieve_scopes(self): + info_response_first = {'scope': 'foo bar'} + info_response_second = {'error_description': 'abcdef'} + http = HttpMockSequence([ + ({'status': '200'}, json.dumps(info_response_first).encode('utf-8')), + ({'status': '400'}, json.dumps(info_response_second).encode('utf-8')), + ({'status': '500'}, b''), + ]) + + self.credentials.retrieve_scopes(http) + self.assertEqual(set(['foo', 'bar']), self.credentials.scopes) + + self.assertRaises( + Error, + self.credentials.retrieve_scopes, + http) + + self.assertRaises( + Error, + self.credentials.retrieve_scopes, + http) class AccessTokenCredentialsTests(unittest.TestCase): @@ -1069,6 +1109,7 @@ class OAuth2WebServerFlowTest(unittest.TestCase): self.assertNotEqual(None, credentials.token_expiry) self.assertEqual('8xLOxBtZp8', credentials.refresh_token) self.assertEqual('dummy_revoke_uri', credentials.revoke_uri) + self.assertEqual(set(['foo']), credentials.scopes) def test_exchange_dictlike(self): class FakeDict(object): @@ -1095,6 +1136,7 @@ class OAuth2WebServerFlowTest(unittest.TestCase): self.assertNotEqual(None, credentials.token_expiry) self.assertEqual('8xLOxBtZp8', credentials.refresh_token) self.assertEqual('dummy_revoke_uri', credentials.revoke_uri) + self.assertEqual(set(['foo']), credentials.scopes) request_code = urllib.parse.parse_qs(http.requests[0]['body'])['code'][0] self.assertEqual(code, request_code) @@ -1229,6 +1271,7 @@ class CredentialsFromCodeTests(unittest.TestCase): http=http) self.assertEqual(credentials.access_token, token) self.assertNotEqual(None, credentials.token_expiry) + self.assertEqual(set(['foo']), credentials.scopes) def test_exchange_code_for_token_fail(self): http = HttpMockSequence([ @@ -1254,6 +1297,7 @@ class CredentialsFromCodeTests(unittest.TestCase): self.code, http=http) self.assertEqual(credentials.access_token, 'asdfghjkl') self.assertNotEqual(None, credentials.token_expiry) + self.assertEqual(set(['foo']), credentials.scopes) def test_exchange_code_and_cached_file_for_token(self): http = HttpMockSequence([ @@ -1266,6 +1310,7 @@ class CredentialsFromCodeTests(unittest.TestCase): 'some_secrets', self.scope, self.code, http=http, cache=cache_mock) self.assertEqual(credentials.access_token, 'asdfghjkl') + self.assertEqual(set(['foo']), credentials.scopes) def test_exchange_code_and_file_for_token_fail(self): http = HttpMockSequence([ diff --git a/tests/test_util.py b/tests/test_util.py index 2d67316..b3fc326 100644 --- a/tests/test_util.py +++ b/tests/test_util.py @@ -27,6 +27,19 @@ class ScopeToStringTests(unittest.TestCase): self.assertEqual(expected, util.scopes_to_string(case)) +class StringToScopeTests(unittest.TestCase): + + def test_conversion(self): + cases = [ + (['a', 'b'], ['a', 'b']), + ('', []), + ('a', ['a']), + ('a b c d e f', ['a', 'b', 'c', 'd', 'e', 'f']), + ] + + for case, expected in cases: + self.assertEqual(expected, util.string_to_scopes(case)) + class KeyConversionTests(unittest.TestCase): def test_key_conversions(self):