diff --git a/keystoneauth1/service_token.py b/keystoneauth1/service_token.py new file mode 100644 index 00000000..d06402b6 --- /dev/null +++ b/keystoneauth1/service_token.py @@ -0,0 +1,73 @@ +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + +from keystoneauth1 import plugin + +SERVICE_AUTH_HEADER_NAME = 'X-Service-Token' + +__all__ = ('ServiceTokenAuthWrapper',) + + +class ServiceTokenAuthWrapper(plugin.BaseAuthPlugin): + + def __init__(self, user_auth, service_auth): + self.user_auth = user_auth + self.service_auth = service_auth + + def get_headers(self, session, **kwargs): + headers = self.user_auth.get_headers(session, **kwargs) + + token = self.service_auth.get_token(session, **kwargs) + headers[SERVICE_AUTH_HEADER_NAME] = token + + return headers + + def invalidate(self): + # NOTE(jamielennox): hmm, what to do here? Should we invalidate both + # the service and user auth? Only one? There's no way to know what the + # failure was to selectively invalidate. + user = self.user_auth.invalidate() + service = self.service_auth.invalidate() + return user or service + + def get_connection_params(self, *args, **kwargs): + # NOTE(jamielennox): This is also a bit of a guess but unlikely to be a + # problem in practice. We don't know how merging connection parameters + # between these plugins will conflict - but there aren't many plugins + # that set this anyway. + # Take the service auth params first so that user auth params will be + # given priority. + params = self.service_auth.get_connection_params(*args, **kwargs) + params.update(self.user_auth.get_connection_params(*args, **kwargs)) + return params + + # TODO(jamielennox): Everything below here is a generic wrapper that could + # be extracted into a base wrapper class. We can do this as soon as there + # is a need for it, but we may never actually need it. + + def get_token(self, *args, **kwargs): + return self.user_auth.get_token(*args, **kwargs) + + def get_endpoint(self, *args, **kwargs): + return self.user_auth.get_endpoint(*args, **kwargs) + + def get_user_id(self, *args, **kwargs): + return self.user_auth.get_user_id(*args, **kwargs) + + def get_project_id(self, *args, **kwargs): + return self.user_auth.get_project_id(*args, **kwargs) + + def get_sp_auth_url(self, *args, **kwargs): + return self.user_auth.get_sp_auth_url(*args, **kwargs) + + def get_sp_url(self, *args, **kwargs): + return self.user_auth.get_sp_url(*args, **kwargs) diff --git a/keystoneauth1/tests/unit/test_service_token.py b/keystoneauth1/tests/unit/test_service_token.py new file mode 100644 index 00000000..1fe20562 --- /dev/null +++ b/keystoneauth1/tests/unit/test_service_token.py @@ -0,0 +1,116 @@ +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + +import uuid + +from keystoneauth1 import fixture +from keystoneauth1 import identity +from keystoneauth1 import service_token +from keystoneauth1 import session +from keystoneauth1.tests.unit import utils + + +class ServiceTokenTests(utils.TestCase): + + TEST_URL = 'http://test.example.com/path/' + USER_URL = 'http://user-keystone.example.com/v3' + SERVICE_URL = 'http://service-keystone.example.com/v3' + + def setUp(self): + super(ServiceTokenTests, self).setUp() + + self.user_token_id = uuid.uuid4().hex + self.user_token = fixture.V3Token() + self.user_token.set_project_scope() + self.user_auth = identity.V3Password(auth_url=self.USER_URL, + user_id=uuid.uuid4().hex, + password=uuid.uuid4().hex, + project_id=uuid.uuid4().hex) + + self.service_token_id = uuid.uuid4().hex + self.service_token = fixture.V3Token() + self.service_token.set_project_scope() + self.service_auth = identity.V3Password(auth_url=self.SERVICE_URL, + user_id=uuid.uuid4().hex, + password=uuid.uuid4().hex, + project_id=uuid.uuid4().hex) + + for t in (self.user_token, self.service_token): + s = t.add_service('identity') + s.add_standard_endpoints(public='http://keystone.example.com', + admin='http://keystone.example.com', + internal='http://keystone.example.com') + + self.test_data = {'data': uuid.uuid4().hex} + + self.user_mock = self.requests_mock.post( + self.USER_URL + '/auth/tokens', + json=self.user_token, + headers={'X-Subject-Token': self.user_token_id}) + + self.service_mock = self.requests_mock.post( + self.SERVICE_URL + '/auth/tokens', + json=self.service_token, + headers={'X-Subject-Token': self.service_token_id}) + + self.requests_mock.get(self.TEST_URL, json=self.test_data) + + self.combined_auth = service_token.ServiceTokenAuthWrapper( + self.user_auth, + self.service_auth) + + self.session = session.Session(auth=self.combined_auth) + + def test_setting_service_token(self): + self.session.get(self.TEST_URL) + + headers = self.requests_mock.last_request.headers + + self.assertEqual(self.user_token_id, headers['X-Auth-Token']) + self.assertEqual(self.service_token_id, headers['X-Service-Token']) + + self.assertTrue(self.user_mock.called_once) + self.assertTrue(self.service_mock.called_once) + + def test_invalidation(self): + text = uuid.uuid4().hex + test_url = 'http://test.example.com/abc' + + response_list = [{'status_code': 401}, {'text': text}] + mock = self.requests_mock.get(test_url, response_list=response_list) + + resp = self.session.get(test_url) + self.assertEqual(text, resp.text) + + self.assertEqual(2, mock.call_count) + self.assertEqual(2, self.user_mock.call_count) + self.assertEqual(2, self.service_mock.call_count) + + def test_pass_throughs(self): + self.assertEqual(self.user_auth.get_token(self.session), + self.combined_auth.get_token(self.session)) + + self.assertEqual( + self.user_auth.get_endpoint(self.session, 'identity'), + self.combined_auth.get_endpoint(self.session, 'identity')) + + self.assertEqual(self.user_auth.get_user_id(self.session), + self.combined_auth.get_user_id(self.session)) + + self.assertEqual(self.user_auth.get_project_id(self.session), + self.combined_auth.get_project_id(self.session)) + + self.assertEqual(self.user_auth.get_sp_auth_url(self.session, 'a'), + self.combined_auth.get_sp_auth_url(self.session, 'a')) + + self.assertEqual(self.user_auth.get_sp_url(self.session, 'a'), + self.combined_auth.get_sp_url(self.session, 'a'))