Merge "Allow passing auth plugin as a parameter"
This commit is contained in:
commit
61322a3fee
keystoneclient
@ -114,7 +114,7 @@ class Session(object):
|
||||
@utils.positional(enforcement=utils.positional.WARN)
|
||||
def request(self, url, method, json=None, original_ip=None,
|
||||
user_agent=None, redirect=None, authenticated=None,
|
||||
endpoint_filter=None, **kwargs):
|
||||
endpoint_filter=None, auth=None, requests_auth=None, **kwargs):
|
||||
"""Send an HTTP request with the specified characteristics.
|
||||
|
||||
Wrapper around `requests.Session.request` to handle tasks such as
|
||||
@ -149,6 +149,14 @@ class Session(object):
|
||||
endpoint to use for this request. If not
|
||||
provided then URL is expected to be a
|
||||
fully qualified URL. (optional)
|
||||
:param auth: The auth plugin to use when authenticating this request.
|
||||
This will override the plugin that is attached to the
|
||||
session (if any). (optional)
|
||||
:type auth: :class:`keystoneclient.auth.base.BaseAuthPlugin`
|
||||
:param requests_auth: A requests library auth plugin that cannot be
|
||||
passed via kwarg because the `auth` kwarg
|
||||
collides with our own auth plugins. (optional)
|
||||
:type requests_auth: :class:`requests.auth.AuthBase`
|
||||
:param kwargs: any other parameter that can be passed to
|
||||
requests.Session.request (such as `headers`). Except:
|
||||
'data' will be overwritten by the data in 'json' param.
|
||||
@ -164,10 +172,10 @@ class Session(object):
|
||||
headers = kwargs.setdefault('headers', dict())
|
||||
|
||||
if authenticated is None:
|
||||
authenticated = self.auth is not None
|
||||
authenticated = bool(auth or self.auth)
|
||||
|
||||
if authenticated:
|
||||
token = self.get_token()
|
||||
token = self.get_token(auth)
|
||||
|
||||
if not token:
|
||||
raise exceptions.AuthorizationFailure("No token Available")
|
||||
@ -180,7 +188,7 @@ class Session(object):
|
||||
# requests. We check fully qualified here by the presence of a host.
|
||||
url_data = urllib.parse.urlparse(url)
|
||||
if endpoint_filter and not url_data.netloc:
|
||||
base_url = self.get_endpoint(**endpoint_filter)
|
||||
base_url = self.get_endpoint(auth, **endpoint_filter)
|
||||
|
||||
if not base_url:
|
||||
raise exceptions.EndpointNotFound()
|
||||
@ -210,6 +218,9 @@ class Session(object):
|
||||
|
||||
kwargs.setdefault('verify', self.verify)
|
||||
|
||||
if requests_auth:
|
||||
kwargs['auth'] = requests_auth
|
||||
|
||||
string_parts = ['curl -i']
|
||||
|
||||
# NOTE(jamielennox): None means let requests do its default validation
|
||||
@ -355,26 +366,45 @@ class Session(object):
|
||||
original_ip=kwargs.pop('original_ip', None),
|
||||
user_agent=kwargs.pop('user_agent', None))
|
||||
|
||||
def get_token(self):
|
||||
def get_token(self, auth=None):
|
||||
"""Return a token as provided by the auth plugin.
|
||||
|
||||
:param auth: The auth plugin to use for token. Overrides the plugin
|
||||
on the session. (optional)
|
||||
:type auth: :class:`keystoneclient.auth.base.BaseAuthPlugin`
|
||||
|
||||
:raises AuthorizationFailure: if a new token fetch fails.
|
||||
|
||||
:returns string: A valid token.
|
||||
"""
|
||||
if not self.auth:
|
||||
if not auth:
|
||||
auth = self.auth
|
||||
|
||||
if not auth:
|
||||
raise exceptions.MissingAuthPlugin("Token Required")
|
||||
|
||||
try:
|
||||
return self.auth.get_token(self)
|
||||
return auth.get_token(self)
|
||||
except exceptions.HTTPError as exc:
|
||||
raise exceptions.AuthorizationFailure("Authentication failure: "
|
||||
"%s" % exc)
|
||||
|
||||
def get_endpoint(self, **kwargs):
|
||||
"""Get an endpoint as provided by the auth plugin."""
|
||||
if not self.auth:
|
||||
def get_endpoint(self, auth=None, **kwargs):
|
||||
"""Get an endpoint as provided by the auth plugin.
|
||||
|
||||
:param auth: The auth plugin to use for token. Overrides the plugin on
|
||||
the session. (optional)
|
||||
:type auth: :class:`keystoneclient.auth.base.BaseAuthPlugin`
|
||||
|
||||
:raises MissingAuthPlugin: if a plugin is not available.
|
||||
|
||||
:returns string: An endpoint if available or None.
|
||||
"""
|
||||
if not auth:
|
||||
auth = self.auth
|
||||
|
||||
if not auth:
|
||||
raise exceptions.MissingAuthPlugin('An auth plugin is required to '
|
||||
'determine the endpoint URL.')
|
||||
|
||||
return self.auth.get_endpoint(self, **kwargs)
|
||||
return auth.get_endpoint(self, **kwargs)
|
||||
|
@ -304,6 +304,23 @@ class AuthPlugin(base.BaseAuthPlugin):
|
||||
return None
|
||||
|
||||
|
||||
class CalledAuthPlugin(base.BaseAuthPlugin):
|
||||
|
||||
ENDPOINT = 'http://fakeendpoint/'
|
||||
|
||||
def __init__(self):
|
||||
self.get_token_called = False
|
||||
self.get_endpoint_called = False
|
||||
|
||||
def get_token(self, session):
|
||||
self.get_token_called = True
|
||||
return 'aToken'
|
||||
|
||||
def get_endpoint(self, session, **kwargs):
|
||||
self.get_endpoint_called = True
|
||||
return self.ENDPOINT
|
||||
|
||||
|
||||
class SessionAuthTests(utils.TestCase):
|
||||
|
||||
TEST_URL = 'http://127.0.0.1:5000/'
|
||||
@ -375,3 +392,64 @@ class SessionAuthTests(utils.TestCase):
|
||||
sess.get, '/path',
|
||||
endpoint_filter={'service_type': 'unknown',
|
||||
'interface': 'public'})
|
||||
|
||||
@httpretty.activate
|
||||
def test_passed_auth_plugin(self):
|
||||
passed = CalledAuthPlugin()
|
||||
sess = client_session.Session()
|
||||
|
||||
httpretty.register_uri(httpretty.GET,
|
||||
CalledAuthPlugin.ENDPOINT + 'path',
|
||||
status=200)
|
||||
endpoint_filter = {'service_type': 'identity'}
|
||||
|
||||
# no plugin with authenticated won't work
|
||||
self.assertRaises(exceptions.MissingAuthPlugin, sess.get, 'path',
|
||||
authenticated=True)
|
||||
|
||||
# no plugin with an endpoint filter won't work
|
||||
self.assertRaises(exceptions.MissingAuthPlugin, sess.get, 'path',
|
||||
authenticated=False, endpoint_filter=endpoint_filter)
|
||||
|
||||
resp = sess.get('path', auth=passed, endpoint_filter=endpoint_filter)
|
||||
|
||||
self.assertEqual(200, resp.status_code)
|
||||
self.assertTrue(passed.get_endpoint_called)
|
||||
self.assertTrue(passed.get_token_called)
|
||||
|
||||
@httpretty.activate
|
||||
def test_passed_auth_plugin_overrides(self):
|
||||
fixed = CalledAuthPlugin()
|
||||
passed = CalledAuthPlugin()
|
||||
|
||||
sess = client_session.Session(fixed)
|
||||
|
||||
httpretty.register_uri(httpretty.GET,
|
||||
CalledAuthPlugin.ENDPOINT + 'path',
|
||||
status=200)
|
||||
|
||||
resp = sess.get('path', auth=passed,
|
||||
endpoint_filter={'service_type': 'identity'})
|
||||
|
||||
self.assertEqual(200, resp.status_code)
|
||||
self.assertTrue(passed.get_endpoint_called)
|
||||
self.assertTrue(passed.get_token_called)
|
||||
self.assertFalse(fixed.get_endpoint_called)
|
||||
self.assertFalse(fixed.get_token_called)
|
||||
|
||||
def test_requests_auth_plugin(self):
|
||||
sess = client_session.Session()
|
||||
|
||||
requests_auth = object()
|
||||
|
||||
FAKE_RESP = utils.TestResponse({'status_code': 200, 'text': 'resp'})
|
||||
RESP = mock.Mock(return_value=FAKE_RESP)
|
||||
|
||||
with mock.patch.object(sess.session, 'request', RESP) as mocked:
|
||||
sess.get(self.TEST_URL, requests_auth=requests_auth)
|
||||
|
||||
mocked.assert_called_once_with('GET', self.TEST_URL,
|
||||
headers=mock.ANY,
|
||||
allow_redirects=mock.ANY,
|
||||
auth=requests_auth,
|
||||
verify=mock.ANY)
|
||||
|
Loading…
x
Reference in New Issue
Block a user