diff --git a/keystoneclient/session.py b/keystoneclient/session.py index 6aa7df84..27b850d7 100644 --- a/keystoneclient/session.py +++ b/keystoneclient/session.py @@ -114,7 +114,7 @@ class Session(object): @utils.positional(enforcement=utils.positional.WARN) def request(self, url, method, json=None, original_ip=None, user_agent=None, redirect=None, authenticated=None, - endpoint_filter=None, **kwargs): + endpoint_filter=None, auth=None, requests_auth=None, **kwargs): """Send an HTTP request with the specified characteristics. Wrapper around `requests.Session.request` to handle tasks such as @@ -149,6 +149,14 @@ class Session(object): endpoint to use for this request. If not provided then URL is expected to be a fully qualified URL. (optional) + :param auth: The auth plugin to use when authenticating this request. + This will override the plugin that is attached to the + session (if any). (optional) + :type auth: :class:`keystoneclient.auth.base.BaseAuthPlugin` + :param requests_auth: A requests library auth plugin that cannot be + passed via kwarg because the `auth` kwarg + collides with our own auth plugins. (optional) + :type requests_auth: :class:`requests.auth.AuthBase` :param kwargs: any other parameter that can be passed to requests.Session.request (such as `headers`). Except: 'data' will be overwritten by the data in 'json' param. @@ -164,10 +172,10 @@ class Session(object): headers = kwargs.setdefault('headers', dict()) if authenticated is None: - authenticated = self.auth is not None + authenticated = bool(auth or self.auth) if authenticated: - token = self.get_token() + token = self.get_token(auth) if not token: raise exceptions.AuthorizationFailure("No token Available") @@ -180,7 +188,7 @@ class Session(object): # requests. We check fully qualified here by the presence of a host. url_data = urllib.parse.urlparse(url) if endpoint_filter and not url_data.netloc: - base_url = self.get_endpoint(**endpoint_filter) + base_url = self.get_endpoint(auth, **endpoint_filter) if not base_url: raise exceptions.EndpointNotFound() @@ -210,6 +218,9 @@ class Session(object): kwargs.setdefault('verify', self.verify) + if requests_auth: + kwargs['auth'] = requests_auth + string_parts = ['curl -i'] # NOTE(jamielennox): None means let requests do its default validation @@ -355,26 +366,45 @@ class Session(object): original_ip=kwargs.pop('original_ip', None), user_agent=kwargs.pop('user_agent', None)) - def get_token(self): + def get_token(self, auth=None): """Return a token as provided by the auth plugin. + :param auth: The auth plugin to use for token. Overrides the plugin + on the session. (optional) + :type auth: :class:`keystoneclient.auth.base.BaseAuthPlugin` + :raises AuthorizationFailure: if a new token fetch fails. :returns string: A valid token. """ - if not self.auth: + if not auth: + auth = self.auth + + if not auth: raise exceptions.MissingAuthPlugin("Token Required") try: - return self.auth.get_token(self) + return auth.get_token(self) except exceptions.HTTPError as exc: raise exceptions.AuthorizationFailure("Authentication failure: " "%s" % exc) - def get_endpoint(self, **kwargs): - """Get an endpoint as provided by the auth plugin.""" - if not self.auth: + def get_endpoint(self, auth=None, **kwargs): + """Get an endpoint as provided by the auth plugin. + + :param auth: The auth plugin to use for token. Overrides the plugin on + the session. (optional) + :type auth: :class:`keystoneclient.auth.base.BaseAuthPlugin` + + :raises MissingAuthPlugin: if a plugin is not available. + + :returns string: An endpoint if available or None. + """ + if not auth: + auth = self.auth + + if not auth: raise exceptions.MissingAuthPlugin('An auth plugin is required to ' 'determine the endpoint URL.') - return self.auth.get_endpoint(self, **kwargs) + return auth.get_endpoint(self, **kwargs) diff --git a/keystoneclient/tests/test_session.py b/keystoneclient/tests/test_session.py index 7ebb1307..c86bc355 100644 --- a/keystoneclient/tests/test_session.py +++ b/keystoneclient/tests/test_session.py @@ -304,6 +304,23 @@ class AuthPlugin(base.BaseAuthPlugin): return None +class CalledAuthPlugin(base.BaseAuthPlugin): + + ENDPOINT = 'http://fakeendpoint/' + + def __init__(self): + self.get_token_called = False + self.get_endpoint_called = False + + def get_token(self, session): + self.get_token_called = True + return 'aToken' + + def get_endpoint(self, session, **kwargs): + self.get_endpoint_called = True + return self.ENDPOINT + + class SessionAuthTests(utils.TestCase): TEST_URL = 'http://127.0.0.1:5000/' @@ -375,3 +392,64 @@ class SessionAuthTests(utils.TestCase): sess.get, '/path', endpoint_filter={'service_type': 'unknown', 'interface': 'public'}) + + @httpretty.activate + def test_passed_auth_plugin(self): + passed = CalledAuthPlugin() + sess = client_session.Session() + + httpretty.register_uri(httpretty.GET, + CalledAuthPlugin.ENDPOINT + 'path', + status=200) + endpoint_filter = {'service_type': 'identity'} + + # no plugin with authenticated won't work + self.assertRaises(exceptions.MissingAuthPlugin, sess.get, 'path', + authenticated=True) + + # no plugin with an endpoint filter won't work + self.assertRaises(exceptions.MissingAuthPlugin, sess.get, 'path', + authenticated=False, endpoint_filter=endpoint_filter) + + resp = sess.get('path', auth=passed, endpoint_filter=endpoint_filter) + + self.assertEqual(200, resp.status_code) + self.assertTrue(passed.get_endpoint_called) + self.assertTrue(passed.get_token_called) + + @httpretty.activate + def test_passed_auth_plugin_overrides(self): + fixed = CalledAuthPlugin() + passed = CalledAuthPlugin() + + sess = client_session.Session(fixed) + + httpretty.register_uri(httpretty.GET, + CalledAuthPlugin.ENDPOINT + 'path', + status=200) + + resp = sess.get('path', auth=passed, + endpoint_filter={'service_type': 'identity'}) + + self.assertEqual(200, resp.status_code) + self.assertTrue(passed.get_endpoint_called) + self.assertTrue(passed.get_token_called) + self.assertFalse(fixed.get_endpoint_called) + self.assertFalse(fixed.get_token_called) + + def test_requests_auth_plugin(self): + sess = client_session.Session() + + requests_auth = object() + + FAKE_RESP = utils.TestResponse({'status_code': 200, 'text': 'resp'}) + RESP = mock.Mock(return_value=FAKE_RESP) + + with mock.patch.object(sess.session, 'request', RESP) as mocked: + sess.get(self.TEST_URL, requests_auth=requests_auth) + + mocked.assert_called_once_with('GET', self.TEST_URL, + headers=mock.ANY, + allow_redirects=mock.ANY, + auth=requests_auth, + verify=mock.ANY)