diff --git a/keystone/federation/V8_backends/__init__.py b/keystone/federation/V8_backends/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/keystone/federation/V8_backends/sql.py b/keystone/federation/V8_backends/sql.py new file mode 100644 index 0000000000..851a91128d --- /dev/null +++ b/keystone/federation/V8_backends/sql.py @@ -0,0 +1,366 @@ +# Copyright 2014 OpenStack Foundation +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + +from oslo_serialization import jsonutils +from sqlalchemy import orm + +from keystone.common import sql +from keystone import exception +from keystone.federation import core + + +class FederationProtocolModel(sql.ModelBase, sql.DictBase): + __tablename__ = 'federation_protocol' + attributes = ['id', 'idp_id', 'mapping_id'] + mutable_attributes = frozenset(['mapping_id']) + + id = sql.Column(sql.String(64), primary_key=True) + idp_id = sql.Column(sql.String(64), sql.ForeignKey('identity_provider.id', + ondelete='CASCADE'), primary_key=True) + mapping_id = sql.Column(sql.String(64), nullable=False) + + @classmethod + def from_dict(cls, dictionary): + new_dictionary = dictionary.copy() + return cls(**new_dictionary) + + def to_dict(self): + """Return a dictionary with model's attributes.""" + d = dict() + for attr in self.__class__.attributes: + d[attr] = getattr(self, attr) + return d + + +class IdentityProviderModel(sql.ModelBase, sql.DictBase): + __tablename__ = 'identity_provider' + attributes = ['id', 'enabled', 'description', 'remote_ids'] + mutable_attributes = frozenset(['description', 'enabled', 'remote_ids']) + + id = sql.Column(sql.String(64), primary_key=True) + enabled = sql.Column(sql.Boolean, nullable=False) + description = sql.Column(sql.Text(), nullable=True) + remote_ids = orm.relationship('IdPRemoteIdsModel', + order_by='IdPRemoteIdsModel.remote_id', + cascade='all, delete-orphan') + + @classmethod + def from_dict(cls, dictionary): + new_dictionary = dictionary.copy() + remote_ids_list = new_dictionary.pop('remote_ids', None) + if not remote_ids_list: + remote_ids_list = [] + identity_provider = cls(**new_dictionary) + remote_ids = [] + # NOTE(fmarco76): the remote_ids_list contains only remote ids + # associated with the IdP because of the "relationship" established in + # sqlalchemy and corresponding to the FK in the idp_remote_ids table + for remote in remote_ids_list: + remote_ids.append(IdPRemoteIdsModel(remote_id=remote)) + identity_provider.remote_ids = remote_ids + return identity_provider + + def to_dict(self): + """Return a dictionary with model's attributes.""" + d = dict() + for attr in self.__class__.attributes: + d[attr] = getattr(self, attr) + d['remote_ids'] = [] + for remote in self.remote_ids: + d['remote_ids'].append(remote.remote_id) + return d + + +class IdPRemoteIdsModel(sql.ModelBase, sql.DictBase): + __tablename__ = 'idp_remote_ids' + attributes = ['idp_id', 'remote_id'] + mutable_attributes = frozenset(['idp_id', 'remote_id']) + + idp_id = sql.Column(sql.String(64), + sql.ForeignKey('identity_provider.id', + ondelete='CASCADE')) + remote_id = sql.Column(sql.String(255), + primary_key=True) + + @classmethod + def from_dict(cls, dictionary): + new_dictionary = dictionary.copy() + return cls(**new_dictionary) + + def to_dict(self): + """Return a dictionary with model's attributes.""" + d = dict() + for attr in self.__class__.attributes: + d[attr] = getattr(self, attr) + return d + + +class MappingModel(sql.ModelBase, sql.DictBase): + __tablename__ = 'mapping' + attributes = ['id', 'rules'] + + id = sql.Column(sql.String(64), primary_key=True) + rules = sql.Column(sql.JsonBlob(), nullable=False) + + @classmethod + def from_dict(cls, dictionary): + new_dictionary = dictionary.copy() + new_dictionary['rules'] = jsonutils.dumps(new_dictionary['rules']) + return cls(**new_dictionary) + + def to_dict(self): + """Return a dictionary with model's attributes.""" + d = dict() + for attr in self.__class__.attributes: + d[attr] = getattr(self, attr) + d['rules'] = jsonutils.loads(d['rules']) + return d + + +class ServiceProviderModel(sql.ModelBase, sql.DictBase): + __tablename__ = 'service_provider' + attributes = ['auth_url', 'id', 'enabled', 'description', + 'relay_state_prefix', 'sp_url'] + mutable_attributes = frozenset(['auth_url', 'description', 'enabled', + 'relay_state_prefix', 'sp_url']) + + id = sql.Column(sql.String(64), primary_key=True) + enabled = sql.Column(sql.Boolean, nullable=False) + description = sql.Column(sql.Text(), nullable=True) + auth_url = sql.Column(sql.String(256), nullable=False) + sp_url = sql.Column(sql.String(256), nullable=False) + relay_state_prefix = sql.Column(sql.String(256), nullable=False) + + @classmethod + def from_dict(cls, dictionary): + new_dictionary = dictionary.copy() + return cls(**new_dictionary) + + def to_dict(self): + """Return a dictionary with model's attributes.""" + d = dict() + for attr in self.__class__.attributes: + d[attr] = getattr(self, attr) + return d + + +class Federation(core.FederationDriverV8): + + # Identity Provider CRUD + @sql.handle_conflicts(conflict_type='identity_provider') + def create_idp(self, idp_id, idp): + idp['id'] = idp_id + with sql.transaction() as session: + idp_ref = IdentityProviderModel.from_dict(idp) + session.add(idp_ref) + return idp_ref.to_dict() + + def delete_idp(self, idp_id): + with sql.transaction() as session: + self._delete_assigned_protocols(session, idp_id) + idp_ref = self._get_idp(session, idp_id) + session.delete(idp_ref) + + def _get_idp(self, session, idp_id): + idp_ref = session.query(IdentityProviderModel).get(idp_id) + if not idp_ref: + raise exception.IdentityProviderNotFound(idp_id=idp_id) + return idp_ref + + def _get_idp_from_remote_id(self, session, remote_id): + q = session.query(IdPRemoteIdsModel) + q = q.filter_by(remote_id=remote_id) + try: + return q.one() + except sql.NotFound: + raise exception.IdentityProviderNotFound(idp_id=remote_id) + + def list_idps(self): + with sql.transaction() as session: + idps = session.query(IdentityProviderModel) + idps_list = [idp.to_dict() for idp in idps] + return idps_list + + def get_idp(self, idp_id): + with sql.transaction() as session: + idp_ref = self._get_idp(session, idp_id) + return idp_ref.to_dict() + + def get_idp_from_remote_id(self, remote_id): + with sql.transaction() as session: + ref = self._get_idp_from_remote_id(session, remote_id) + return ref.to_dict() + + def update_idp(self, idp_id, idp): + with sql.transaction() as session: + idp_ref = self._get_idp(session, idp_id) + old_idp = idp_ref.to_dict() + old_idp.update(idp) + new_idp = IdentityProviderModel.from_dict(old_idp) + for attr in IdentityProviderModel.mutable_attributes: + setattr(idp_ref, attr, getattr(new_idp, attr)) + return idp_ref.to_dict() + + # Protocol CRUD + def _get_protocol(self, session, idp_id, protocol_id): + q = session.query(FederationProtocolModel) + q = q.filter_by(id=protocol_id, idp_id=idp_id) + try: + return q.one() + except sql.NotFound: + kwargs = {'protocol_id': protocol_id, + 'idp_id': idp_id} + raise exception.FederatedProtocolNotFound(**kwargs) + + @sql.handle_conflicts(conflict_type='federation_protocol') + def create_protocol(self, idp_id, protocol_id, protocol): + protocol['id'] = protocol_id + protocol['idp_id'] = idp_id + with sql.transaction() as session: + self._get_idp(session, idp_id) + protocol_ref = FederationProtocolModel.from_dict(protocol) + session.add(protocol_ref) + return protocol_ref.to_dict() + + def update_protocol(self, idp_id, protocol_id, protocol): + with sql.transaction() as session: + proto_ref = self._get_protocol(session, idp_id, protocol_id) + old_proto = proto_ref.to_dict() + old_proto.update(protocol) + new_proto = FederationProtocolModel.from_dict(old_proto) + for attr in FederationProtocolModel.mutable_attributes: + setattr(proto_ref, attr, getattr(new_proto, attr)) + return proto_ref.to_dict() + + def get_protocol(self, idp_id, protocol_id): + with sql.transaction() as session: + protocol_ref = self._get_protocol(session, idp_id, protocol_id) + return protocol_ref.to_dict() + + def list_protocols(self, idp_id): + with sql.transaction() as session: + q = session.query(FederationProtocolModel) + q = q.filter_by(idp_id=idp_id) + protocols = [protocol.to_dict() for protocol in q] + return protocols + + def delete_protocol(self, idp_id, protocol_id): + with sql.transaction() as session: + key_ref = self._get_protocol(session, idp_id, protocol_id) + session.delete(key_ref) + + def _delete_assigned_protocols(self, session, idp_id): + query = session.query(FederationProtocolModel) + query = query.filter_by(idp_id=idp_id) + query.delete() + + # Mapping CRUD + def _get_mapping(self, session, mapping_id): + mapping_ref = session.query(MappingModel).get(mapping_id) + if not mapping_ref: + raise exception.MappingNotFound(mapping_id=mapping_id) + return mapping_ref + + @sql.handle_conflicts(conflict_type='mapping') + def create_mapping(self, mapping_id, mapping): + ref = {} + ref['id'] = mapping_id + ref['rules'] = mapping.get('rules') + with sql.transaction() as session: + mapping_ref = MappingModel.from_dict(ref) + session.add(mapping_ref) + return mapping_ref.to_dict() + + def delete_mapping(self, mapping_id): + with sql.transaction() as session: + mapping_ref = self._get_mapping(session, mapping_id) + session.delete(mapping_ref) + + def list_mappings(self): + with sql.transaction() as session: + mappings = session.query(MappingModel) + return [x.to_dict() for x in mappings] + + def get_mapping(self, mapping_id): + with sql.transaction() as session: + mapping_ref = self._get_mapping(session, mapping_id) + return mapping_ref.to_dict() + + @sql.handle_conflicts(conflict_type='mapping') + def update_mapping(self, mapping_id, mapping): + ref = {} + ref['id'] = mapping_id + ref['rules'] = mapping.get('rules') + with sql.transaction() as session: + mapping_ref = self._get_mapping(session, mapping_id) + old_mapping = mapping_ref.to_dict() + old_mapping.update(ref) + new_mapping = MappingModel.from_dict(old_mapping) + for attr in MappingModel.attributes: + setattr(mapping_ref, attr, getattr(new_mapping, attr)) + return mapping_ref.to_dict() + + def get_mapping_from_idp_and_protocol(self, idp_id, protocol_id): + with sql.transaction() as session: + protocol_ref = self._get_protocol(session, idp_id, protocol_id) + mapping_id = protocol_ref.mapping_id + mapping_ref = self._get_mapping(session, mapping_id) + return mapping_ref.to_dict() + + # Service Provider CRUD + @sql.handle_conflicts(conflict_type='service_provider') + def create_sp(self, sp_id, sp): + sp['id'] = sp_id + with sql.transaction() as session: + sp_ref = ServiceProviderModel.from_dict(sp) + session.add(sp_ref) + return sp_ref.to_dict() + + def delete_sp(self, sp_id): + with sql.transaction() as session: + sp_ref = self._get_sp(session, sp_id) + session.delete(sp_ref) + + def _get_sp(self, session, sp_id): + sp_ref = session.query(ServiceProviderModel).get(sp_id) + if not sp_ref: + raise exception.ServiceProviderNotFound(sp_id=sp_id) + return sp_ref + + def list_sps(self): + with sql.transaction() as session: + sps = session.query(ServiceProviderModel) + sps_list = [sp.to_dict() for sp in sps] + return sps_list + + def get_sp(self, sp_id): + with sql.transaction() as session: + sp_ref = self._get_sp(session, sp_id) + return sp_ref.to_dict() + + def update_sp(self, sp_id, sp): + with sql.transaction() as session: + sp_ref = self._get_sp(session, sp_id) + old_sp = sp_ref.to_dict() + old_sp.update(sp) + new_sp = ServiceProviderModel.from_dict(old_sp) + for attr in ServiceProviderModel.mutable_attributes: + setattr(sp_ref, attr, getattr(new_sp, attr)) + return sp_ref.to_dict() + + def get_enabled_service_providers(self): + with sql.transaction() as session: + service_providers = session.query(ServiceProviderModel) + service_providers = service_providers.filter_by(enabled=True) + return service_providers diff --git a/keystone/federation/backends/sql.py b/keystone/federation/backends/sql.py index 851a91128d..26fd51c1a6 100644 --- a/keystone/federation/backends/sql.py +++ b/keystone/federation/backends/sql.py @@ -155,7 +155,7 @@ class ServiceProviderModel(sql.ModelBase, sql.DictBase): return d -class Federation(core.FederationDriverV8): +class Federation(core.FederationDriverV9): # Identity Provider CRUD @sql.handle_conflicts(conflict_type='identity_provider') diff --git a/keystone/federation/core.py b/keystone/federation/core.py index 261cc00ba7..27223f9b59 100644 --- a/keystone/federation/core.py +++ b/keystone/federation/core.py @@ -15,6 +15,7 @@ import abc from oslo_config import cfg +from oslo_log import versionutils import six from keystone.common import dependency @@ -56,6 +57,14 @@ class Manager(manager.Manager): def __init__(self): super(Manager, self).__init__(CONF.federation.driver) + # Make sure it is a driver version we support, and if it is a legacy + # driver, then wrap it. + if isinstance(self.driver, FederationDriverV8): + self.driver = V9FederationWrapperForV8Driver(self.driver) + elif not isinstance(self.driver, FederationDriverV9): + raise exception.UnsupportedDriverVersion( + driver=CONF.federation.driver) + def get_enabled_service_providers(self): """List enabled service providers for Service Catalog @@ -90,8 +99,16 @@ class Manager(manager.Manager): return mapped_properties, mapping['id'] +# The FederationDriverBase class is the set of driver methods from earlier +# drivers that we still support, that have not been removed or modified. This +# class is then used to created the augmented V8 and V9 version abstract driver +# classes, without having to duplicate a lot of abstract method signatures. +# If you remove a method from V9, then move the abstact methods from this Base +# class to the V8 class. Do not modify any of the method signatures in the Base +# class - changes should only be made in the V8 and subsequent classes. + @six.add_metaclass(abc.ABCMeta) -class FederationDriverV8(object): +class FederationDriverBase(object): @abc.abstractmethod def create_idp(self, idp_id, idp): @@ -369,4 +386,129 @@ class FederationDriverV8(object): raise exception.NotImplemented() # pragma: no cover +class FederationDriverV8(FederationDriverBase): + """Removed or redefined methods from V8. + + Move the abstract methods of any methods removed or modified in later + versions of the driver from FederationDriverBase to here. We maintain this + so that legacy drivers, which will be a subclass of FederationDriverV8, can + still reference them. + + """ + + pass + + +class FederationDriverV9(FederationDriverBase): + """New or redefined methods from V8. + + Add any new V9 abstract methods (or those with modified signatures) to + this class. + + """ + + pass + + +class V9FederationWrapperForV8Driver(FederationDriverV9): + """Wrapper class to supported a V8 legacy driver. + + In order to support legacy drivers without having to make the manager code + driver-version aware, we wrap legacy drivers so that they look like the + latest version. For the various changes made in a new driver, here are the + actions needed in this wrapper: + + Method removed from new driver - remove the call-through method from this + class, since the manager will no longer be + calling it. + Method signature (or meaning) changed - wrap the old method in a new + signature here, and munge the input + and output parameters accordingly. + New method added to new driver - add a method to implement the new + functionality here if possible. If that is + not possible, then return NotImplemented, + since we do not guarantee to support new + functionality with legacy drivers. + + """ + + @versionutils.deprecated( + as_of=versionutils.deprecated.MITAKA, + what='keystone.federation.FederationDriverV8', + in_favor_of='keystone.federation.FederationDriverV9', + remove_in=+2) + def __init__(self, wrapped_driver): + self.driver = wrapped_driver + + def create_idp(self, idp_id, idp): + return self.driver.create_idp(idp_id, idp) + + def delete_idp(self, idp_id): + self.driver.delete_idp(idp_id) + + def list_idps(self): + return self.driver.list_idps() + + def get_idp(self, idp_id): + return self.driver.get_idp(idp_id) + + def get_idp_from_remote_id(self, remote_id): + return self.driver.get_idp_from_remote_id(remote_id) + + def update_idp(self, idp_id, idp): + return self.driver.update_idp(idp_id, idp) + + def create_protocol(self, idp_id, protocol_id, protocol): + return self.driver.create_protocol(idp_id, protocol_id, protocol) + + def update_protocol(self, idp_id, protocol_id, protocol): + return self.driver.update_protocol(idp_id, protocol_id, protocol) + + def get_protocol(self, idp_id, protocol_id): + return self.driver.get_protocol(idp_id, protocol_id) + + def list_protocols(self, idp_id): + return self.driver.list_protocols(idp_id) + + def delete_protocol(self, idp_id, protocol_id): + self.driver.delete_protocol(idp_id, protocol_id) + + def create_mapping(self, mapping_ref): + return self.driver.create_mapping(mapping_ref) + + def delete_mapping(self, mapping_id): + self.driver.delete_mapping(mapping_id) + + def update_mapping(self, mapping_id, mapping_ref): + return self.driver.update_mapping(mapping_id, mapping_ref) + + def list_mappings(self): + return self.driver.list_mappings() + + def get_mapping(self, mapping_id): + return self.driver.get_mapping(mapping_id) + + def get_mapping_from_idp_and_protocol(self, idp_id, protocol_id): + return self.driver.get_mapping_from_idp_and_protocol( + idp_id, protocol_id) + + def create_sp(self, sp_id, sp): + return self.driver.create_sp(sp_id, sp) + + def delete_sp(self, sp_id): + self.driver.delete_sp(sp_id) + + def list_sps(self): + return self.driver.list_sps() + + def get_sp(self, sp_id): + return self.driver.get_sp(sp_id) + + def update_sp(self, sp_id, sp): + return self.driver.update_sp(sp_id, sp) + + def get_enabled_service_providers(self): + return self.driver.get_enabled_service_providers() + + Driver = manager.create_legacy_driver(FederationDriverV8) diff --git a/keystone/tests/unit/backend/legacy_drivers/federation/V8/__init__.py b/keystone/tests/unit/backend/legacy_drivers/federation/V8/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/keystone/tests/unit/backend/legacy_drivers/federation/V8/api_v3.py b/keystone/tests/unit/backend/legacy_drivers/federation/V8/api_v3.py new file mode 100644 index 0000000000..48e5a71658 --- /dev/null +++ b/keystone/tests/unit/backend/legacy_drivers/federation/V8/api_v3.py @@ -0,0 +1,31 @@ +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + +from keystone.tests.unit import test_v3_federation + + +class FederatedIdentityProviderTestsV8( + test_v3_federation.FederatedIdentityProviderTests): + """Test that a V8 driver still passes the same tests. + + We use the SQL driver as an example of a V8 legacy driver. + + """ + + def config_overrides(self): + super(FederatedIdentityProviderTestsV8, self).config_overrides() + # V8 SQL specific driver overrides + self.config_fixture.config( + group='federation', + driver='keystone.federation.V8_backends.sql.Federation') + self.use_specific_sql_driver_version( + 'keystone.federation', 'backends', 'V8_') diff --git a/keystone/tests/unit/backend/legacy_drivers/federation/__init__.py b/keystone/tests/unit/backend/legacy_drivers/federation/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/keystone/tests/unit/core.py b/keystone/tests/unit/core.py index 9fa1682371..f90874a296 100644 --- a/keystone/tests/unit/core.py +++ b/keystone/tests/unit/core.py @@ -596,6 +596,7 @@ class TestCase(BaseTestCase): config, '_register_auth_plugin_opt', new=mocked_register_auth_plugin_opt)) + self.sql_driver_version_overrides = {} self.config_overrides() # NOTE(morganfainberg): ensure config_overrides has been called. self.addCleanup(self._assert_config_overrides_called) @@ -841,7 +842,6 @@ class SQLDriverOverrides(object): self.config_fixture.config(group='revoke', driver='sql') self.config_fixture.config(group='token', driver='sql') self.config_fixture.config(group='trust', driver='sql') - self.sql_driver_version_overrides = {} def use_specific_sql_driver_version(self, driver_path, versionless_backend, version_suffix): diff --git a/keystone/tests/unit/ksfixtures/database.py b/keystone/tests/unit/ksfixtures/database.py index 9a68262bbd..5aec38c2ca 100644 --- a/keystone/tests/unit/ksfixtures/database.py +++ b/keystone/tests/unit/ksfixtures/database.py @@ -109,6 +109,13 @@ def _load_sqlalchemy_models(version_specifiers): # At this point module_without_backends might be something like # 'keystone.assignment', while this_backend might be something # 'V8_backends'. + + if module_without_backends.startswith('keystone.contrib'): + # All the sql modules have now been moved into the core tree + # so no point in loading these again here (and, in fact, doing + # so might break trying to load a versioned driver. + continue + if module_without_backends in version_specifiers: # OK, so there is a request for a specific version of this one. # We therefore should skip any other versioned backend as well diff --git a/keystone/tests/unit/rest.py b/keystone/tests/unit/rest.py index 63000eb543..33362a49e0 100644 --- a/keystone/tests/unit/rest.py +++ b/keystone/tests/unit/rest.py @@ -61,7 +61,7 @@ class RestfulTestCase(unit.TestCase): # Will need to reset the plug-ins self.addCleanup(setattr, auth_controllers, 'AUTH_METHODS', {}) - self.useFixture(database.Database()) + self.useFixture(database.Database(self.sql_driver_version_overrides)) self.load_backends() self.load_fixtures(default_fixtures) diff --git a/releasenotes/notes/v9FederationDriver-cbebcf5f97e1eae2.yaml b/releasenotes/notes/v9FederationDriver-cbebcf5f97e1eae2.yaml new file mode 100644 index 0000000000..24a91178d0 --- /dev/null +++ b/releasenotes/notes/v9FederationDriver-cbebcf5f97e1eae2.yaml @@ -0,0 +1,7 @@ +--- +upgrade: + - The V8 Federation driver interface is deprecated, but still supported in + Mitaka, so any custom drivers based on the V8 interface should still work. +other: + - Support for the V8 Federation driver interface is planned to be removed in + the 'O' release of OpenStack. diff --git a/tox.ini b/tox.ini index af827cd9d0..151267534f 100644 --- a/tox.ini +++ b/tox.ini @@ -79,6 +79,8 @@ commands = keystone/tests/unit/backend/legacy_drivers/assignment/V8/sql.py nosetests -v \ keystone/tests/unit/backend/legacy_drivers/role/V8/sql.py + nosetests -v \ + keystone/tests/unit/backend/legacy_drivers/federation/V8/api_v3.py [testenv:pep8] commands =