Merge "Move identity.backends.sql model code to sql_model.py"
This commit is contained in:
commit
a991d9edfe
|
@ -13,8 +13,6 @@
|
|||
# under the License.
|
||||
|
||||
import sqlalchemy
|
||||
from sqlalchemy.ext.hybrid import hybrid_property
|
||||
from sqlalchemy import orm
|
||||
|
||||
from keystone.common import driver_hints
|
||||
from keystone.common import sql
|
||||
|
@ -22,160 +20,7 @@ from keystone.common import utils
|
|||
from keystone import exception
|
||||
from keystone.i18n import _
|
||||
from keystone.identity.backends import base
|
||||
|
||||
|
||||
class User(sql.ModelBase, sql.DictBase):
|
||||
__tablename__ = 'user'
|
||||
attributes = ['id', 'name', 'domain_id', 'password', 'enabled',
|
||||
'default_project_id']
|
||||
id = sql.Column(sql.String(64), primary_key=True)
|
||||
enabled = sql.Column(sql.Boolean)
|
||||
extra = sql.Column(sql.JsonBlob())
|
||||
default_project_id = sql.Column(sql.String(64))
|
||||
local_user = orm.relationship('LocalUser', uselist=False,
|
||||
single_parent=True, lazy='subquery',
|
||||
cascade='all,delete-orphan', backref='user')
|
||||
federated_users = orm.relationship('FederatedUser',
|
||||
single_parent=True,
|
||||
lazy='subquery',
|
||||
cascade='all,delete-orphan',
|
||||
backref='user')
|
||||
|
||||
# name property
|
||||
@hybrid_property
|
||||
def name(self):
|
||||
if self.local_user:
|
||||
return self.local_user.name
|
||||
elif self.federated_users:
|
||||
return self.federated_users[0].display_name
|
||||
else:
|
||||
return None
|
||||
|
||||
@name.setter
|
||||
def name(self, value):
|
||||
if not self.local_user:
|
||||
self.local_user = LocalUser()
|
||||
self.local_user.name = value
|
||||
|
||||
@name.expression
|
||||
def name(cls):
|
||||
return LocalUser.name
|
||||
|
||||
# password property
|
||||
@hybrid_property
|
||||
def password(self):
|
||||
if self.local_user and self.local_user.passwords:
|
||||
return self.local_user.passwords[0].password
|
||||
else:
|
||||
return None
|
||||
|
||||
@password.setter
|
||||
def password(self, value):
|
||||
if not value:
|
||||
if self.local_user and self.local_user.passwords:
|
||||
self.local_user.passwords = []
|
||||
else:
|
||||
if not self.local_user:
|
||||
self.local_user = LocalUser()
|
||||
if not self.local_user.passwords:
|
||||
self.local_user.passwords.append(Password())
|
||||
self.local_user.passwords[0].password = value
|
||||
|
||||
@password.expression
|
||||
def password(cls):
|
||||
return Password.password
|
||||
|
||||
# domain_id property
|
||||
@hybrid_property
|
||||
def domain_id(self):
|
||||
if self.local_user:
|
||||
return self.local_user.domain_id
|
||||
else:
|
||||
return None
|
||||
|
||||
@domain_id.setter
|
||||
def domain_id(self, value):
|
||||
if not self.local_user:
|
||||
self.local_user = LocalUser()
|
||||
self.local_user.domain_id = value
|
||||
|
||||
@domain_id.expression
|
||||
def domain_id(cls):
|
||||
return LocalUser.domain_id
|
||||
|
||||
def to_dict(self, include_extra_dict=False):
|
||||
d = super(User, self).to_dict(include_extra_dict=include_extra_dict)
|
||||
if 'default_project_id' in d and d['default_project_id'] is None:
|
||||
del d['default_project_id']
|
||||
return d
|
||||
|
||||
|
||||
class LocalUser(sql.ModelBase, sql.DictBase):
|
||||
__tablename__ = 'local_user'
|
||||
attributes = ['id', 'user_id', 'domain_id', 'name']
|
||||
id = sql.Column(sql.Integer, primary_key=True)
|
||||
user_id = sql.Column(sql.String(64), sql.ForeignKey('user.id',
|
||||
ondelete='CASCADE'), unique=True)
|
||||
domain_id = sql.Column(sql.String(64), nullable=False)
|
||||
name = sql.Column(sql.String(255), nullable=False)
|
||||
passwords = orm.relationship('Password', single_parent=True,
|
||||
cascade='all,delete-orphan',
|
||||
backref='local_user')
|
||||
__table_args__ = (sql.UniqueConstraint('domain_id', 'name'), {})
|
||||
|
||||
|
||||
class Password(sql.ModelBase, sql.DictBase):
|
||||
__tablename__ = 'password'
|
||||
attributes = ['id', 'local_user_id', 'password']
|
||||
id = sql.Column(sql.Integer, primary_key=True)
|
||||
local_user_id = sql.Column(sql.Integer, sql.ForeignKey('local_user.id',
|
||||
ondelete='CASCADE'))
|
||||
password = sql.Column(sql.String(128))
|
||||
|
||||
|
||||
class FederatedUser(sql.ModelBase, sql.ModelDictMixin):
|
||||
__tablename__ = 'federated_user'
|
||||
attributes = ['id', 'user_id', 'idp_id', 'protocol_id', 'unique_id',
|
||||
'display_name']
|
||||
id = sql.Column(sql.Integer, primary_key=True)
|
||||
user_id = sql.Column(sql.String(64), sql.ForeignKey('user.id',
|
||||
ondelete='CASCADE'))
|
||||
idp_id = sql.Column(sql.String(64), sql.ForeignKey('identity_provider.id',
|
||||
ondelete='CASCADE'))
|
||||
protocol_id = sql.Column(sql.String(64), nullable=False)
|
||||
unique_id = sql.Column(sql.String(255), nullable=False)
|
||||
display_name = sql.Column(sql.String(255), nullable=True)
|
||||
__table_args__ = (
|
||||
sql.UniqueConstraint('idp_id', 'protocol_id', 'unique_id'),
|
||||
sqlalchemy.ForeignKeyConstraint(['protocol_id', 'idp_id'],
|
||||
['federation_protocol.id',
|
||||
'federation_protocol.idp_id'])
|
||||
)
|
||||
|
||||
|
||||
class Group(sql.ModelBase, sql.DictBase):
|
||||
__tablename__ = 'group'
|
||||
attributes = ['id', 'name', 'domain_id', 'description']
|
||||
id = sql.Column(sql.String(64), primary_key=True)
|
||||
name = sql.Column(sql.String(64), nullable=False)
|
||||
domain_id = sql.Column(sql.String(64), nullable=False)
|
||||
description = sql.Column(sql.Text())
|
||||
extra = sql.Column(sql.JsonBlob())
|
||||
# Unique constraint across two columns to create the separation
|
||||
# rather than just only 'name' being unique
|
||||
__table_args__ = (sql.UniqueConstraint('domain_id', 'name'),)
|
||||
|
||||
|
||||
class UserGroupMembership(sql.ModelBase, sql.DictBase):
|
||||
"""Group membership join table."""
|
||||
|
||||
__tablename__ = 'user_group_membership'
|
||||
user_id = sql.Column(sql.String(64),
|
||||
sql.ForeignKey('user.id'),
|
||||
primary_key=True)
|
||||
group_id = sql.Column(sql.String(64),
|
||||
sql.ForeignKey('group.id'),
|
||||
primary_key=True)
|
||||
from keystone.identity.backends import sql_model as model
|
||||
|
||||
|
||||
class Identity(base.IdentityDriverV8):
|
||||
|
@ -218,19 +63,19 @@ class Identity(base.IdentityDriverV8):
|
|||
def create_user(self, user_id, user):
|
||||
user = utils.hash_user_password(user)
|
||||
with sql.session_for_write() as session:
|
||||
user_ref = User.from_dict(user)
|
||||
user_ref = model.User.from_dict(user)
|
||||
session.add(user_ref)
|
||||
return base.filter_user(user_ref.to_dict())
|
||||
|
||||
@driver_hints.truncated
|
||||
def list_users(self, hints):
|
||||
with sql.session_for_read() as session:
|
||||
query = session.query(User).outerjoin(LocalUser)
|
||||
user_refs = sql.filter_limit_query(User, query, hints)
|
||||
query = session.query(model.User).outerjoin(model.LocalUser)
|
||||
user_refs = sql.filter_limit_query(model.User, query, hints)
|
||||
return [base.filter_user(x.to_dict()) for x in user_refs]
|
||||
|
||||
def _get_user(self, session, user_id):
|
||||
user_ref = session.query(User).get(user_id)
|
||||
user_ref = session.query(model.User).get(user_id)
|
||||
if not user_ref:
|
||||
raise exception.UserNotFound(user_id=user_id)
|
||||
return user_ref
|
||||
|
@ -242,9 +87,10 @@ class Identity(base.IdentityDriverV8):
|
|||
|
||||
def get_user_by_name(self, user_name, domain_id):
|
||||
with sql.session_for_read() as session:
|
||||
query = session.query(User).join(LocalUser)
|
||||
query = query.filter(sqlalchemy.and_(LocalUser.name == user_name,
|
||||
LocalUser.domain_id == domain_id))
|
||||
query = session.query(model.User).join(model.LocalUser)
|
||||
query = query.filter(sqlalchemy.and_(
|
||||
model.LocalUser.name == user_name,
|
||||
model.LocalUser.domain_id == domain_id))
|
||||
try:
|
||||
user_ref = query.one()
|
||||
except sql.NotFound:
|
||||
|
@ -259,8 +105,8 @@ class Identity(base.IdentityDriverV8):
|
|||
user = utils.hash_user_password(user)
|
||||
for k in user:
|
||||
old_user_dict[k] = user[k]
|
||||
new_user = User.from_dict(old_user_dict)
|
||||
for attr in User.attributes:
|
||||
new_user = model.User.from_dict(old_user_dict)
|
||||
for attr in model.User.attributes:
|
||||
if attr != 'id':
|
||||
setattr(user_ref, attr, getattr(new_user, attr))
|
||||
user_ref.extra = new_user.extra
|
||||
|
@ -271,21 +117,21 @@ class Identity(base.IdentityDriverV8):
|
|||
with sql.session_for_write() as session:
|
||||
self.get_group(group_id)
|
||||
self.get_user(user_id)
|
||||
query = session.query(UserGroupMembership)
|
||||
query = session.query(model.UserGroupMembership)
|
||||
query = query.filter_by(user_id=user_id)
|
||||
query = query.filter_by(group_id=group_id)
|
||||
rv = query.first()
|
||||
if rv:
|
||||
return
|
||||
|
||||
session.add(UserGroupMembership(user_id=user_id,
|
||||
group_id=group_id))
|
||||
session.add(model.UserGroupMembership(user_id=user_id,
|
||||
group_id=group_id))
|
||||
|
||||
def check_user_in_group(self, user_id, group_id):
|
||||
with sql.session_for_read() as session:
|
||||
self.get_group(group_id)
|
||||
self.get_user(user_id)
|
||||
query = session.query(UserGroupMembership)
|
||||
query = session.query(model.UserGroupMembership)
|
||||
query = query.filter_by(user_id=user_id)
|
||||
query = query.filter_by(group_id=group_id)
|
||||
if not query.first():
|
||||
|
@ -298,7 +144,7 @@ class Identity(base.IdentityDriverV8):
|
|||
# We don't check if user or group are still valid and let the remove
|
||||
# be tried anyway - in case this is some kind of clean-up operation
|
||||
with sql.session_for_write() as session:
|
||||
query = session.query(UserGroupMembership)
|
||||
query = session.query(model.UserGroupMembership)
|
||||
query = query.filter_by(user_id=user_id)
|
||||
query = query.filter_by(group_id=group_id)
|
||||
membership_ref = query.first()
|
||||
|
@ -316,25 +162,26 @@ class Identity(base.IdentityDriverV8):
|
|||
def list_groups_for_user(self, user_id, hints):
|
||||
with sql.session_for_read() as session:
|
||||
self.get_user(user_id)
|
||||
query = session.query(Group).join(UserGroupMembership)
|
||||
query = query.filter(UserGroupMembership.user_id == user_id)
|
||||
query = sql.filter_limit_query(Group, query, hints)
|
||||
query = session.query(model.Group).join(model.UserGroupMembership)
|
||||
query = query.filter(model.UserGroupMembership.user_id == user_id)
|
||||
query = sql.filter_limit_query(model.Group, query, hints)
|
||||
return [g.to_dict() for g in query]
|
||||
|
||||
def list_users_in_group(self, group_id, hints):
|
||||
with sql.session_for_read() as session:
|
||||
self.get_group(group_id)
|
||||
query = session.query(User).outerjoin(LocalUser)
|
||||
query = query.join(UserGroupMembership)
|
||||
query = query.filter(UserGroupMembership.group_id == group_id)
|
||||
query = sql.filter_limit_query(User, query, hints)
|
||||
query = session.query(model.User).outerjoin(model.LocalUser)
|
||||
query = query.join(model.UserGroupMembership)
|
||||
query = query.filter(
|
||||
model.UserGroupMembership.group_id == group_id)
|
||||
query = sql.filter_limit_query(model.User, query, hints)
|
||||
return [base.filter_user(u.to_dict()) for u in query]
|
||||
|
||||
def delete_user(self, user_id):
|
||||
with sql.session_for_write() as session:
|
||||
ref = self._get_user(session, user_id)
|
||||
|
||||
q = session.query(UserGroupMembership)
|
||||
q = session.query(model.UserGroupMembership)
|
||||
q = q.filter_by(user_id=user_id)
|
||||
q.delete(False)
|
||||
|
||||
|
@ -345,19 +192,19 @@ class Identity(base.IdentityDriverV8):
|
|||
@sql.handle_conflicts(conflict_type='group')
|
||||
def create_group(self, group_id, group):
|
||||
with sql.session_for_write() as session:
|
||||
ref = Group.from_dict(group)
|
||||
ref = model.Group.from_dict(group)
|
||||
session.add(ref)
|
||||
return ref.to_dict()
|
||||
|
||||
@driver_hints.truncated
|
||||
def list_groups(self, hints):
|
||||
with sql.session_for_read() as session:
|
||||
query = session.query(Group)
|
||||
refs = sql.filter_limit_query(Group, query, hints)
|
||||
query = session.query(model.Group)
|
||||
refs = sql.filter_limit_query(model.Group, query, hints)
|
||||
return [ref.to_dict() for ref in refs]
|
||||
|
||||
def _get_group(self, session, group_id):
|
||||
ref = session.query(Group).get(group_id)
|
||||
ref = session.query(model.Group).get(group_id)
|
||||
if not ref:
|
||||
raise exception.GroupNotFound(group_id=group_id)
|
||||
return ref
|
||||
|
@ -368,7 +215,7 @@ class Identity(base.IdentityDriverV8):
|
|||
|
||||
def get_group_by_name(self, group_name, domain_id):
|
||||
with sql.session_for_read() as session:
|
||||
query = session.query(Group)
|
||||
query = session.query(model.Group)
|
||||
query = query.filter_by(name=group_name)
|
||||
query = query.filter_by(domain_id=domain_id)
|
||||
try:
|
||||
|
@ -384,8 +231,8 @@ class Identity(base.IdentityDriverV8):
|
|||
old_dict = ref.to_dict()
|
||||
for k in group:
|
||||
old_dict[k] = group[k]
|
||||
new_group = Group.from_dict(old_dict)
|
||||
for attr in Group.attributes:
|
||||
new_group = model.Group.from_dict(old_dict)
|
||||
for attr in model.Group.attributes:
|
||||
if attr != 'id':
|
||||
setattr(ref, attr, getattr(new_group, attr))
|
||||
ref.extra = new_group.extra
|
||||
|
@ -395,7 +242,7 @@ class Identity(base.IdentityDriverV8):
|
|||
with sql.session_for_write() as session:
|
||||
ref = self._get_group(session, group_id)
|
||||
|
||||
q = session.query(UserGroupMembership)
|
||||
q = session.query(model.UserGroupMembership)
|
||||
q = q.filter_by(group_id=group_id)
|
||||
q.delete(False)
|
||||
|
||||
|
|
|
@ -0,0 +1,173 @@
|
|||
# Copyright 2012 OpenStack Foundation
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License"); you may
|
||||
# not use this file except in compliance with the License. You may obtain
|
||||
# a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
# License for the specific language governing permissions and limitations
|
||||
# under the License.
|
||||
|
||||
import sqlalchemy
|
||||
from sqlalchemy.ext.hybrid import hybrid_property
|
||||
from sqlalchemy import orm
|
||||
|
||||
from keystone.common import sql
|
||||
|
||||
|
||||
class User(sql.ModelBase, sql.DictBase):
|
||||
__tablename__ = 'user'
|
||||
attributes = ['id', 'name', 'domain_id', 'password', 'enabled',
|
||||
'default_project_id']
|
||||
id = sql.Column(sql.String(64), primary_key=True)
|
||||
enabled = sql.Column(sql.Boolean)
|
||||
extra = sql.Column(sql.JsonBlob())
|
||||
default_project_id = sql.Column(sql.String(64))
|
||||
local_user = orm.relationship('LocalUser', uselist=False,
|
||||
single_parent=True, lazy='subquery',
|
||||
cascade='all,delete-orphan', backref='user')
|
||||
federated_users = orm.relationship('FederatedUser',
|
||||
single_parent=True,
|
||||
lazy='subquery',
|
||||
cascade='all,delete-orphan',
|
||||
backref='user')
|
||||
|
||||
# name property
|
||||
@hybrid_property
|
||||
def name(self):
|
||||
if self.local_user:
|
||||
return self.local_user.name
|
||||
elif self.federated_users:
|
||||
return self.federated_users[0].display_name
|
||||
else:
|
||||
return None
|
||||
|
||||
@name.setter
|
||||
def name(self, value):
|
||||
if not self.local_user:
|
||||
self.local_user = LocalUser()
|
||||
self.local_user.name = value
|
||||
|
||||
@name.expression
|
||||
def name(cls):
|
||||
return LocalUser.name
|
||||
|
||||
# password property
|
||||
@hybrid_property
|
||||
def password(self):
|
||||
if self.local_user and self.local_user.passwords:
|
||||
return self.local_user.passwords[0].password
|
||||
else:
|
||||
return None
|
||||
|
||||
@password.setter
|
||||
def password(self, value):
|
||||
if not value:
|
||||
if self.local_user and self.local_user.passwords:
|
||||
self.local_user.passwords = []
|
||||
else:
|
||||
if not self.local_user:
|
||||
self.local_user = LocalUser()
|
||||
if not self.local_user.passwords:
|
||||
self.local_user.passwords.append(Password())
|
||||
self.local_user.passwords[0].password = value
|
||||
|
||||
@password.expression
|
||||
def password(cls):
|
||||
return Password.password
|
||||
|
||||
# domain_id property
|
||||
@hybrid_property
|
||||
def domain_id(self):
|
||||
if self.local_user:
|
||||
return self.local_user.domain_id
|
||||
else:
|
||||
return None
|
||||
|
||||
@domain_id.setter
|
||||
def domain_id(self, value):
|
||||
if not self.local_user:
|
||||
self.local_user = LocalUser()
|
||||
self.local_user.domain_id = value
|
||||
|
||||
@domain_id.expression
|
||||
def domain_id(cls):
|
||||
return LocalUser.domain_id
|
||||
|
||||
def to_dict(self, include_extra_dict=False):
|
||||
d = super(User, self).to_dict(include_extra_dict=include_extra_dict)
|
||||
if 'default_project_id' in d and d['default_project_id'] is None:
|
||||
del d['default_project_id']
|
||||
return d
|
||||
|
||||
|
||||
class LocalUser(sql.ModelBase, sql.DictBase):
|
||||
__tablename__ = 'local_user'
|
||||
attributes = ['id', 'user_id', 'domain_id', 'name']
|
||||
id = sql.Column(sql.Integer, primary_key=True)
|
||||
user_id = sql.Column(sql.String(64), sql.ForeignKey('user.id',
|
||||
ondelete='CASCADE'), unique=True)
|
||||
domain_id = sql.Column(sql.String(64), nullable=False)
|
||||
name = sql.Column(sql.String(255), nullable=False)
|
||||
passwords = orm.relationship('Password', single_parent=True,
|
||||
cascade='all,delete-orphan',
|
||||
backref='local_user')
|
||||
__table_args__ = (sql.UniqueConstraint('domain_id', 'name'), {})
|
||||
|
||||
|
||||
class Password(sql.ModelBase, sql.DictBase):
|
||||
__tablename__ = 'password'
|
||||
attributes = ['id', 'local_user_id', 'password']
|
||||
id = sql.Column(sql.Integer, primary_key=True)
|
||||
local_user_id = sql.Column(sql.Integer, sql.ForeignKey('local_user.id',
|
||||
ondelete='CASCADE'))
|
||||
password = sql.Column(sql.String(128))
|
||||
|
||||
|
||||
class FederatedUser(sql.ModelBase, sql.ModelDictMixin):
|
||||
__tablename__ = 'federated_user'
|
||||
attributes = ['id', 'user_id', 'idp_id', 'protocol_id', 'unique_id',
|
||||
'display_name']
|
||||
id = sql.Column(sql.Integer, primary_key=True)
|
||||
user_id = sql.Column(sql.String(64), sql.ForeignKey('user.id',
|
||||
ondelete='CASCADE'))
|
||||
idp_id = sql.Column(sql.String(64), sql.ForeignKey('identity_provider.id',
|
||||
ondelete='CASCADE'))
|
||||
protocol_id = sql.Column(sql.String(64), nullable=False)
|
||||
unique_id = sql.Column(sql.String(255), nullable=False)
|
||||
display_name = sql.Column(sql.String(255), nullable=True)
|
||||
__table_args__ = (
|
||||
sql.UniqueConstraint('idp_id', 'protocol_id', 'unique_id'),
|
||||
sqlalchemy.ForeignKeyConstraint(['protocol_id', 'idp_id'],
|
||||
['federation_protocol.id',
|
||||
'federation_protocol.idp_id'])
|
||||
)
|
||||
|
||||
|
||||
class Group(sql.ModelBase, sql.DictBase):
|
||||
__tablename__ = 'group'
|
||||
attributes = ['id', 'name', 'domain_id', 'description']
|
||||
id = sql.Column(sql.String(64), primary_key=True)
|
||||
name = sql.Column(sql.String(64), nullable=False)
|
||||
domain_id = sql.Column(sql.String(64), nullable=False)
|
||||
description = sql.Column(sql.Text())
|
||||
extra = sql.Column(sql.JsonBlob())
|
||||
# Unique constraint across two columns to create the separation
|
||||
# rather than just only 'name' being unique
|
||||
__table_args__ = (sql.UniqueConstraint('domain_id', 'name'),)
|
||||
|
||||
|
||||
class UserGroupMembership(sql.ModelBase, sql.DictBase):
|
||||
"""Group membership join table."""
|
||||
|
||||
__tablename__ = 'user_group_membership'
|
||||
user_id = sql.Column(sql.String(64),
|
||||
sql.ForeignKey('user.id'),
|
||||
primary_key=True)
|
||||
group_id = sql.Column(sql.String(64),
|
||||
sql.ForeignKey('group.id'),
|
||||
primary_key=True)
|
|
@ -15,7 +15,7 @@ import uuid
|
|||
from keystone.common import sql
|
||||
from keystone import exception
|
||||
from keystone.identity.backends import base as identity_base
|
||||
from keystone.identity.backends import sql as model
|
||||
from keystone.identity.backends import sql_model as model
|
||||
from keystone.identity.shadow_backends import base
|
||||
|
||||
|
||||
|
|
|
@ -28,7 +28,7 @@ from testtools import matchers
|
|||
from keystone.common import driver_hints
|
||||
from keystone.common import sql
|
||||
from keystone import exception
|
||||
from keystone.identity.backends import sql as identity_sql
|
||||
from keystone.identity.backends import sql_model as identity_sql
|
||||
from keystone.resource.backends import base as resource
|
||||
from keystone.tests import unit
|
||||
from keystone.tests.unit.assignment import test_backends as assignment_tests
|
||||
|
|
Loading…
Reference in New Issue