From fb3bc6c73fa471aa77c5b8e77b239f81e085c825 Mon Sep 17 00:00:00 2001 From: Ronald De Rose Date: Mon, 14 Mar 2016 21:12:19 +0000 Subject: [PATCH] Move identity.backends.sql model code to sql_model.py This patch moves identity.backends.sql model code to sql_model.py so that they can be shared with the shadow users sql backend. In addition, I think it is a cleaner implementation to separate the model from the driver. Change-Id: I9bc44e0b9837bf5800dddd94bb36804c44581240 --- keystone/identity/backends/sql.py | 219 ++++------------------- keystone/identity/backends/sql_model.py | 173 ++++++++++++++++++ keystone/identity/shadow_backends/sql.py | 2 +- keystone/tests/unit/test_backend_sql.py | 2 +- 4 files changed, 208 insertions(+), 188 deletions(-) create mode 100644 keystone/identity/backends/sql_model.py diff --git a/keystone/identity/backends/sql.py b/keystone/identity/backends/sql.py index 1a775a90b5..f7fad29beb 100644 --- a/keystone/identity/backends/sql.py +++ b/keystone/identity/backends/sql.py @@ -13,8 +13,6 @@ # under the License. import sqlalchemy -from sqlalchemy.ext.hybrid import hybrid_property -from sqlalchemy import orm from keystone.common import driver_hints from keystone.common import sql @@ -22,160 +20,7 @@ from keystone.common import utils from keystone import exception from keystone.i18n import _ from keystone.identity.backends import base - - -class User(sql.ModelBase, sql.DictBase): - __tablename__ = 'user' - attributes = ['id', 'name', 'domain_id', 'password', 'enabled', - 'default_project_id'] - id = sql.Column(sql.String(64), primary_key=True) - enabled = sql.Column(sql.Boolean) - extra = sql.Column(sql.JsonBlob()) - default_project_id = sql.Column(sql.String(64)) - local_user = orm.relationship('LocalUser', uselist=False, - single_parent=True, lazy='subquery', - cascade='all,delete-orphan', backref='user') - federated_users = orm.relationship('FederatedUser', - single_parent=True, - lazy='subquery', - cascade='all,delete-orphan', - backref='user') - - # name property - @hybrid_property - def name(self): - if self.local_user: - return self.local_user.name - elif self.federated_users: - return self.federated_users[0].display_name - else: - return None - - @name.setter - def name(self, value): - if not self.local_user: - self.local_user = LocalUser() - self.local_user.name = value - - @name.expression - def name(cls): - return LocalUser.name - - # password property - @hybrid_property - def password(self): - if self.local_user and self.local_user.passwords: - return self.local_user.passwords[0].password - else: - return None - - @password.setter - def password(self, value): - if not value: - if self.local_user and self.local_user.passwords: - self.local_user.passwords = [] - else: - if not self.local_user: - self.local_user = LocalUser() - if not self.local_user.passwords: - self.local_user.passwords.append(Password()) - self.local_user.passwords[0].password = value - - @password.expression - def password(cls): - return Password.password - - # domain_id property - @hybrid_property - def domain_id(self): - if self.local_user: - return self.local_user.domain_id - else: - return None - - @domain_id.setter - def domain_id(self, value): - if not self.local_user: - self.local_user = LocalUser() - self.local_user.domain_id = value - - @domain_id.expression - def domain_id(cls): - return LocalUser.domain_id - - def to_dict(self, include_extra_dict=False): - d = super(User, self).to_dict(include_extra_dict=include_extra_dict) - if 'default_project_id' in d and d['default_project_id'] is None: - del d['default_project_id'] - return d - - -class LocalUser(sql.ModelBase, sql.DictBase): - __tablename__ = 'local_user' - attributes = ['id', 'user_id', 'domain_id', 'name'] - id = sql.Column(sql.Integer, primary_key=True) - user_id = sql.Column(sql.String(64), sql.ForeignKey('user.id', - ondelete='CASCADE'), unique=True) - domain_id = sql.Column(sql.String(64), nullable=False) - name = sql.Column(sql.String(255), nullable=False) - passwords = orm.relationship('Password', single_parent=True, - cascade='all,delete-orphan', - backref='local_user') - __table_args__ = (sql.UniqueConstraint('domain_id', 'name'), {}) - - -class Password(sql.ModelBase, sql.DictBase): - __tablename__ = 'password' - attributes = ['id', 'local_user_id', 'password'] - id = sql.Column(sql.Integer, primary_key=True) - local_user_id = sql.Column(sql.Integer, sql.ForeignKey('local_user.id', - ondelete='CASCADE')) - password = sql.Column(sql.String(128)) - - -class FederatedUser(sql.ModelBase, sql.ModelDictMixin): - __tablename__ = 'federated_user' - attributes = ['id', 'user_id', 'idp_id', 'protocol_id', 'unique_id', - 'display_name'] - id = sql.Column(sql.Integer, primary_key=True) - user_id = sql.Column(sql.String(64), sql.ForeignKey('user.id', - ondelete='CASCADE')) - idp_id = sql.Column(sql.String(64), sql.ForeignKey('identity_provider.id', - ondelete='CASCADE')) - protocol_id = sql.Column(sql.String(64), nullable=False) - unique_id = sql.Column(sql.String(255), nullable=False) - display_name = sql.Column(sql.String(255), nullable=True) - __table_args__ = ( - sql.UniqueConstraint('idp_id', 'protocol_id', 'unique_id'), - sqlalchemy.ForeignKeyConstraint(['protocol_id', 'idp_id'], - ['federation_protocol.id', - 'federation_protocol.idp_id']) - ) - - -class Group(sql.ModelBase, sql.DictBase): - __tablename__ = 'group' - attributes = ['id', 'name', 'domain_id', 'description'] - id = sql.Column(sql.String(64), primary_key=True) - name = sql.Column(sql.String(64), nullable=False) - domain_id = sql.Column(sql.String(64), nullable=False) - description = sql.Column(sql.Text()) - extra = sql.Column(sql.JsonBlob()) - # Unique constraint across two columns to create the separation - # rather than just only 'name' being unique - __table_args__ = (sql.UniqueConstraint('domain_id', 'name'),) - - -class UserGroupMembership(sql.ModelBase, sql.DictBase): - """Group membership join table.""" - - __tablename__ = 'user_group_membership' - user_id = sql.Column(sql.String(64), - sql.ForeignKey('user.id'), - primary_key=True) - group_id = sql.Column(sql.String(64), - sql.ForeignKey('group.id'), - primary_key=True) +from keystone.identity.backends import sql_model as model class Identity(base.IdentityDriverV8): @@ -218,19 +63,19 @@ class Identity(base.IdentityDriverV8): def create_user(self, user_id, user): user = utils.hash_user_password(user) with sql.session_for_write() as session: - user_ref = User.from_dict(user) + user_ref = model.User.from_dict(user) session.add(user_ref) return base.filter_user(user_ref.to_dict()) @driver_hints.truncated def list_users(self, hints): with sql.session_for_read() as session: - query = session.query(User).outerjoin(LocalUser) - user_refs = sql.filter_limit_query(User, query, hints) + query = session.query(model.User).outerjoin(model.LocalUser) + user_refs = sql.filter_limit_query(model.User, query, hints) return [base.filter_user(x.to_dict()) for x in user_refs] def _get_user(self, session, user_id): - user_ref = session.query(User).get(user_id) + user_ref = session.query(model.User).get(user_id) if not user_ref: raise exception.UserNotFound(user_id=user_id) return user_ref @@ -242,9 +87,10 @@ class Identity(base.IdentityDriverV8): def get_user_by_name(self, user_name, domain_id): with sql.session_for_read() as session: - query = session.query(User).join(LocalUser) - query = query.filter(sqlalchemy.and_(LocalUser.name == user_name, - LocalUser.domain_id == domain_id)) + query = session.query(model.User).join(model.LocalUser) + query = query.filter(sqlalchemy.and_( + model.LocalUser.name == user_name, + model.LocalUser.domain_id == domain_id)) try: user_ref = query.one() except sql.NotFound: @@ -259,8 +105,8 @@ class Identity(base.IdentityDriverV8): user = utils.hash_user_password(user) for k in user: old_user_dict[k] = user[k] - new_user = User.from_dict(old_user_dict) - for attr in User.attributes: + new_user = model.User.from_dict(old_user_dict) + for attr in model.User.attributes: if attr != 'id': setattr(user_ref, attr, getattr(new_user, attr)) user_ref.extra = new_user.extra @@ -271,21 +117,21 @@ class Identity(base.IdentityDriverV8): with sql.session_for_write() as session: self.get_group(group_id) self.get_user(user_id) - query = session.query(UserGroupMembership) + query = session.query(model.UserGroupMembership) query = query.filter_by(user_id=user_id) query = query.filter_by(group_id=group_id) rv = query.first() if rv: return - session.add(UserGroupMembership(user_id=user_id, - group_id=group_id)) + session.add(model.UserGroupMembership(user_id=user_id, + group_id=group_id)) def check_user_in_group(self, user_id, group_id): with sql.session_for_read() as session: self.get_group(group_id) self.get_user(user_id) - query = session.query(UserGroupMembership) + query = session.query(model.UserGroupMembership) query = query.filter_by(user_id=user_id) query = query.filter_by(group_id=group_id) if not query.first(): @@ -298,7 +144,7 @@ class Identity(base.IdentityDriverV8): # We don't check if user or group are still valid and let the remove # be tried anyway - in case this is some kind of clean-up operation with sql.session_for_write() as session: - query = session.query(UserGroupMembership) + query = session.query(model.UserGroupMembership) query = query.filter_by(user_id=user_id) query = query.filter_by(group_id=group_id) membership_ref = query.first() @@ -316,25 +162,26 @@ class Identity(base.IdentityDriverV8): def list_groups_for_user(self, user_id, hints): with sql.session_for_read() as session: self.get_user(user_id) - query = session.query(Group).join(UserGroupMembership) - query = query.filter(UserGroupMembership.user_id == user_id) - query = sql.filter_limit_query(Group, query, hints) + query = session.query(model.Group).join(model.UserGroupMembership) + query = query.filter(model.UserGroupMembership.user_id == user_id) + query = sql.filter_limit_query(model.Group, query, hints) return [g.to_dict() for g in query] def list_users_in_group(self, group_id, hints): with sql.session_for_read() as session: self.get_group(group_id) - query = session.query(User).outerjoin(LocalUser) - query = query.join(UserGroupMembership) - query = query.filter(UserGroupMembership.group_id == group_id) - query = sql.filter_limit_query(User, query, hints) + query = session.query(model.User).outerjoin(model.LocalUser) + query = query.join(model.UserGroupMembership) + query = query.filter( + model.UserGroupMembership.group_id == group_id) + query = sql.filter_limit_query(model.User, query, hints) return [base.filter_user(u.to_dict()) for u in query] def delete_user(self, user_id): with sql.session_for_write() as session: ref = self._get_user(session, user_id) - q = session.query(UserGroupMembership) + q = session.query(model.UserGroupMembership) q = q.filter_by(user_id=user_id) q.delete(False) @@ -345,19 +192,19 @@ class Identity(base.IdentityDriverV8): @sql.handle_conflicts(conflict_type='group') def create_group(self, group_id, group): with sql.session_for_write() as session: - ref = Group.from_dict(group) + ref = model.Group.from_dict(group) session.add(ref) return ref.to_dict() @driver_hints.truncated def list_groups(self, hints): with sql.session_for_read() as session: - query = session.query(Group) - refs = sql.filter_limit_query(Group, query, hints) + query = session.query(model.Group) + refs = sql.filter_limit_query(model.Group, query, hints) return [ref.to_dict() for ref in refs] def _get_group(self, session, group_id): - ref = session.query(Group).get(group_id) + ref = session.query(model.Group).get(group_id) if not ref: raise exception.GroupNotFound(group_id=group_id) return ref @@ -368,7 +215,7 @@ class Identity(base.IdentityDriverV8): def get_group_by_name(self, group_name, domain_id): with sql.session_for_read() as session: - query = session.query(Group) + query = session.query(model.Group) query = query.filter_by(name=group_name) query = query.filter_by(domain_id=domain_id) try: @@ -384,8 +231,8 @@ class Identity(base.IdentityDriverV8): old_dict = ref.to_dict() for k in group: old_dict[k] = group[k] - new_group = Group.from_dict(old_dict) - for attr in Group.attributes: + new_group = model.Group.from_dict(old_dict) + for attr in model.Group.attributes: if attr != 'id': setattr(ref, attr, getattr(new_group, attr)) ref.extra = new_group.extra @@ -395,7 +242,7 @@ class Identity(base.IdentityDriverV8): with sql.session_for_write() as session: ref = self._get_group(session, group_id) - q = session.query(UserGroupMembership) + q = session.query(model.UserGroupMembership) q = q.filter_by(group_id=group_id) q.delete(False) diff --git a/keystone/identity/backends/sql_model.py b/keystone/identity/backends/sql_model.py new file mode 100644 index 0000000000..4b4263dfca --- /dev/null +++ b/keystone/identity/backends/sql_model.py @@ -0,0 +1,173 @@ +# Copyright 2012 OpenStack Foundation +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + +import sqlalchemy +from sqlalchemy.ext.hybrid import hybrid_property +from sqlalchemy import orm + +from keystone.common import sql + + +class User(sql.ModelBase, sql.DictBase): + __tablename__ = 'user' + attributes = ['id', 'name', 'domain_id', 'password', 'enabled', + 'default_project_id'] + id = sql.Column(sql.String(64), primary_key=True) + enabled = sql.Column(sql.Boolean) + extra = sql.Column(sql.JsonBlob()) + default_project_id = sql.Column(sql.String(64)) + local_user = orm.relationship('LocalUser', uselist=False, + single_parent=True, lazy='subquery', + cascade='all,delete-orphan', backref='user') + federated_users = orm.relationship('FederatedUser', + single_parent=True, + lazy='subquery', + cascade='all,delete-orphan', + backref='user') + + # name property + @hybrid_property + def name(self): + if self.local_user: + return self.local_user.name + elif self.federated_users: + return self.federated_users[0].display_name + else: + return None + + @name.setter + def name(self, value): + if not self.local_user: + self.local_user = LocalUser() + self.local_user.name = value + + @name.expression + def name(cls): + return LocalUser.name + + # password property + @hybrid_property + def password(self): + if self.local_user and self.local_user.passwords: + return self.local_user.passwords[0].password + else: + return None + + @password.setter + def password(self, value): + if not value: + if self.local_user and self.local_user.passwords: + self.local_user.passwords = [] + else: + if not self.local_user: + self.local_user = LocalUser() + if not self.local_user.passwords: + self.local_user.passwords.append(Password()) + self.local_user.passwords[0].password = value + + @password.expression + def password(cls): + return Password.password + + # domain_id property + @hybrid_property + def domain_id(self): + if self.local_user: + return self.local_user.domain_id + else: + return None + + @domain_id.setter + def domain_id(self, value): + if not self.local_user: + self.local_user = LocalUser() + self.local_user.domain_id = value + + @domain_id.expression + def domain_id(cls): + return LocalUser.domain_id + + def to_dict(self, include_extra_dict=False): + d = super(User, self).to_dict(include_extra_dict=include_extra_dict) + if 'default_project_id' in d and d['default_project_id'] is None: + del d['default_project_id'] + return d + + +class LocalUser(sql.ModelBase, sql.DictBase): + __tablename__ = 'local_user' + attributes = ['id', 'user_id', 'domain_id', 'name'] + id = sql.Column(sql.Integer, primary_key=True) + user_id = sql.Column(sql.String(64), sql.ForeignKey('user.id', + ondelete='CASCADE'), unique=True) + domain_id = sql.Column(sql.String(64), nullable=False) + name = sql.Column(sql.String(255), nullable=False) + passwords = orm.relationship('Password', single_parent=True, + cascade='all,delete-orphan', + backref='local_user') + __table_args__ = (sql.UniqueConstraint('domain_id', 'name'), {}) + + +class Password(sql.ModelBase, sql.DictBase): + __tablename__ = 'password' + attributes = ['id', 'local_user_id', 'password'] + id = sql.Column(sql.Integer, primary_key=True) + local_user_id = sql.Column(sql.Integer, sql.ForeignKey('local_user.id', + ondelete='CASCADE')) + password = sql.Column(sql.String(128)) + + +class FederatedUser(sql.ModelBase, sql.ModelDictMixin): + __tablename__ = 'federated_user' + attributes = ['id', 'user_id', 'idp_id', 'protocol_id', 'unique_id', + 'display_name'] + id = sql.Column(sql.Integer, primary_key=True) + user_id = sql.Column(sql.String(64), sql.ForeignKey('user.id', + ondelete='CASCADE')) + idp_id = sql.Column(sql.String(64), sql.ForeignKey('identity_provider.id', + ondelete='CASCADE')) + protocol_id = sql.Column(sql.String(64), nullable=False) + unique_id = sql.Column(sql.String(255), nullable=False) + display_name = sql.Column(sql.String(255), nullable=True) + __table_args__ = ( + sql.UniqueConstraint('idp_id', 'protocol_id', 'unique_id'), + sqlalchemy.ForeignKeyConstraint(['protocol_id', 'idp_id'], + ['federation_protocol.id', + 'federation_protocol.idp_id']) + ) + + +class Group(sql.ModelBase, sql.DictBase): + __tablename__ = 'group' + attributes = ['id', 'name', 'domain_id', 'description'] + id = sql.Column(sql.String(64), primary_key=True) + name = sql.Column(sql.String(64), nullable=False) + domain_id = sql.Column(sql.String(64), nullable=False) + description = sql.Column(sql.Text()) + extra = sql.Column(sql.JsonBlob()) + # Unique constraint across two columns to create the separation + # rather than just only 'name' being unique + __table_args__ = (sql.UniqueConstraint('domain_id', 'name'),) + + +class UserGroupMembership(sql.ModelBase, sql.DictBase): + """Group membership join table.""" + + __tablename__ = 'user_group_membership' + user_id = sql.Column(sql.String(64), + sql.ForeignKey('user.id'), + primary_key=True) + group_id = sql.Column(sql.String(64), + sql.ForeignKey('group.id'), + primary_key=True) diff --git a/keystone/identity/shadow_backends/sql.py b/keystone/identity/shadow_backends/sql.py index 2f92ce7721..461b01ab2c 100644 --- a/keystone/identity/shadow_backends/sql.py +++ b/keystone/identity/shadow_backends/sql.py @@ -15,7 +15,7 @@ import uuid from keystone.common import sql from keystone import exception from keystone.identity.backends import base as identity_base -from keystone.identity.backends import sql as model +from keystone.identity.backends import sql_model as model from keystone.identity.shadow_backends import base diff --git a/keystone/tests/unit/test_backend_sql.py b/keystone/tests/unit/test_backend_sql.py index 466969c910..295abeb290 100644 --- a/keystone/tests/unit/test_backend_sql.py +++ b/keystone/tests/unit/test_backend_sql.py @@ -28,7 +28,7 @@ from testtools import matchers from keystone.common import driver_hints from keystone.common import sql from keystone import exception -from keystone.identity.backends import sql as identity_sql +from keystone.identity.backends import sql_model as identity_sql from keystone.resource.backends import base as resource from keystone.tests import unit from keystone.tests.unit.assignment import test_backends as assignment_tests