Merge "Use the new enginefacade from oslo.db"
This commit is contained in:
commit
d37af165d0
@ -56,7 +56,7 @@ class Assignment(keystone_assignment.AssignmentDriverV8):
|
|||||||
return 'sql'
|
return 'sql'
|
||||||
|
|
||||||
def list_user_ids_for_project(self, tenant_id):
|
def list_user_ids_for_project(self, tenant_id):
|
||||||
with sql.transaction() as session:
|
with sql.session_for_read() as session:
|
||||||
query = session.query(RoleAssignment.actor_id)
|
query = session.query(RoleAssignment.actor_id)
|
||||||
query = query.filter_by(type=AssignmentType.USER_PROJECT)
|
query = query.filter_by(type=AssignmentType.USER_PROJECT)
|
||||||
query = query.filter_by(target_id=tenant_id)
|
query = query.filter_by(target_id=tenant_id)
|
||||||
@ -71,7 +71,7 @@ class Assignment(keystone_assignment.AssignmentDriverV8):
|
|||||||
assignment_type = AssignmentType.calculate_type(
|
assignment_type = AssignmentType.calculate_type(
|
||||||
user_id, group_id, project_id, domain_id)
|
user_id, group_id, project_id, domain_id)
|
||||||
try:
|
try:
|
||||||
with sql.transaction() as session:
|
with sql.session_for_write() as session:
|
||||||
session.add(RoleAssignment(
|
session.add(RoleAssignment(
|
||||||
type=assignment_type,
|
type=assignment_type,
|
||||||
actor_id=user_id or group_id,
|
actor_id=user_id or group_id,
|
||||||
@ -85,7 +85,7 @@ class Assignment(keystone_assignment.AssignmentDriverV8):
|
|||||||
def list_grant_role_ids(self, user_id=None, group_id=None,
|
def list_grant_role_ids(self, user_id=None, group_id=None,
|
||||||
domain_id=None, project_id=None,
|
domain_id=None, project_id=None,
|
||||||
inherited_to_projects=False):
|
inherited_to_projects=False):
|
||||||
with sql.transaction() as session:
|
with sql.session_for_read() as session:
|
||||||
q = session.query(RoleAssignment.role_id)
|
q = session.query(RoleAssignment.role_id)
|
||||||
q = q.filter(RoleAssignment.actor_id == (user_id or group_id))
|
q = q.filter(RoleAssignment.actor_id == (user_id or group_id))
|
||||||
q = q.filter(RoleAssignment.target_id == (project_id or domain_id))
|
q = q.filter(RoleAssignment.target_id == (project_id or domain_id))
|
||||||
@ -104,7 +104,7 @@ class Assignment(keystone_assignment.AssignmentDriverV8):
|
|||||||
def check_grant_role_id(self, role_id, user_id=None, group_id=None,
|
def check_grant_role_id(self, role_id, user_id=None, group_id=None,
|
||||||
domain_id=None, project_id=None,
|
domain_id=None, project_id=None,
|
||||||
inherited_to_projects=False):
|
inherited_to_projects=False):
|
||||||
with sql.transaction() as session:
|
with sql.session_for_read() as session:
|
||||||
try:
|
try:
|
||||||
q = self._build_grant_filter(
|
q = self._build_grant_filter(
|
||||||
session, role_id, user_id, group_id, domain_id, project_id,
|
session, role_id, user_id, group_id, domain_id, project_id,
|
||||||
@ -120,7 +120,7 @@ class Assignment(keystone_assignment.AssignmentDriverV8):
|
|||||||
def delete_grant(self, role_id, user_id=None, group_id=None,
|
def delete_grant(self, role_id, user_id=None, group_id=None,
|
||||||
domain_id=None, project_id=None,
|
domain_id=None, project_id=None,
|
||||||
inherited_to_projects=False):
|
inherited_to_projects=False):
|
||||||
with sql.transaction() as session:
|
with sql.session_for_write() as session:
|
||||||
q = self._build_grant_filter(
|
q = self._build_grant_filter(
|
||||||
session, role_id, user_id, group_id, domain_id, project_id,
|
session, role_id, user_id, group_id, domain_id, project_id,
|
||||||
inherited_to_projects)
|
inherited_to_projects)
|
||||||
@ -145,11 +145,11 @@ class Assignment(keystone_assignment.AssignmentDriverV8):
|
|||||||
RoleAssignment.inherited == inherited,
|
RoleAssignment.inherited == inherited,
|
||||||
RoleAssignment.actor_id.in_(actors))
|
RoleAssignment.actor_id.in_(actors))
|
||||||
|
|
||||||
with sql.transaction() as session:
|
with sql.session_for_read() as session:
|
||||||
query = session.query(RoleAssignment.target_id).filter(
|
query = session.query(RoleAssignment.target_id).filter(
|
||||||
sql_constraints).distinct()
|
sql_constraints).distinct()
|
||||||
|
|
||||||
return [x.target_id for x in query.all()]
|
return [x.target_id for x in query.all()]
|
||||||
|
|
||||||
def list_project_ids_for_user(self, user_id, group_ids, hints,
|
def list_project_ids_for_user(self, user_id, group_ids, hints,
|
||||||
inherited=False):
|
inherited=False):
|
||||||
@ -161,7 +161,7 @@ class Assignment(keystone_assignment.AssignmentDriverV8):
|
|||||||
|
|
||||||
def list_domain_ids_for_user(self, user_id, group_ids, hints,
|
def list_domain_ids_for_user(self, user_id, group_ids, hints,
|
||||||
inherited=False):
|
inherited=False):
|
||||||
with sql.transaction() as session:
|
with sql.session_for_read() as session:
|
||||||
query = session.query(RoleAssignment.target_id)
|
query = session.query(RoleAssignment.target_id)
|
||||||
filters = []
|
filters = []
|
||||||
|
|
||||||
@ -197,10 +197,10 @@ class Assignment(keystone_assignment.AssignmentDriverV8):
|
|||||||
RoleAssignment.inherited == false(),
|
RoleAssignment.inherited == false(),
|
||||||
RoleAssignment.actor_id.in_(group_ids))
|
RoleAssignment.actor_id.in_(group_ids))
|
||||||
|
|
||||||
with sql.transaction() as session:
|
with sql.session_for_read() as session:
|
||||||
query = session.query(RoleAssignment.role_id).filter(
|
query = session.query(RoleAssignment.role_id).filter(
|
||||||
sql_constraints).distinct()
|
sql_constraints).distinct()
|
||||||
return [role.role_id for role in query.all()]
|
return [role.role_id for role in query.all()]
|
||||||
|
|
||||||
def list_role_ids_for_groups_on_project(
|
def list_role_ids_for_groups_on_project(
|
||||||
self, group_ids, project_id, project_domain_id, project_parents):
|
self, group_ids, project_id, project_domain_id, project_parents):
|
||||||
@ -237,13 +237,13 @@ class Assignment(keystone_assignment.AssignmentDriverV8):
|
|||||||
sql_constraints = sqlalchemy.and_(
|
sql_constraints = sqlalchemy.and_(
|
||||||
sql_constraints, RoleAssignment.actor_id.in_(group_ids))
|
sql_constraints, RoleAssignment.actor_id.in_(group_ids))
|
||||||
|
|
||||||
with sql.transaction() as session:
|
with sql.session_for_read() as session:
|
||||||
# NOTE(morganfainberg): Only select the columns we actually care
|
# NOTE(morganfainberg): Only select the columns we actually care
|
||||||
# about here, in this case role_id.
|
# about here, in this case role_id.
|
||||||
query = session.query(RoleAssignment.role_id).filter(
|
query = session.query(RoleAssignment.role_id).filter(
|
||||||
sql_constraints).distinct()
|
sql_constraints).distinct()
|
||||||
|
|
||||||
return [result.role_id for result in query.all()]
|
return [result.role_id for result in query.all()]
|
||||||
|
|
||||||
def list_project_ids_for_groups(self, group_ids, hints,
|
def list_project_ids_for_groups(self, group_ids, hints,
|
||||||
inherited=False):
|
inherited=False):
|
||||||
@ -260,14 +260,14 @@ class Assignment(keystone_assignment.AssignmentDriverV8):
|
|||||||
RoleAssignment.inherited == inherited,
|
RoleAssignment.inherited == inherited,
|
||||||
RoleAssignment.actor_id.in_(group_ids))
|
RoleAssignment.actor_id.in_(group_ids))
|
||||||
|
|
||||||
with sql.transaction() as session:
|
with sql.session_for_read() as session:
|
||||||
query = session.query(RoleAssignment.target_id).filter(
|
query = session.query(RoleAssignment.target_id).filter(
|
||||||
group_sql_conditions).distinct()
|
group_sql_conditions).distinct()
|
||||||
return [x.target_id for x in query.all()]
|
return [x.target_id for x in query.all()]
|
||||||
|
|
||||||
def add_role_to_user_and_project(self, user_id, tenant_id, role_id):
|
def add_role_to_user_and_project(self, user_id, tenant_id, role_id):
|
||||||
try:
|
try:
|
||||||
with sql.transaction() as session:
|
with sql.session_for_write() as session:
|
||||||
session.add(RoleAssignment(
|
session.add(RoleAssignment(
|
||||||
type=AssignmentType.USER_PROJECT,
|
type=AssignmentType.USER_PROJECT,
|
||||||
actor_id=user_id, target_id=tenant_id,
|
actor_id=user_id, target_id=tenant_id,
|
||||||
@ -278,7 +278,7 @@ class Assignment(keystone_assignment.AssignmentDriverV8):
|
|||||||
raise exception.Conflict(type='role grant', details=msg)
|
raise exception.Conflict(type='role grant', details=msg)
|
||||||
|
|
||||||
def remove_role_from_user_and_project(self, user_id, tenant_id, role_id):
|
def remove_role_from_user_and_project(self, user_id, tenant_id, role_id):
|
||||||
with sql.transaction() as session:
|
with sql.session_for_write() as session:
|
||||||
q = session.query(RoleAssignment)
|
q = session.query(RoleAssignment)
|
||||||
q = q.filter_by(actor_id=user_id)
|
q = q.filter_by(actor_id=user_id)
|
||||||
q = q.filter_by(target_id=tenant_id)
|
q = q.filter_by(target_id=tenant_id)
|
||||||
@ -368,7 +368,7 @@ class Assignment(keystone_assignment.AssignmentDriverV8):
|
|||||||
assignment['inherited_to_projects'] = 'projects'
|
assignment['inherited_to_projects'] = 'projects'
|
||||||
return assignment
|
return assignment
|
||||||
|
|
||||||
with sql.transaction() as session:
|
with sql.session_for_read() as session:
|
||||||
assignment_types = self._get_assignment_types(
|
assignment_types = self._get_assignment_types(
|
||||||
user_id, group_ids, project_ids, domain_id)
|
user_id, group_ids, project_ids, domain_id)
|
||||||
|
|
||||||
@ -400,25 +400,25 @@ class Assignment(keystone_assignment.AssignmentDriverV8):
|
|||||||
return [denormalize_role(ref) for ref in query.all()]
|
return [denormalize_role(ref) for ref in query.all()]
|
||||||
|
|
||||||
def delete_project_assignments(self, project_id):
|
def delete_project_assignments(self, project_id):
|
||||||
with sql.transaction() as session:
|
with sql.session_for_write() as session:
|
||||||
q = session.query(RoleAssignment)
|
q = session.query(RoleAssignment)
|
||||||
q = q.filter_by(target_id=project_id)
|
q = q.filter_by(target_id=project_id)
|
||||||
q.delete(False)
|
q.delete(False)
|
||||||
|
|
||||||
def delete_role_assignments(self, role_id):
|
def delete_role_assignments(self, role_id):
|
||||||
with sql.transaction() as session:
|
with sql.session_for_write() as session:
|
||||||
q = session.query(RoleAssignment)
|
q = session.query(RoleAssignment)
|
||||||
q = q.filter_by(role_id=role_id)
|
q = q.filter_by(role_id=role_id)
|
||||||
q.delete(False)
|
q.delete(False)
|
||||||
|
|
||||||
def delete_user_assignments(self, user_id):
|
def delete_user_assignments(self, user_id):
|
||||||
with sql.transaction() as session:
|
with sql.session_for_write() as session:
|
||||||
q = session.query(RoleAssignment)
|
q = session.query(RoleAssignment)
|
||||||
q = q.filter_by(actor_id=user_id)
|
q = q.filter_by(actor_id=user_id)
|
||||||
q.delete(False)
|
q.delete(False)
|
||||||
|
|
||||||
def delete_group_assignments(self, group_id):
|
def delete_group_assignments(self, group_id):
|
||||||
with sql.transaction() as session:
|
with sql.session_for_write() as session:
|
||||||
q = session.query(RoleAssignment)
|
q = session.query(RoleAssignment)
|
||||||
q = q.filter_by(actor_id=group_id)
|
q = q.filter_by(actor_id=group_id)
|
||||||
q.delete(False)
|
q.delete(False)
|
||||||
|
@ -19,14 +19,14 @@ class Role(assignment.RoleDriverV8):
|
|||||||
|
|
||||||
@sql.handle_conflicts(conflict_type='role')
|
@sql.handle_conflicts(conflict_type='role')
|
||||||
def create_role(self, role_id, role):
|
def create_role(self, role_id, role):
|
||||||
with sql.transaction() as session:
|
with sql.session_for_write() as session:
|
||||||
ref = RoleTable.from_dict(role)
|
ref = RoleTable.from_dict(role)
|
||||||
session.add(ref)
|
session.add(ref)
|
||||||
return ref.to_dict()
|
return ref.to_dict()
|
||||||
|
|
||||||
@sql.truncated
|
@sql.truncated
|
||||||
def list_roles(self, hints):
|
def list_roles(self, hints):
|
||||||
with sql.transaction() as session:
|
with sql.session_for_read() as session:
|
||||||
query = session.query(RoleTable)
|
query = session.query(RoleTable)
|
||||||
refs = sql.filter_limit_query(RoleTable, query, hints)
|
refs = sql.filter_limit_query(RoleTable, query, hints)
|
||||||
return [ref.to_dict() for ref in refs]
|
return [ref.to_dict() for ref in refs]
|
||||||
@ -35,7 +35,7 @@ class Role(assignment.RoleDriverV8):
|
|||||||
if not ids:
|
if not ids:
|
||||||
return []
|
return []
|
||||||
else:
|
else:
|
||||||
with sql.transaction() as session:
|
with sql.session_for_read() as session:
|
||||||
query = session.query(RoleTable)
|
query = session.query(RoleTable)
|
||||||
query = query.filter(RoleTable.id.in_(ids))
|
query = query.filter(RoleTable.id.in_(ids))
|
||||||
role_refs = query.all()
|
role_refs = query.all()
|
||||||
@ -48,12 +48,12 @@ class Role(assignment.RoleDriverV8):
|
|||||||
return ref
|
return ref
|
||||||
|
|
||||||
def get_role(self, role_id):
|
def get_role(self, role_id):
|
||||||
with sql.transaction() as session:
|
with sql.session_for_read() as session:
|
||||||
return self._get_role(session, role_id).to_dict()
|
return self._get_role(session, role_id).to_dict()
|
||||||
|
|
||||||
@sql.handle_conflicts(conflict_type='role')
|
@sql.handle_conflicts(conflict_type='role')
|
||||||
def update_role(self, role_id, role):
|
def update_role(self, role_id, role):
|
||||||
with sql.transaction() as session:
|
with sql.session_for_write() as session:
|
||||||
ref = self._get_role(session, role_id)
|
ref = self._get_role(session, role_id)
|
||||||
old_dict = ref.to_dict()
|
old_dict = ref.to_dict()
|
||||||
for k in role:
|
for k in role:
|
||||||
@ -66,7 +66,7 @@ class Role(assignment.RoleDriverV8):
|
|||||||
return ref.to_dict()
|
return ref.to_dict()
|
||||||
|
|
||||||
def delete_role(self, role_id):
|
def delete_role(self, role_id):
|
||||||
with sql.transaction() as session:
|
with sql.session_for_write() as session:
|
||||||
ref = self._get_role(session, role_id)
|
ref = self._get_role(session, role_id)
|
||||||
session.delete(ref)
|
session.delete(ref)
|
||||||
|
|
||||||
|
@ -55,7 +55,7 @@ class Assignment(keystone_assignment.AssignmentDriverV9):
|
|||||||
assignment_type = AssignmentType.calculate_type(
|
assignment_type = AssignmentType.calculate_type(
|
||||||
user_id, group_id, project_id, domain_id)
|
user_id, group_id, project_id, domain_id)
|
||||||
try:
|
try:
|
||||||
with sql.transaction() as session:
|
with sql.session_for_write() as session:
|
||||||
session.add(RoleAssignment(
|
session.add(RoleAssignment(
|
||||||
type=assignment_type,
|
type=assignment_type,
|
||||||
actor_id=user_id or group_id,
|
actor_id=user_id or group_id,
|
||||||
@ -69,7 +69,7 @@ class Assignment(keystone_assignment.AssignmentDriverV9):
|
|||||||
def list_grant_role_ids(self, user_id=None, group_id=None,
|
def list_grant_role_ids(self, user_id=None, group_id=None,
|
||||||
domain_id=None, project_id=None,
|
domain_id=None, project_id=None,
|
||||||
inherited_to_projects=False):
|
inherited_to_projects=False):
|
||||||
with sql.transaction() as session:
|
with sql.session_for_read() as session:
|
||||||
q = session.query(RoleAssignment.role_id)
|
q = session.query(RoleAssignment.role_id)
|
||||||
q = q.filter(RoleAssignment.actor_id == (user_id or group_id))
|
q = q.filter(RoleAssignment.actor_id == (user_id or group_id))
|
||||||
q = q.filter(RoleAssignment.target_id == (project_id or domain_id))
|
q = q.filter(RoleAssignment.target_id == (project_id or domain_id))
|
||||||
@ -88,7 +88,7 @@ class Assignment(keystone_assignment.AssignmentDriverV9):
|
|||||||
def check_grant_role_id(self, role_id, user_id=None, group_id=None,
|
def check_grant_role_id(self, role_id, user_id=None, group_id=None,
|
||||||
domain_id=None, project_id=None,
|
domain_id=None, project_id=None,
|
||||||
inherited_to_projects=False):
|
inherited_to_projects=False):
|
||||||
with sql.transaction() as session:
|
with sql.session_for_read() as session:
|
||||||
try:
|
try:
|
||||||
q = self._build_grant_filter(
|
q = self._build_grant_filter(
|
||||||
session, role_id, user_id, group_id, domain_id, project_id,
|
session, role_id, user_id, group_id, domain_id, project_id,
|
||||||
@ -104,7 +104,7 @@ class Assignment(keystone_assignment.AssignmentDriverV9):
|
|||||||
def delete_grant(self, role_id, user_id=None, group_id=None,
|
def delete_grant(self, role_id, user_id=None, group_id=None,
|
||||||
domain_id=None, project_id=None,
|
domain_id=None, project_id=None,
|
||||||
inherited_to_projects=False):
|
inherited_to_projects=False):
|
||||||
with sql.transaction() as session:
|
with sql.session_for_write() as session:
|
||||||
q = self._build_grant_filter(
|
q = self._build_grant_filter(
|
||||||
session, role_id, user_id, group_id, domain_id, project_id,
|
session, role_id, user_id, group_id, domain_id, project_id,
|
||||||
inherited_to_projects)
|
inherited_to_projects)
|
||||||
@ -117,7 +117,7 @@ class Assignment(keystone_assignment.AssignmentDriverV9):
|
|||||||
|
|
||||||
def add_role_to_user_and_project(self, user_id, tenant_id, role_id):
|
def add_role_to_user_and_project(self, user_id, tenant_id, role_id):
|
||||||
try:
|
try:
|
||||||
with sql.transaction() as session:
|
with sql.session_for_write() as session:
|
||||||
session.add(RoleAssignment(
|
session.add(RoleAssignment(
|
||||||
type=AssignmentType.USER_PROJECT,
|
type=AssignmentType.USER_PROJECT,
|
||||||
actor_id=user_id, target_id=tenant_id,
|
actor_id=user_id, target_id=tenant_id,
|
||||||
@ -128,7 +128,7 @@ class Assignment(keystone_assignment.AssignmentDriverV9):
|
|||||||
raise exception.Conflict(type='role grant', details=msg)
|
raise exception.Conflict(type='role grant', details=msg)
|
||||||
|
|
||||||
def remove_role_from_user_and_project(self, user_id, tenant_id, role_id):
|
def remove_role_from_user_and_project(self, user_id, tenant_id, role_id):
|
||||||
with sql.transaction() as session:
|
with sql.session_for_write() as session:
|
||||||
q = session.query(RoleAssignment)
|
q = session.query(RoleAssignment)
|
||||||
q = q.filter_by(actor_id=user_id)
|
q = q.filter_by(actor_id=user_id)
|
||||||
q = q.filter_by(target_id=tenant_id)
|
q = q.filter_by(target_id=tenant_id)
|
||||||
@ -218,7 +218,7 @@ class Assignment(keystone_assignment.AssignmentDriverV9):
|
|||||||
assignment['inherited_to_projects'] = 'projects'
|
assignment['inherited_to_projects'] = 'projects'
|
||||||
return assignment
|
return assignment
|
||||||
|
|
||||||
with sql.transaction() as session:
|
with sql.session_for_read() as session:
|
||||||
assignment_types = self._get_assignment_types(
|
assignment_types = self._get_assignment_types(
|
||||||
user_id, group_ids, project_ids, domain_id)
|
user_id, group_ids, project_ids, domain_id)
|
||||||
|
|
||||||
@ -250,7 +250,7 @@ class Assignment(keystone_assignment.AssignmentDriverV9):
|
|||||||
return [denormalize_role(ref) for ref in query.all()]
|
return [denormalize_role(ref) for ref in query.all()]
|
||||||
|
|
||||||
def delete_project_assignments(self, project_id):
|
def delete_project_assignments(self, project_id):
|
||||||
with sql.transaction() as session:
|
with sql.session_for_write() as session:
|
||||||
q = session.query(RoleAssignment)
|
q = session.query(RoleAssignment)
|
||||||
q = q.filter_by(target_id=project_id).filter(
|
q = q.filter_by(target_id=project_id).filter(
|
||||||
RoleAssignment.type.in_((AssignmentType.USER_PROJECT,
|
RoleAssignment.type.in_((AssignmentType.USER_PROJECT,
|
||||||
@ -259,13 +259,13 @@ class Assignment(keystone_assignment.AssignmentDriverV9):
|
|||||||
q.delete(False)
|
q.delete(False)
|
||||||
|
|
||||||
def delete_role_assignments(self, role_id):
|
def delete_role_assignments(self, role_id):
|
||||||
with sql.transaction() as session:
|
with sql.session_for_write() as session:
|
||||||
q = session.query(RoleAssignment)
|
q = session.query(RoleAssignment)
|
||||||
q = q.filter_by(role_id=role_id)
|
q = q.filter_by(role_id=role_id)
|
||||||
q.delete(False)
|
q.delete(False)
|
||||||
|
|
||||||
def delete_user_assignments(self, user_id):
|
def delete_user_assignments(self, user_id):
|
||||||
with sql.transaction() as session:
|
with sql.session_for_write() as session:
|
||||||
q = session.query(RoleAssignment)
|
q = session.query(RoleAssignment)
|
||||||
q = q.filter_by(actor_id=user_id).filter(
|
q = q.filter_by(actor_id=user_id).filter(
|
||||||
RoleAssignment.type.in_((AssignmentType.USER_PROJECT,
|
RoleAssignment.type.in_((AssignmentType.USER_PROJECT,
|
||||||
@ -274,7 +274,7 @@ class Assignment(keystone_assignment.AssignmentDriverV9):
|
|||||||
q.delete(False)
|
q.delete(False)
|
||||||
|
|
||||||
def delete_group_assignments(self, group_id):
|
def delete_group_assignments(self, group_id):
|
||||||
with sql.transaction() as session:
|
with sql.session_for_write() as session:
|
||||||
q = session.query(RoleAssignment)
|
q = session.query(RoleAssignment)
|
||||||
q = q.filter_by(actor_id=group_id).filter(
|
q = q.filter_by(actor_id=group_id).filter(
|
||||||
RoleAssignment.type.in_((AssignmentType.GROUP_PROJECT,
|
RoleAssignment.type.in_((AssignmentType.GROUP_PROJECT,
|
||||||
|
@ -29,7 +29,7 @@ class Role(assignment.RoleDriverV9):
|
|||||||
|
|
||||||
@sql.handle_conflicts(conflict_type='role')
|
@sql.handle_conflicts(conflict_type='role')
|
||||||
def create_role(self, role_id, role):
|
def create_role(self, role_id, role):
|
||||||
with sql.transaction() as session:
|
with sql.session_for_write() as session:
|
||||||
ref = RoleTable.from_dict(role)
|
ref = RoleTable.from_dict(role)
|
||||||
session.add(ref)
|
session.add(ref)
|
||||||
return ref.to_dict()
|
return ref.to_dict()
|
||||||
@ -46,7 +46,7 @@ class Role(assignment.RoleDriverV9):
|
|||||||
if (f['name'] == 'domain_id' and f['value'] is None):
|
if (f['name'] == 'domain_id' and f['value'] is None):
|
||||||
f['value'] = NULL_DOMAIN_ID
|
f['value'] = NULL_DOMAIN_ID
|
||||||
|
|
||||||
with sql.transaction() as session:
|
with sql.session_for_read() as session:
|
||||||
query = session.query(RoleTable)
|
query = session.query(RoleTable)
|
||||||
refs = sql.filter_limit_query(RoleTable, query, hints)
|
refs = sql.filter_limit_query(RoleTable, query, hints)
|
||||||
return [ref.to_dict() for ref in refs]
|
return [ref.to_dict() for ref in refs]
|
||||||
@ -55,7 +55,7 @@ class Role(assignment.RoleDriverV9):
|
|||||||
if not ids:
|
if not ids:
|
||||||
return []
|
return []
|
||||||
else:
|
else:
|
||||||
with sql.transaction() as session:
|
with sql.session_for_read() as session:
|
||||||
query = session.query(RoleTable)
|
query = session.query(RoleTable)
|
||||||
query = query.filter(RoleTable.id.in_(ids))
|
query = query.filter(RoleTable.id.in_(ids))
|
||||||
role_refs = query.all()
|
role_refs = query.all()
|
||||||
@ -68,12 +68,12 @@ class Role(assignment.RoleDriverV9):
|
|||||||
return ref
|
return ref
|
||||||
|
|
||||||
def get_role(self, role_id):
|
def get_role(self, role_id):
|
||||||
with sql.transaction() as session:
|
with sql.session_for_read() as session:
|
||||||
return self._get_role(session, role_id).to_dict()
|
return self._get_role(session, role_id).to_dict()
|
||||||
|
|
||||||
@sql.handle_conflicts(conflict_type='role')
|
@sql.handle_conflicts(conflict_type='role')
|
||||||
def update_role(self, role_id, role):
|
def update_role(self, role_id, role):
|
||||||
with sql.transaction() as session:
|
with sql.session_for_write() as session:
|
||||||
ref = self._get_role(session, role_id)
|
ref = self._get_role(session, role_id)
|
||||||
old_dict = ref.to_dict()
|
old_dict = ref.to_dict()
|
||||||
for k in role:
|
for k in role:
|
||||||
@ -86,7 +86,7 @@ class Role(assignment.RoleDriverV9):
|
|||||||
return ref.to_dict()
|
return ref.to_dict()
|
||||||
|
|
||||||
def delete_role(self, role_id):
|
def delete_role(self, role_id):
|
||||||
with sql.transaction() as session:
|
with sql.session_for_write() as session:
|
||||||
ref = self._get_role(session, role_id)
|
ref = self._get_role(session, role_id)
|
||||||
session.delete(ref)
|
session.delete(ref)
|
||||||
|
|
||||||
@ -105,7 +105,7 @@ class Role(assignment.RoleDriverV9):
|
|||||||
|
|
||||||
@sql.handle_conflicts(conflict_type='implied_role')
|
@sql.handle_conflicts(conflict_type='implied_role')
|
||||||
def create_implied_role(self, prior_role_id, implied_role_id):
|
def create_implied_role(self, prior_role_id, implied_role_id):
|
||||||
with sql.transaction() as session:
|
with sql.session_for_write() as session:
|
||||||
inference = {'prior_role_id': prior_role_id,
|
inference = {'prior_role_id': prior_role_id,
|
||||||
'implied_role_id': implied_role_id}
|
'implied_role_id': implied_role_id}
|
||||||
ref = ImpliedRoleTable.from_dict(inference)
|
ref = ImpliedRoleTable.from_dict(inference)
|
||||||
@ -119,13 +119,13 @@ class Role(assignment.RoleDriverV9):
|
|||||||
return ref.to_dict()
|
return ref.to_dict()
|
||||||
|
|
||||||
def delete_implied_role(self, prior_role_id, implied_role_id):
|
def delete_implied_role(self, prior_role_id, implied_role_id):
|
||||||
with sql.transaction() as session:
|
with sql.session_for_write() as session:
|
||||||
ref = self._get_implied_role(session, prior_role_id,
|
ref = self._get_implied_role(session, prior_role_id,
|
||||||
implied_role_id)
|
implied_role_id)
|
||||||
session.delete(ref)
|
session.delete(ref)
|
||||||
|
|
||||||
def list_implied_roles(self, prior_role_id):
|
def list_implied_roles(self, prior_role_id):
|
||||||
with sql.transaction() as session:
|
with sql.session_for_read() as session:
|
||||||
query = session.query(
|
query = session.query(
|
||||||
ImpliedRoleTable).filter(
|
ImpliedRoleTable).filter(
|
||||||
ImpliedRoleTable.prior_role_id == prior_role_id)
|
ImpliedRoleTable.prior_role_id == prior_role_id)
|
||||||
@ -133,13 +133,13 @@ class Role(assignment.RoleDriverV9):
|
|||||||
return [ref.to_dict() for ref in refs]
|
return [ref.to_dict() for ref in refs]
|
||||||
|
|
||||||
def list_role_inference_rules(self):
|
def list_role_inference_rules(self):
|
||||||
with sql.transaction() as session:
|
with sql.session_for_read() as session:
|
||||||
query = session.query(ImpliedRoleTable)
|
query = session.query(ImpliedRoleTable)
|
||||||
refs = query.all()
|
refs = query.all()
|
||||||
return [ref.to_dict() for ref in refs]
|
return [ref.to_dict() for ref in refs]
|
||||||
|
|
||||||
def get_implied_role(self, prior_role_id, implied_role_id):
|
def get_implied_role(self, prior_role_id, implied_role_id):
|
||||||
with sql.transaction() as session:
|
with sql.session_for_read() as session:
|
||||||
ref = self._get_implied_role(session, prior_role_id,
|
ref = self._get_implied_role(session, prior_role_id,
|
||||||
implied_role_id)
|
implied_role_id)
|
||||||
return ref.to_dict()
|
return ref.to_dict()
|
||||||
|
@ -84,10 +84,10 @@ class Endpoint(sql.ModelBase, sql.DictBase):
|
|||||||
class Catalog(catalog.CatalogDriverV8):
|
class Catalog(catalog.CatalogDriverV8):
|
||||||
# Regions
|
# Regions
|
||||||
def list_regions(self, hints):
|
def list_regions(self, hints):
|
||||||
session = sql.get_session()
|
with sql.session_for_read() as session:
|
||||||
regions = session.query(Region)
|
regions = session.query(Region)
|
||||||
regions = sql.filter_limit_query(Region, regions, hints)
|
regions = sql.filter_limit_query(Region, regions, hints)
|
||||||
return [s.to_dict() for s in list(regions)]
|
return [s.to_dict() for s in list(regions)]
|
||||||
|
|
||||||
def _get_region(self, session, region_id):
|
def _get_region(self, session, region_id):
|
||||||
ref = session.query(Region).get(region_id)
|
ref = session.query(Region).get(region_id)
|
||||||
@ -136,12 +136,11 @@ class Catalog(catalog.CatalogDriverV8):
|
|||||||
return False
|
return False
|
||||||
|
|
||||||
def get_region(self, region_id):
|
def get_region(self, region_id):
|
||||||
session = sql.get_session()
|
with sql.session_for_read() as session:
|
||||||
return self._get_region(session, region_id).to_dict()
|
return self._get_region(session, region_id).to_dict()
|
||||||
|
|
||||||
def delete_region(self, region_id):
|
def delete_region(self, region_id):
|
||||||
session = sql.get_session()
|
with sql.session_for_write() as session:
|
||||||
with session.begin():
|
|
||||||
ref = self._get_region(session, region_id)
|
ref = self._get_region(session, region_id)
|
||||||
if self._has_endpoints(session, ref, ref):
|
if self._has_endpoints(session, ref, ref):
|
||||||
raise exception.RegionDeletionError(region_id=region_id)
|
raise exception.RegionDeletionError(region_id=region_id)
|
||||||
@ -150,16 +149,14 @@ class Catalog(catalog.CatalogDriverV8):
|
|||||||
|
|
||||||
@sql.handle_conflicts(conflict_type='region')
|
@sql.handle_conflicts(conflict_type='region')
|
||||||
def create_region(self, region_ref):
|
def create_region(self, region_ref):
|
||||||
session = sql.get_session()
|
with sql.session_for_write() as session:
|
||||||
with session.begin():
|
|
||||||
self._check_parent_region(session, region_ref)
|
self._check_parent_region(session, region_ref)
|
||||||
region = Region.from_dict(region_ref)
|
region = Region.from_dict(region_ref)
|
||||||
session.add(region)
|
session.add(region)
|
||||||
return region.to_dict()
|
return region.to_dict()
|
||||||
|
|
||||||
def update_region(self, region_id, region_ref):
|
def update_region(self, region_id, region_ref):
|
||||||
session = sql.get_session()
|
with sql.session_for_write() as session:
|
||||||
with session.begin():
|
|
||||||
self._check_parent_region(session, region_ref)
|
self._check_parent_region(session, region_ref)
|
||||||
ref = self._get_region(session, region_id)
|
ref = self._get_region(session, region_id)
|
||||||
old_dict = ref.to_dict()
|
old_dict = ref.to_dict()
|
||||||
@ -169,15 +166,15 @@ class Catalog(catalog.CatalogDriverV8):
|
|||||||
for attr in Region.attributes:
|
for attr in Region.attributes:
|
||||||
if attr != 'id':
|
if attr != 'id':
|
||||||
setattr(ref, attr, getattr(new_region, attr))
|
setattr(ref, attr, getattr(new_region, attr))
|
||||||
return ref.to_dict()
|
return ref.to_dict()
|
||||||
|
|
||||||
# Services
|
# Services
|
||||||
@driver_hints.truncated
|
@driver_hints.truncated
|
||||||
def list_services(self, hints):
|
def list_services(self, hints):
|
||||||
session = sql.get_session()
|
with sql.session_for_read() as session:
|
||||||
services = session.query(Service)
|
services = session.query(Service)
|
||||||
services = sql.filter_limit_query(Service, services, hints)
|
services = sql.filter_limit_query(Service, services, hints)
|
||||||
return [s.to_dict() for s in list(services)]
|
return [s.to_dict() for s in list(services)]
|
||||||
|
|
||||||
def _get_service(self, session, service_id):
|
def _get_service(self, session, service_id):
|
||||||
ref = session.query(Service).get(service_id)
|
ref = session.query(Service).get(service_id)
|
||||||
@ -186,26 +183,23 @@ class Catalog(catalog.CatalogDriverV8):
|
|||||||
return ref
|
return ref
|
||||||
|
|
||||||
def get_service(self, service_id):
|
def get_service(self, service_id):
|
||||||
session = sql.get_session()
|
with sql.session_for_read() as session:
|
||||||
return self._get_service(session, service_id).to_dict()
|
return self._get_service(session, service_id).to_dict()
|
||||||
|
|
||||||
def delete_service(self, service_id):
|
def delete_service(self, service_id):
|
||||||
session = sql.get_session()
|
with sql.session_for_write() as session:
|
||||||
with session.begin():
|
|
||||||
ref = self._get_service(session, service_id)
|
ref = self._get_service(session, service_id)
|
||||||
session.query(Endpoint).filter_by(service_id=service_id).delete()
|
session.query(Endpoint).filter_by(service_id=service_id).delete()
|
||||||
session.delete(ref)
|
session.delete(ref)
|
||||||
|
|
||||||
def create_service(self, service_id, service_ref):
|
def create_service(self, service_id, service_ref):
|
||||||
session = sql.get_session()
|
with sql.session_for_write() as session:
|
||||||
with session.begin():
|
|
||||||
service = Service.from_dict(service_ref)
|
service = Service.from_dict(service_ref)
|
||||||
session.add(service)
|
session.add(service)
|
||||||
return service.to_dict()
|
return service.to_dict()
|
||||||
|
|
||||||
def update_service(self, service_id, service_ref):
|
def update_service(self, service_id, service_ref):
|
||||||
session = sql.get_session()
|
with sql.session_for_write() as session:
|
||||||
with session.begin():
|
|
||||||
ref = self._get_service(session, service_id)
|
ref = self._get_service(session, service_id)
|
||||||
old_dict = ref.to_dict()
|
old_dict = ref.to_dict()
|
||||||
old_dict.update(service_ref)
|
old_dict.update(service_ref)
|
||||||
@ -214,20 +208,17 @@ class Catalog(catalog.CatalogDriverV8):
|
|||||||
if attr != 'id':
|
if attr != 'id':
|
||||||
setattr(ref, attr, getattr(new_service, attr))
|
setattr(ref, attr, getattr(new_service, attr))
|
||||||
ref.extra = new_service.extra
|
ref.extra = new_service.extra
|
||||||
return ref.to_dict()
|
return ref.to_dict()
|
||||||
|
|
||||||
# Endpoints
|
# Endpoints
|
||||||
def create_endpoint(self, endpoint_id, endpoint_ref):
|
def create_endpoint(self, endpoint_id, endpoint_ref):
|
||||||
session = sql.get_session()
|
|
||||||
new_endpoint = Endpoint.from_dict(endpoint_ref)
|
new_endpoint = Endpoint.from_dict(endpoint_ref)
|
||||||
|
with sql.session_for_write() as session:
|
||||||
with session.begin():
|
|
||||||
session.add(new_endpoint)
|
session.add(new_endpoint)
|
||||||
return new_endpoint.to_dict()
|
return new_endpoint.to_dict()
|
||||||
|
|
||||||
def delete_endpoint(self, endpoint_id):
|
def delete_endpoint(self, endpoint_id):
|
||||||
session = sql.get_session()
|
with sql.session_for_write() as session:
|
||||||
with session.begin():
|
|
||||||
ref = self._get_endpoint(session, endpoint_id)
|
ref = self._get_endpoint(session, endpoint_id)
|
||||||
session.delete(ref)
|
session.delete(ref)
|
||||||
|
|
||||||
@ -238,20 +229,18 @@ class Catalog(catalog.CatalogDriverV8):
|
|||||||
raise exception.EndpointNotFound(endpoint_id=endpoint_id)
|
raise exception.EndpointNotFound(endpoint_id=endpoint_id)
|
||||||
|
|
||||||
def get_endpoint(self, endpoint_id):
|
def get_endpoint(self, endpoint_id):
|
||||||
session = sql.get_session()
|
with sql.session_for_read() as session:
|
||||||
return self._get_endpoint(session, endpoint_id).to_dict()
|
return self._get_endpoint(session, endpoint_id).to_dict()
|
||||||
|
|
||||||
@driver_hints.truncated
|
@driver_hints.truncated
|
||||||
def list_endpoints(self, hints):
|
def list_endpoints(self, hints):
|
||||||
session = sql.get_session()
|
with sql.session_for_read() as session:
|
||||||
endpoints = session.query(Endpoint)
|
endpoints = session.query(Endpoint)
|
||||||
endpoints = sql.filter_limit_query(Endpoint, endpoints, hints)
|
endpoints = sql.filter_limit_query(Endpoint, endpoints, hints)
|
||||||
return [e.to_dict() for e in list(endpoints)]
|
return [e.to_dict() for e in list(endpoints)]
|
||||||
|
|
||||||
def update_endpoint(self, endpoint_id, endpoint_ref):
|
def update_endpoint(self, endpoint_id, endpoint_ref):
|
||||||
session = sql.get_session()
|
with sql.session_for_write() as session:
|
||||||
|
|
||||||
with session.begin():
|
|
||||||
ref = self._get_endpoint(session, endpoint_id)
|
ref = self._get_endpoint(session, endpoint_id)
|
||||||
old_dict = ref.to_dict()
|
old_dict = ref.to_dict()
|
||||||
old_dict.update(endpoint_ref)
|
old_dict.update(endpoint_ref)
|
||||||
@ -260,7 +249,7 @@ class Catalog(catalog.CatalogDriverV8):
|
|||||||
if attr != 'id':
|
if attr != 'id':
|
||||||
setattr(ref, attr, getattr(new_endpoint, attr))
|
setattr(ref, attr, getattr(new_endpoint, attr))
|
||||||
ref.extra = new_endpoint.extra
|
ref.extra = new_endpoint.extra
|
||||||
return ref.to_dict()
|
return ref.to_dict()
|
||||||
|
|
||||||
def get_catalog(self, user_id, tenant_id):
|
def get_catalog(self, user_id, tenant_id):
|
||||||
"""Retrieve and format the V2 service catalog.
|
"""Retrieve and format the V2 service catalog.
|
||||||
@ -289,40 +278,40 @@ class Catalog(catalog.CatalogDriverV8):
|
|||||||
else:
|
else:
|
||||||
silent_keyerror_failures = ['tenant_id', 'project_id', ]
|
silent_keyerror_failures = ['tenant_id', 'project_id', ]
|
||||||
|
|
||||||
session = sql.get_session()
|
with sql.session_for_read() as session:
|
||||||
endpoints = (session.query(Endpoint).
|
endpoints = (session.query(Endpoint).
|
||||||
options(sql.joinedload(Endpoint.service)).
|
options(sql.joinedload(Endpoint.service)).
|
||||||
filter(Endpoint.enabled == true()).all())
|
filter(Endpoint.enabled == true()).all())
|
||||||
|
|
||||||
catalog = {}
|
catalog = {}
|
||||||
|
|
||||||
for endpoint in endpoints:
|
for endpoint in endpoints:
|
||||||
if not endpoint.service['enabled']:
|
if not endpoint.service['enabled']:
|
||||||
continue
|
|
||||||
try:
|
|
||||||
formatted_url = core.format_url(
|
|
||||||
endpoint['url'], substitutions,
|
|
||||||
silent_keyerror_failures=silent_keyerror_failures)
|
|
||||||
if formatted_url is not None:
|
|
||||||
url = formatted_url
|
|
||||||
else:
|
|
||||||
continue
|
continue
|
||||||
except exception.MalformedEndpoint:
|
try:
|
||||||
continue # this failure is already logged in format_url()
|
formatted_url = core.format_url(
|
||||||
|
endpoint['url'], substitutions,
|
||||||
|
silent_keyerror_failures=silent_keyerror_failures)
|
||||||
|
if formatted_url is not None:
|
||||||
|
url = formatted_url
|
||||||
|
else:
|
||||||
|
continue
|
||||||
|
except exception.MalformedEndpoint:
|
||||||
|
continue # this failure is already logged in format_url()
|
||||||
|
|
||||||
region = endpoint['region_id']
|
region = endpoint['region_id']
|
||||||
service_type = endpoint.service['type']
|
service_type = endpoint.service['type']
|
||||||
default_service = {
|
default_service = {
|
||||||
'id': endpoint['id'],
|
'id': endpoint['id'],
|
||||||
'name': endpoint.service.extra.get('name', ''),
|
'name': endpoint.service.extra.get('name', ''),
|
||||||
'publicURL': ''
|
'publicURL': ''
|
||||||
}
|
}
|
||||||
catalog.setdefault(region, {})
|
catalog.setdefault(region, {})
|
||||||
catalog[region].setdefault(service_type, default_service)
|
catalog[region].setdefault(service_type, default_service)
|
||||||
interface_url = '%sURL' % endpoint['interface']
|
interface_url = '%sURL' % endpoint['interface']
|
||||||
catalog[region][service_type][interface_url] = url
|
catalog[region][service_type][interface_url] = url
|
||||||
|
|
||||||
return catalog
|
return catalog
|
||||||
|
|
||||||
def get_v3_catalog(self, user_id, tenant_id):
|
def get_v3_catalog(self, user_id, tenant_id):
|
||||||
"""Retrieve and format the current V3 service catalog.
|
"""Retrieve and format the current V3 service catalog.
|
||||||
@ -349,44 +338,46 @@ class Catalog(catalog.CatalogDriverV8):
|
|||||||
else:
|
else:
|
||||||
silent_keyerror_failures = ['tenant_id', 'project_id', ]
|
silent_keyerror_failures = ['tenant_id', 'project_id', ]
|
||||||
|
|
||||||
session = sql.get_session()
|
with sql.session_for_read() as session:
|
||||||
services = (session.query(Service).filter(Service.enabled == true()).
|
services = (session.query(Service).filter(
|
||||||
options(sql.joinedload(Service.endpoints)).
|
Service.enabled == true()).options(
|
||||||
all())
|
sql.joinedload(Service.endpoints)).all())
|
||||||
|
|
||||||
def make_v3_endpoints(endpoints):
|
def make_v3_endpoints(endpoints):
|
||||||
for endpoint in (ep.to_dict() for ep in endpoints if ep.enabled):
|
for endpoint in (ep.to_dict()
|
||||||
del endpoint['service_id']
|
for ep in endpoints if ep.enabled):
|
||||||
del endpoint['legacy_endpoint_id']
|
del endpoint['service_id']
|
||||||
del endpoint['enabled']
|
del endpoint['legacy_endpoint_id']
|
||||||
endpoint['region'] = endpoint['region_id']
|
del endpoint['enabled']
|
||||||
try:
|
endpoint['region'] = endpoint['region_id']
|
||||||
formatted_url = core.format_url(
|
try:
|
||||||
endpoint['url'], d,
|
formatted_url = core.format_url(
|
||||||
silent_keyerror_failures=silent_keyerror_failures)
|
endpoint['url'], d,
|
||||||
if formatted_url:
|
silent_keyerror_failures=silent_keyerror_failures)
|
||||||
endpoint['url'] = formatted_url
|
if formatted_url:
|
||||||
else:
|
endpoint['url'] = formatted_url
|
||||||
|
else:
|
||||||
|
continue
|
||||||
|
except exception.MalformedEndpoint:
|
||||||
|
# this failure is already logged in format_url()
|
||||||
continue
|
continue
|
||||||
except exception.MalformedEndpoint:
|
|
||||||
continue # this failure is already logged in format_url()
|
|
||||||
|
|
||||||
yield endpoint
|
yield endpoint
|
||||||
|
|
||||||
# TODO(davechen): If there is service with no endpoints, we should skip
|
# TODO(davechen): If there is service with no endpoints, we should
|
||||||
# the service instead of keeping it in the catalog, see bug #1436704.
|
# skip the service instead of keeping it in the catalog,
|
||||||
def make_v3_service(svc):
|
# see bug #1436704.
|
||||||
eps = list(make_v3_endpoints(svc.endpoints))
|
def make_v3_service(svc):
|
||||||
service = {'endpoints': eps, 'id': svc.id, 'type': svc.type}
|
eps = list(make_v3_endpoints(svc.endpoints))
|
||||||
service['name'] = svc.extra.get('name', '')
|
service = {'endpoints': eps, 'id': svc.id, 'type': svc.type}
|
||||||
return service
|
service['name'] = svc.extra.get('name', '')
|
||||||
|
return service
|
||||||
|
|
||||||
return [make_v3_service(svc) for svc in services]
|
return [make_v3_service(svc) for svc in services]
|
||||||
|
|
||||||
@sql.handle_conflicts(conflict_type='project_endpoint')
|
@sql.handle_conflicts(conflict_type='project_endpoint')
|
||||||
def add_endpoint_to_project(self, endpoint_id, project_id):
|
def add_endpoint_to_project(self, endpoint_id, project_id):
|
||||||
session = sql.get_session()
|
with sql.session_for_write() as session:
|
||||||
with session.begin():
|
|
||||||
endpoint_filter_ref = ProjectEndpoint(endpoint_id=endpoint_id,
|
endpoint_filter_ref = ProjectEndpoint(endpoint_id=endpoint_id,
|
||||||
project_id=project_id)
|
project_id=project_id)
|
||||||
session.add(endpoint_filter_ref)
|
session.add(endpoint_filter_ref)
|
||||||
@ -402,50 +393,46 @@ class Catalog(catalog.CatalogDriverV8):
|
|||||||
return endpoint_filter_ref
|
return endpoint_filter_ref
|
||||||
|
|
||||||
def check_endpoint_in_project(self, endpoint_id, project_id):
|
def check_endpoint_in_project(self, endpoint_id, project_id):
|
||||||
session = sql.get_session()
|
with sql.session_for_read() as session:
|
||||||
self._get_project_endpoint_ref(session, endpoint_id, project_id)
|
self._get_project_endpoint_ref(session, endpoint_id, project_id)
|
||||||
|
|
||||||
def remove_endpoint_from_project(self, endpoint_id, project_id):
|
def remove_endpoint_from_project(self, endpoint_id, project_id):
|
||||||
session = sql.get_session()
|
with sql.session_for_write() as session:
|
||||||
endpoint_filter_ref = self._get_project_endpoint_ref(
|
endpoint_filter_ref = self._get_project_endpoint_ref(
|
||||||
session, endpoint_id, project_id)
|
session, endpoint_id, project_id)
|
||||||
with session.begin():
|
|
||||||
session.delete(endpoint_filter_ref)
|
session.delete(endpoint_filter_ref)
|
||||||
|
|
||||||
def list_endpoints_for_project(self, project_id):
|
def list_endpoints_for_project(self, project_id):
|
||||||
session = sql.get_session()
|
with sql.session_for_read() as session:
|
||||||
query = session.query(ProjectEndpoint)
|
query = session.query(ProjectEndpoint)
|
||||||
query = query.filter_by(project_id=project_id)
|
query = query.filter_by(project_id=project_id)
|
||||||
endpoint_filter_refs = query.all()
|
endpoint_filter_refs = query.all()
|
||||||
return [ref.to_dict() for ref in endpoint_filter_refs]
|
return [ref.to_dict() for ref in endpoint_filter_refs]
|
||||||
|
|
||||||
def list_projects_for_endpoint(self, endpoint_id):
|
def list_projects_for_endpoint(self, endpoint_id):
|
||||||
session = sql.get_session()
|
with sql.session_for_read() as session:
|
||||||
query = session.query(ProjectEndpoint)
|
query = session.query(ProjectEndpoint)
|
||||||
query = query.filter_by(endpoint_id=endpoint_id)
|
query = query.filter_by(endpoint_id=endpoint_id)
|
||||||
endpoint_filter_refs = query.all()
|
endpoint_filter_refs = query.all()
|
||||||
return [ref.to_dict() for ref in endpoint_filter_refs]
|
return [ref.to_dict() for ref in endpoint_filter_refs]
|
||||||
|
|
||||||
def delete_association_by_endpoint(self, endpoint_id):
|
def delete_association_by_endpoint(self, endpoint_id):
|
||||||
session = sql.get_session()
|
with sql.session_for_write() as session:
|
||||||
with session.begin():
|
|
||||||
query = session.query(ProjectEndpoint)
|
query = session.query(ProjectEndpoint)
|
||||||
query = query.filter_by(endpoint_id=endpoint_id)
|
query = query.filter_by(endpoint_id=endpoint_id)
|
||||||
query.delete(synchronize_session=False)
|
query.delete(synchronize_session=False)
|
||||||
|
|
||||||
def delete_association_by_project(self, project_id):
|
def delete_association_by_project(self, project_id):
|
||||||
session = sql.get_session()
|
with sql.session_for_write() as session:
|
||||||
with session.begin():
|
|
||||||
query = session.query(ProjectEndpoint)
|
query = session.query(ProjectEndpoint)
|
||||||
query = query.filter_by(project_id=project_id)
|
query = query.filter_by(project_id=project_id)
|
||||||
query.delete(synchronize_session=False)
|
query.delete(synchronize_session=False)
|
||||||
|
|
||||||
def create_endpoint_group(self, endpoint_group_id, endpoint_group):
|
def create_endpoint_group(self, endpoint_group_id, endpoint_group):
|
||||||
session = sql.get_session()
|
with sql.session_for_write() as session:
|
||||||
with session.begin():
|
|
||||||
endpoint_group_ref = EndpointGroup.from_dict(endpoint_group)
|
endpoint_group_ref = EndpointGroup.from_dict(endpoint_group)
|
||||||
session.add(endpoint_group_ref)
|
session.add(endpoint_group_ref)
|
||||||
return endpoint_group_ref.to_dict()
|
return endpoint_group_ref.to_dict()
|
||||||
|
|
||||||
def _get_endpoint_group(self, session, endpoint_group_id):
|
def _get_endpoint_group(self, session, endpoint_group_id):
|
||||||
endpoint_group_ref = session.query(EndpointGroup).get(
|
endpoint_group_ref = session.query(EndpointGroup).get(
|
||||||
@ -456,14 +443,13 @@ class Catalog(catalog.CatalogDriverV8):
|
|||||||
return endpoint_group_ref
|
return endpoint_group_ref
|
||||||
|
|
||||||
def get_endpoint_group(self, endpoint_group_id):
|
def get_endpoint_group(self, endpoint_group_id):
|
||||||
session = sql.get_session()
|
with sql.session_for_read() as session:
|
||||||
endpoint_group_ref = self._get_endpoint_group(session,
|
endpoint_group_ref = self._get_endpoint_group(session,
|
||||||
endpoint_group_id)
|
endpoint_group_id)
|
||||||
return endpoint_group_ref.to_dict()
|
return endpoint_group_ref.to_dict()
|
||||||
|
|
||||||
def update_endpoint_group(self, endpoint_group_id, endpoint_group):
|
def update_endpoint_group(self, endpoint_group_id, endpoint_group):
|
||||||
session = sql.get_session()
|
with sql.session_for_write() as session:
|
||||||
with session.begin():
|
|
||||||
endpoint_group_ref = self._get_endpoint_group(session,
|
endpoint_group_ref = self._get_endpoint_group(session,
|
||||||
endpoint_group_id)
|
endpoint_group_id)
|
||||||
old_endpoint_group = endpoint_group_ref.to_dict()
|
old_endpoint_group = endpoint_group_ref.to_dict()
|
||||||
@ -472,29 +458,26 @@ class Catalog(catalog.CatalogDriverV8):
|
|||||||
for attr in EndpointGroup.mutable_attributes:
|
for attr in EndpointGroup.mutable_attributes:
|
||||||
setattr(endpoint_group_ref, attr,
|
setattr(endpoint_group_ref, attr,
|
||||||
getattr(new_endpoint_group, attr))
|
getattr(new_endpoint_group, attr))
|
||||||
return endpoint_group_ref.to_dict()
|
return endpoint_group_ref.to_dict()
|
||||||
|
|
||||||
def delete_endpoint_group(self, endpoint_group_id):
|
def delete_endpoint_group(self, endpoint_group_id):
|
||||||
session = sql.get_session()
|
with sql.session_for_write() as session:
|
||||||
endpoint_group_ref = self._get_endpoint_group(session,
|
endpoint_group_ref = self._get_endpoint_group(session,
|
||||||
endpoint_group_id)
|
endpoint_group_id)
|
||||||
with session.begin():
|
|
||||||
self._delete_endpoint_group_association_by_endpoint_group(
|
self._delete_endpoint_group_association_by_endpoint_group(
|
||||||
session, endpoint_group_id)
|
session, endpoint_group_id)
|
||||||
session.delete(endpoint_group_ref)
|
session.delete(endpoint_group_ref)
|
||||||
|
|
||||||
def get_endpoint_group_in_project(self, endpoint_group_id, project_id):
|
def get_endpoint_group_in_project(self, endpoint_group_id, project_id):
|
||||||
session = sql.get_session()
|
with sql.session_for_read() as session:
|
||||||
ref = self._get_endpoint_group_in_project(session,
|
ref = self._get_endpoint_group_in_project(session,
|
||||||
endpoint_group_id,
|
endpoint_group_id,
|
||||||
project_id)
|
project_id)
|
||||||
return ref.to_dict()
|
return ref.to_dict()
|
||||||
|
|
||||||
@sql.handle_conflicts(conflict_type='project_endpoint_group')
|
@sql.handle_conflicts(conflict_type='project_endpoint_group')
|
||||||
def add_endpoint_group_to_project(self, endpoint_group_id, project_id):
|
def add_endpoint_group_to_project(self, endpoint_group_id, project_id):
|
||||||
session = sql.get_session()
|
with sql.session_for_write() as session:
|
||||||
|
|
||||||
with session.begin():
|
|
||||||
# Create a new Project Endpoint group entity
|
# Create a new Project Endpoint group entity
|
||||||
endpoint_group_project_ref = ProjectEndpointGroupMembership(
|
endpoint_group_project_ref = ProjectEndpointGroupMembership(
|
||||||
endpoint_group_id=endpoint_group_id, project_id=project_id)
|
endpoint_group_id=endpoint_group_id, project_id=project_id)
|
||||||
@ -512,32 +495,31 @@ class Catalog(catalog.CatalogDriverV8):
|
|||||||
return endpoint_group_project_ref
|
return endpoint_group_project_ref
|
||||||
|
|
||||||
def list_endpoint_groups(self):
|
def list_endpoint_groups(self):
|
||||||
session = sql.get_session()
|
with sql.session_for_read() as session:
|
||||||
query = session.query(EndpointGroup)
|
query = session.query(EndpointGroup)
|
||||||
endpoint_group_refs = query.all()
|
endpoint_group_refs = query.all()
|
||||||
return [e.to_dict() for e in endpoint_group_refs]
|
return [e.to_dict() for e in endpoint_group_refs]
|
||||||
|
|
||||||
def list_endpoint_groups_for_project(self, project_id):
|
def list_endpoint_groups_for_project(self, project_id):
|
||||||
session = sql.get_session()
|
with sql.session_for_read() as session:
|
||||||
query = session.query(ProjectEndpointGroupMembership)
|
query = session.query(ProjectEndpointGroupMembership)
|
||||||
query = query.filter_by(project_id=project_id)
|
query = query.filter_by(project_id=project_id)
|
||||||
endpoint_group_refs = query.all()
|
endpoint_group_refs = query.all()
|
||||||
return [ref.to_dict() for ref in endpoint_group_refs]
|
return [ref.to_dict() for ref in endpoint_group_refs]
|
||||||
|
|
||||||
def remove_endpoint_group_from_project(self, endpoint_group_id,
|
def remove_endpoint_group_from_project(self, endpoint_group_id,
|
||||||
project_id):
|
project_id):
|
||||||
session = sql.get_session()
|
with sql.session_for_write() as session:
|
||||||
endpoint_group_project_ref = self._get_endpoint_group_in_project(
|
endpoint_group_project_ref = self._get_endpoint_group_in_project(
|
||||||
session, endpoint_group_id, project_id)
|
session, endpoint_group_id, project_id)
|
||||||
with session.begin():
|
|
||||||
session.delete(endpoint_group_project_ref)
|
session.delete(endpoint_group_project_ref)
|
||||||
|
|
||||||
def list_projects_associated_with_endpoint_group(self, endpoint_group_id):
|
def list_projects_associated_with_endpoint_group(self, endpoint_group_id):
|
||||||
session = sql.get_session()
|
with sql.session_for_read() as session:
|
||||||
query = session.query(ProjectEndpointGroupMembership)
|
query = session.query(ProjectEndpointGroupMembership)
|
||||||
query = query.filter_by(endpoint_group_id=endpoint_group_id)
|
query = query.filter_by(endpoint_group_id=endpoint_group_id)
|
||||||
endpoint_group_refs = query.all()
|
endpoint_group_refs = query.all()
|
||||||
return [ref.to_dict() for ref in endpoint_group_refs]
|
return [ref.to_dict() for ref in endpoint_group_refs]
|
||||||
|
|
||||||
def _delete_endpoint_group_association_by_endpoint_group(
|
def _delete_endpoint_group_association_by_endpoint_group(
|
||||||
self, session, endpoint_group_id):
|
self, session, endpoint_group_id):
|
||||||
@ -546,8 +528,7 @@ class Catalog(catalog.CatalogDriverV8):
|
|||||||
query.delete()
|
query.delete()
|
||||||
|
|
||||||
def delete_endpoint_group_association_by_project(self, project_id):
|
def delete_endpoint_group_association_by_project(self, project_id):
|
||||||
session = sql.get_session()
|
with sql.session_for_write() as session:
|
||||||
with session.begin():
|
|
||||||
query = session.query(ProjectEndpointGroupMembership)
|
query = session.query(ProjectEndpointGroupMembership)
|
||||||
query = query.filter_by(project_id=project_id)
|
query = query.filter_by(project_id=project_id)
|
||||||
query.delete()
|
query.delete()
|
||||||
|
@ -18,14 +18,14 @@ Before using this module, call initialize(). This has to be done before
|
|||||||
CONF() because it sets up configuration options.
|
CONF() because it sets up configuration options.
|
||||||
|
|
||||||
"""
|
"""
|
||||||
import contextlib
|
|
||||||
import functools
|
import functools
|
||||||
|
import threading
|
||||||
|
|
||||||
from oslo_config import cfg
|
from oslo_config import cfg
|
||||||
from oslo_db import exception as db_exception
|
from oslo_db import exception as db_exception
|
||||||
from oslo_db import options as db_options
|
from oslo_db import options as db_options
|
||||||
|
from oslo_db.sqlalchemy import enginefacade
|
||||||
from oslo_db.sqlalchemy import models
|
from oslo_db.sqlalchemy import models
|
||||||
from oslo_db.sqlalchemy import session as db_session
|
|
||||||
from oslo_log import log
|
from oslo_log import log
|
||||||
from oslo_serialization import jsonutils
|
from oslo_serialization import jsonutils
|
||||||
import six
|
import six
|
||||||
@ -166,38 +166,41 @@ class ModelDictMixin(object):
|
|||||||
return {name: getattr(self, name) for name in names}
|
return {name: getattr(self, name) for name in names}
|
||||||
|
|
||||||
|
|
||||||
_engine_facade = None
|
_main_context_manager = None
|
||||||
|
|
||||||
|
|
||||||
def _get_engine_facade():
|
def _get_main_context_manager():
|
||||||
global _engine_facade
|
global _main_context_manager
|
||||||
|
|
||||||
if not _engine_facade:
|
if not _main_context_manager:
|
||||||
_engine_facade = db_session.EngineFacade.from_config(CONF)
|
_main_context_manager = enginefacade.transaction_context()
|
||||||
|
|
||||||
return _engine_facade
|
return _main_context_manager
|
||||||
|
|
||||||
|
|
||||||
def cleanup():
|
def cleanup():
|
||||||
global _engine_facade
|
global _main_context_manager
|
||||||
|
|
||||||
_engine_facade = None
|
_main_context_manager = None
|
||||||
|
|
||||||
|
|
||||||
def get_engine():
|
def get_engine():
|
||||||
return _get_engine_facade().get_engine()
|
return _get_main_context_manager().get_legacy_facade().get_engine()
|
||||||
|
|
||||||
|
|
||||||
def get_session(expire_on_commit=False):
|
def get_session():
|
||||||
return _get_engine_facade().get_session(expire_on_commit=expire_on_commit)
|
return _get_main_context_manager().get_legacy_facade().get_session()
|
||||||
|
|
||||||
|
|
||||||
@contextlib.contextmanager
|
_CONTEXT = threading.local()
|
||||||
def transaction(expire_on_commit=False):
|
|
||||||
"""Return a SQLAlchemy session in a scoped transaction."""
|
|
||||||
session = get_session(expire_on_commit=expire_on_commit)
|
def session_for_read():
|
||||||
with session.begin():
|
return _get_main_context_manager().reader.using(_CONTEXT)
|
||||||
yield session
|
|
||||||
|
|
||||||
|
def session_for_write():
|
||||||
|
return _get_main_context_manager().writer.using(_CONTEXT)
|
||||||
|
|
||||||
|
|
||||||
def truncated(f):
|
def truncated(f):
|
||||||
|
@ -178,7 +178,7 @@ def _sync_extension_repo(extension, version):
|
|||||||
try:
|
try:
|
||||||
abs_path = find_migrate_repo(package)
|
abs_path = find_migrate_repo(package)
|
||||||
try:
|
try:
|
||||||
migration.db_version_control(sql.get_engine(), abs_path)
|
migration.db_version_control(engine, abs_path)
|
||||||
# Register the repo with the version control API
|
# Register the repo with the version control API
|
||||||
# If it already knows about the repo, it will throw
|
# If it already knows about the repo, it will throw
|
||||||
# an exception that we can safely ignore
|
# an exception that we can safely ignore
|
||||||
|
@ -36,28 +36,27 @@ class Credential(credential.CredentialDriverV8):
|
|||||||
|
|
||||||
@sql.handle_conflicts(conflict_type='credential')
|
@sql.handle_conflicts(conflict_type='credential')
|
||||||
def create_credential(self, credential_id, credential):
|
def create_credential(self, credential_id, credential):
|
||||||
session = sql.get_session()
|
with sql.session_for_write() as session:
|
||||||
with session.begin():
|
|
||||||
ref = CredentialModel.from_dict(credential)
|
ref = CredentialModel.from_dict(credential)
|
||||||
session.add(ref)
|
session.add(ref)
|
||||||
return ref.to_dict()
|
return ref.to_dict()
|
||||||
|
|
||||||
@driver_hints.truncated
|
@driver_hints.truncated
|
||||||
def list_credentials(self, hints):
|
def list_credentials(self, hints):
|
||||||
session = sql.get_session()
|
with sql.session_for_read() as session:
|
||||||
credentials = session.query(CredentialModel)
|
credentials = session.query(CredentialModel)
|
||||||
credentials = sql.filter_limit_query(CredentialModel,
|
credentials = sql.filter_limit_query(CredentialModel,
|
||||||
credentials, hints)
|
credentials, hints)
|
||||||
return [s.to_dict() for s in credentials]
|
return [s.to_dict() for s in credentials]
|
||||||
|
|
||||||
def list_credentials_for_user(self, user_id, type=None):
|
def list_credentials_for_user(self, user_id, type=None):
|
||||||
session = sql.get_session()
|
with sql.session_for_read() as session:
|
||||||
query = session.query(CredentialModel)
|
query = session.query(CredentialModel)
|
||||||
query = query.filter_by(user_id=user_id)
|
query = query.filter_by(user_id=user_id)
|
||||||
if type:
|
if type:
|
||||||
query = query.filter_by(type=type)
|
query = query.filter_by(type=type)
|
||||||
refs = query.all()
|
refs = query.all()
|
||||||
return [ref.to_dict() for ref in refs]
|
return [ref.to_dict() for ref in refs]
|
||||||
|
|
||||||
def _get_credential(self, session, credential_id):
|
def _get_credential(self, session, credential_id):
|
||||||
ref = session.query(CredentialModel).get(credential_id)
|
ref = session.query(CredentialModel).get(credential_id)
|
||||||
@ -66,13 +65,12 @@ class Credential(credential.CredentialDriverV8):
|
|||||||
return ref
|
return ref
|
||||||
|
|
||||||
def get_credential(self, credential_id):
|
def get_credential(self, credential_id):
|
||||||
session = sql.get_session()
|
with sql.session_for_read() as session:
|
||||||
return self._get_credential(session, credential_id).to_dict()
|
return self._get_credential(session, credential_id).to_dict()
|
||||||
|
|
||||||
@sql.handle_conflicts(conflict_type='credential')
|
@sql.handle_conflicts(conflict_type='credential')
|
||||||
def update_credential(self, credential_id, credential):
|
def update_credential(self, credential_id, credential):
|
||||||
session = sql.get_session()
|
with sql.session_for_write() as session:
|
||||||
with session.begin():
|
|
||||||
ref = self._get_credential(session, credential_id)
|
ref = self._get_credential(session, credential_id)
|
||||||
old_dict = ref.to_dict()
|
old_dict = ref.to_dict()
|
||||||
for k in credential:
|
for k in credential:
|
||||||
@ -82,27 +80,21 @@ class Credential(credential.CredentialDriverV8):
|
|||||||
if attr != 'id':
|
if attr != 'id':
|
||||||
setattr(ref, attr, getattr(new_credential, attr))
|
setattr(ref, attr, getattr(new_credential, attr))
|
||||||
ref.extra = new_credential.extra
|
ref.extra = new_credential.extra
|
||||||
return ref.to_dict()
|
return ref.to_dict()
|
||||||
|
|
||||||
def delete_credential(self, credential_id):
|
def delete_credential(self, credential_id):
|
||||||
session = sql.get_session()
|
with sql.session_for_write() as session:
|
||||||
|
|
||||||
with session.begin():
|
|
||||||
ref = self._get_credential(session, credential_id)
|
ref = self._get_credential(session, credential_id)
|
||||||
session.delete(ref)
|
session.delete(ref)
|
||||||
|
|
||||||
def delete_credentials_for_project(self, project_id):
|
def delete_credentials_for_project(self, project_id):
|
||||||
session = sql.get_session()
|
with sql.session_for_write() as session:
|
||||||
|
|
||||||
with session.begin():
|
|
||||||
query = session.query(CredentialModel)
|
query = session.query(CredentialModel)
|
||||||
query = query.filter_by(project_id=project_id)
|
query = query.filter_by(project_id=project_id)
|
||||||
query.delete()
|
query.delete()
|
||||||
|
|
||||||
def delete_credentials_for_user(self, user_id):
|
def delete_credentials_for_user(self, user_id):
|
||||||
session = sql.get_session()
|
with sql.session_for_write() as session:
|
||||||
|
|
||||||
with session.begin():
|
|
||||||
query = session.query(CredentialModel)
|
query = session.query(CredentialModel)
|
||||||
query = query.filter_by(user_id=user_id)
|
query = query.filter_by(user_id=user_id)
|
||||||
query.delete()
|
query.delete()
|
||||||
|
@ -51,7 +51,7 @@ class EndpointPolicy(object):
|
|||||||
|
|
||||||
def create_policy_association(self, policy_id, endpoint_id=None,
|
def create_policy_association(self, policy_id, endpoint_id=None,
|
||||||
service_id=None, region_id=None):
|
service_id=None, region_id=None):
|
||||||
with sql.transaction() as session:
|
with sql.session_for_write() as session:
|
||||||
try:
|
try:
|
||||||
# See if there is already a row for this association, and if
|
# See if there is already a row for this association, and if
|
||||||
# so, update it with the new policy_id
|
# so, update it with the new policy_id
|
||||||
@ -79,14 +79,14 @@ class EndpointPolicy(object):
|
|||||||
|
|
||||||
# NOTE(henry-nash): Getting a single value to save object
|
# NOTE(henry-nash): Getting a single value to save object
|
||||||
# management overhead.
|
# management overhead.
|
||||||
with sql.transaction() as session:
|
with sql.session_for_read() as session:
|
||||||
if session.query(PolicyAssociation.id).filter(
|
if session.query(PolicyAssociation.id).filter(
|
||||||
sql_constraints).distinct().count() == 0:
|
sql_constraints).distinct().count() == 0:
|
||||||
raise exception.PolicyAssociationNotFound()
|
raise exception.PolicyAssociationNotFound()
|
||||||
|
|
||||||
def delete_policy_association(self, policy_id, endpoint_id=None,
|
def delete_policy_association(self, policy_id, endpoint_id=None,
|
||||||
service_id=None, region_id=None):
|
service_id=None, region_id=None):
|
||||||
with sql.transaction() as session:
|
with sql.session_for_write() as session:
|
||||||
query = session.query(PolicyAssociation)
|
query = session.query(PolicyAssociation)
|
||||||
query = query.filter_by(policy_id=policy_id)
|
query = query.filter_by(policy_id=policy_id)
|
||||||
query = query.filter_by(endpoint_id=endpoint_id)
|
query = query.filter_by(endpoint_id=endpoint_id)
|
||||||
@ -102,7 +102,7 @@ class EndpointPolicy(object):
|
|||||||
PolicyAssociation.region_id == region_id)
|
PolicyAssociation.region_id == region_id)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
with sql.transaction() as session:
|
with sql.session_for_read() as session:
|
||||||
policy_id = session.query(PolicyAssociation.policy_id).filter(
|
policy_id = session.query(PolicyAssociation.policy_id).filter(
|
||||||
sql_constraints).distinct().one()
|
sql_constraints).distinct().one()
|
||||||
return {'policy_id': policy_id}
|
return {'policy_id': policy_id}
|
||||||
@ -110,31 +110,31 @@ class EndpointPolicy(object):
|
|||||||
raise exception.PolicyAssociationNotFound()
|
raise exception.PolicyAssociationNotFound()
|
||||||
|
|
||||||
def list_associations_for_policy(self, policy_id):
|
def list_associations_for_policy(self, policy_id):
|
||||||
with sql.transaction() as session:
|
with sql.session_for_read() as session:
|
||||||
query = session.query(PolicyAssociation)
|
query = session.query(PolicyAssociation)
|
||||||
query = query.filter_by(policy_id=policy_id)
|
query = query.filter_by(policy_id=policy_id)
|
||||||
return [ref.to_dict() for ref in query.all()]
|
return [ref.to_dict() for ref in query.all()]
|
||||||
|
|
||||||
def delete_association_by_endpoint(self, endpoint_id):
|
def delete_association_by_endpoint(self, endpoint_id):
|
||||||
with sql.transaction() as session:
|
with sql.session_for_write() as session:
|
||||||
query = session.query(PolicyAssociation)
|
query = session.query(PolicyAssociation)
|
||||||
query = query.filter_by(endpoint_id=endpoint_id)
|
query = query.filter_by(endpoint_id=endpoint_id)
|
||||||
query.delete()
|
query.delete()
|
||||||
|
|
||||||
def delete_association_by_service(self, service_id):
|
def delete_association_by_service(self, service_id):
|
||||||
with sql.transaction() as session:
|
with sql.session_for_write() as session:
|
||||||
query = session.query(PolicyAssociation)
|
query = session.query(PolicyAssociation)
|
||||||
query = query.filter_by(service_id=service_id)
|
query = query.filter_by(service_id=service_id)
|
||||||
query.delete()
|
query.delete()
|
||||||
|
|
||||||
def delete_association_by_region(self, region_id):
|
def delete_association_by_region(self, region_id):
|
||||||
with sql.transaction() as session:
|
with sql.session_for_write() as session:
|
||||||
query = session.query(PolicyAssociation)
|
query = session.query(PolicyAssociation)
|
||||||
query = query.filter_by(region_id=region_id)
|
query = query.filter_by(region_id=region_id)
|
||||||
query.delete()
|
query.delete()
|
||||||
|
|
||||||
def delete_association_by_policy(self, policy_id):
|
def delete_association_by_policy(self, policy_id):
|
||||||
with sql.transaction() as session:
|
with sql.session_for_write() as session:
|
||||||
query = session.query(PolicyAssociation)
|
query = session.query(PolicyAssociation)
|
||||||
query = query.filter_by(policy_id=policy_id)
|
query = query.filter_by(policy_id=policy_id)
|
||||||
query.delete()
|
query.delete()
|
||||||
|
@ -161,13 +161,13 @@ class Federation(core.FederationDriverV8):
|
|||||||
@sql.handle_conflicts(conflict_type='identity_provider')
|
@sql.handle_conflicts(conflict_type='identity_provider')
|
||||||
def create_idp(self, idp_id, idp):
|
def create_idp(self, idp_id, idp):
|
||||||
idp['id'] = idp_id
|
idp['id'] = idp_id
|
||||||
with sql.transaction() as session:
|
with sql.session_for_write() as session:
|
||||||
idp_ref = IdentityProviderModel.from_dict(idp)
|
idp_ref = IdentityProviderModel.from_dict(idp)
|
||||||
session.add(idp_ref)
|
session.add(idp_ref)
|
||||||
return idp_ref.to_dict()
|
return idp_ref.to_dict()
|
||||||
|
|
||||||
def delete_idp(self, idp_id):
|
def delete_idp(self, idp_id):
|
||||||
with sql.transaction() as session:
|
with sql.session_for_write() as session:
|
||||||
self._delete_assigned_protocols(session, idp_id)
|
self._delete_assigned_protocols(session, idp_id)
|
||||||
idp_ref = self._get_idp(session, idp_id)
|
idp_ref = self._get_idp(session, idp_id)
|
||||||
session.delete(idp_ref)
|
session.delete(idp_ref)
|
||||||
@ -187,30 +187,30 @@ class Federation(core.FederationDriverV8):
|
|||||||
raise exception.IdentityProviderNotFound(idp_id=remote_id)
|
raise exception.IdentityProviderNotFound(idp_id=remote_id)
|
||||||
|
|
||||||
def list_idps(self):
|
def list_idps(self):
|
||||||
with sql.transaction() as session:
|
with sql.session_for_read() as session:
|
||||||
idps = session.query(IdentityProviderModel)
|
idps = session.query(IdentityProviderModel)
|
||||||
idps_list = [idp.to_dict() for idp in idps]
|
idps_list = [idp.to_dict() for idp in idps]
|
||||||
return idps_list
|
return idps_list
|
||||||
|
|
||||||
def get_idp(self, idp_id):
|
def get_idp(self, idp_id):
|
||||||
with sql.transaction() as session:
|
with sql.session_for_read() as session:
|
||||||
idp_ref = self._get_idp(session, idp_id)
|
idp_ref = self._get_idp(session, idp_id)
|
||||||
return idp_ref.to_dict()
|
return idp_ref.to_dict()
|
||||||
|
|
||||||
def get_idp_from_remote_id(self, remote_id):
|
def get_idp_from_remote_id(self, remote_id):
|
||||||
with sql.transaction() as session:
|
with sql.session_for_read() as session:
|
||||||
ref = self._get_idp_from_remote_id(session, remote_id)
|
ref = self._get_idp_from_remote_id(session, remote_id)
|
||||||
return ref.to_dict()
|
return ref.to_dict()
|
||||||
|
|
||||||
def update_idp(self, idp_id, idp):
|
def update_idp(self, idp_id, idp):
|
||||||
with sql.transaction() as session:
|
with sql.session_for_write() as session:
|
||||||
idp_ref = self._get_idp(session, idp_id)
|
idp_ref = self._get_idp(session, idp_id)
|
||||||
old_idp = idp_ref.to_dict()
|
old_idp = idp_ref.to_dict()
|
||||||
old_idp.update(idp)
|
old_idp.update(idp)
|
||||||
new_idp = IdentityProviderModel.from_dict(old_idp)
|
new_idp = IdentityProviderModel.from_dict(old_idp)
|
||||||
for attr in IdentityProviderModel.mutable_attributes:
|
for attr in IdentityProviderModel.mutable_attributes:
|
||||||
setattr(idp_ref, attr, getattr(new_idp, attr))
|
setattr(idp_ref, attr, getattr(new_idp, attr))
|
||||||
return idp_ref.to_dict()
|
return idp_ref.to_dict()
|
||||||
|
|
||||||
# Protocol CRUD
|
# Protocol CRUD
|
||||||
def _get_protocol(self, session, idp_id, protocol_id):
|
def _get_protocol(self, session, idp_id, protocol_id):
|
||||||
@ -227,36 +227,36 @@ class Federation(core.FederationDriverV8):
|
|||||||
def create_protocol(self, idp_id, protocol_id, protocol):
|
def create_protocol(self, idp_id, protocol_id, protocol):
|
||||||
protocol['id'] = protocol_id
|
protocol['id'] = protocol_id
|
||||||
protocol['idp_id'] = idp_id
|
protocol['idp_id'] = idp_id
|
||||||
with sql.transaction() as session:
|
with sql.session_for_write() as session:
|
||||||
self._get_idp(session, idp_id)
|
self._get_idp(session, idp_id)
|
||||||
protocol_ref = FederationProtocolModel.from_dict(protocol)
|
protocol_ref = FederationProtocolModel.from_dict(protocol)
|
||||||
session.add(protocol_ref)
|
session.add(protocol_ref)
|
||||||
return protocol_ref.to_dict()
|
return protocol_ref.to_dict()
|
||||||
|
|
||||||
def update_protocol(self, idp_id, protocol_id, protocol):
|
def update_protocol(self, idp_id, protocol_id, protocol):
|
||||||
with sql.transaction() as session:
|
with sql.session_for_write() as session:
|
||||||
proto_ref = self._get_protocol(session, idp_id, protocol_id)
|
proto_ref = self._get_protocol(session, idp_id, protocol_id)
|
||||||
old_proto = proto_ref.to_dict()
|
old_proto = proto_ref.to_dict()
|
||||||
old_proto.update(protocol)
|
old_proto.update(protocol)
|
||||||
new_proto = FederationProtocolModel.from_dict(old_proto)
|
new_proto = FederationProtocolModel.from_dict(old_proto)
|
||||||
for attr in FederationProtocolModel.mutable_attributes:
|
for attr in FederationProtocolModel.mutable_attributes:
|
||||||
setattr(proto_ref, attr, getattr(new_proto, attr))
|
setattr(proto_ref, attr, getattr(new_proto, attr))
|
||||||
return proto_ref.to_dict()
|
return proto_ref.to_dict()
|
||||||
|
|
||||||
def get_protocol(self, idp_id, protocol_id):
|
def get_protocol(self, idp_id, protocol_id):
|
||||||
with sql.transaction() as session:
|
with sql.session_for_read() as session:
|
||||||
protocol_ref = self._get_protocol(session, idp_id, protocol_id)
|
protocol_ref = self._get_protocol(session, idp_id, protocol_id)
|
||||||
return protocol_ref.to_dict()
|
return protocol_ref.to_dict()
|
||||||
|
|
||||||
def list_protocols(self, idp_id):
|
def list_protocols(self, idp_id):
|
||||||
with sql.transaction() as session:
|
with sql.session_for_read() as session:
|
||||||
q = session.query(FederationProtocolModel)
|
q = session.query(FederationProtocolModel)
|
||||||
q = q.filter_by(idp_id=idp_id)
|
q = q.filter_by(idp_id=idp_id)
|
||||||
protocols = [protocol.to_dict() for protocol in q]
|
protocols = [protocol.to_dict() for protocol in q]
|
||||||
return protocols
|
return protocols
|
||||||
|
|
||||||
def delete_protocol(self, idp_id, protocol_id):
|
def delete_protocol(self, idp_id, protocol_id):
|
||||||
with sql.transaction() as session:
|
with sql.session_for_write() as session:
|
||||||
key_ref = self._get_protocol(session, idp_id, protocol_id)
|
key_ref = self._get_protocol(session, idp_id, protocol_id)
|
||||||
session.delete(key_ref)
|
session.delete(key_ref)
|
||||||
|
|
||||||
@ -277,58 +277,58 @@ class Federation(core.FederationDriverV8):
|
|||||||
ref = {}
|
ref = {}
|
||||||
ref['id'] = mapping_id
|
ref['id'] = mapping_id
|
||||||
ref['rules'] = mapping.get('rules')
|
ref['rules'] = mapping.get('rules')
|
||||||
with sql.transaction() as session:
|
with sql.session_for_write() as session:
|
||||||
mapping_ref = MappingModel.from_dict(ref)
|
mapping_ref = MappingModel.from_dict(ref)
|
||||||
session.add(mapping_ref)
|
session.add(mapping_ref)
|
||||||
return mapping_ref.to_dict()
|
return mapping_ref.to_dict()
|
||||||
|
|
||||||
def delete_mapping(self, mapping_id):
|
def delete_mapping(self, mapping_id):
|
||||||
with sql.transaction() as session:
|
with sql.session_for_write() as session:
|
||||||
mapping_ref = self._get_mapping(session, mapping_id)
|
mapping_ref = self._get_mapping(session, mapping_id)
|
||||||
session.delete(mapping_ref)
|
session.delete(mapping_ref)
|
||||||
|
|
||||||
def list_mappings(self):
|
def list_mappings(self):
|
||||||
with sql.transaction() as session:
|
with sql.session_for_read() as session:
|
||||||
mappings = session.query(MappingModel)
|
mappings = session.query(MappingModel)
|
||||||
return [x.to_dict() for x in mappings]
|
return [x.to_dict() for x in mappings]
|
||||||
|
|
||||||
def get_mapping(self, mapping_id):
|
def get_mapping(self, mapping_id):
|
||||||
with sql.transaction() as session:
|
with sql.session_for_read() as session:
|
||||||
mapping_ref = self._get_mapping(session, mapping_id)
|
mapping_ref = self._get_mapping(session, mapping_id)
|
||||||
return mapping_ref.to_dict()
|
return mapping_ref.to_dict()
|
||||||
|
|
||||||
@sql.handle_conflicts(conflict_type='mapping')
|
@sql.handle_conflicts(conflict_type='mapping')
|
||||||
def update_mapping(self, mapping_id, mapping):
|
def update_mapping(self, mapping_id, mapping):
|
||||||
ref = {}
|
ref = {}
|
||||||
ref['id'] = mapping_id
|
ref['id'] = mapping_id
|
||||||
ref['rules'] = mapping.get('rules')
|
ref['rules'] = mapping.get('rules')
|
||||||
with sql.transaction() as session:
|
with sql.session_for_write() as session:
|
||||||
mapping_ref = self._get_mapping(session, mapping_id)
|
mapping_ref = self._get_mapping(session, mapping_id)
|
||||||
old_mapping = mapping_ref.to_dict()
|
old_mapping = mapping_ref.to_dict()
|
||||||
old_mapping.update(ref)
|
old_mapping.update(ref)
|
||||||
new_mapping = MappingModel.from_dict(old_mapping)
|
new_mapping = MappingModel.from_dict(old_mapping)
|
||||||
for attr in MappingModel.attributes:
|
for attr in MappingModel.attributes:
|
||||||
setattr(mapping_ref, attr, getattr(new_mapping, attr))
|
setattr(mapping_ref, attr, getattr(new_mapping, attr))
|
||||||
return mapping_ref.to_dict()
|
return mapping_ref.to_dict()
|
||||||
|
|
||||||
def get_mapping_from_idp_and_protocol(self, idp_id, protocol_id):
|
def get_mapping_from_idp_and_protocol(self, idp_id, protocol_id):
|
||||||
with sql.transaction() as session:
|
with sql.session_for_read() as session:
|
||||||
protocol_ref = self._get_protocol(session, idp_id, protocol_id)
|
protocol_ref = self._get_protocol(session, idp_id, protocol_id)
|
||||||
mapping_id = protocol_ref.mapping_id
|
mapping_id = protocol_ref.mapping_id
|
||||||
mapping_ref = self._get_mapping(session, mapping_id)
|
mapping_ref = self._get_mapping(session, mapping_id)
|
||||||
return mapping_ref.to_dict()
|
return mapping_ref.to_dict()
|
||||||
|
|
||||||
# Service Provider CRUD
|
# Service Provider CRUD
|
||||||
@sql.handle_conflicts(conflict_type='service_provider')
|
@sql.handle_conflicts(conflict_type='service_provider')
|
||||||
def create_sp(self, sp_id, sp):
|
def create_sp(self, sp_id, sp):
|
||||||
sp['id'] = sp_id
|
sp['id'] = sp_id
|
||||||
with sql.transaction() as session:
|
with sql.session_for_write() as session:
|
||||||
sp_ref = ServiceProviderModel.from_dict(sp)
|
sp_ref = ServiceProviderModel.from_dict(sp)
|
||||||
session.add(sp_ref)
|
session.add(sp_ref)
|
||||||
return sp_ref.to_dict()
|
return sp_ref.to_dict()
|
||||||
|
|
||||||
def delete_sp(self, sp_id):
|
def delete_sp(self, sp_id):
|
||||||
with sql.transaction() as session:
|
with sql.session_for_write() as session:
|
||||||
sp_ref = self._get_sp(session, sp_id)
|
sp_ref = self._get_sp(session, sp_id)
|
||||||
session.delete(sp_ref)
|
session.delete(sp_ref)
|
||||||
|
|
||||||
@ -339,28 +339,28 @@ class Federation(core.FederationDriverV8):
|
|||||||
return sp_ref
|
return sp_ref
|
||||||
|
|
||||||
def list_sps(self):
|
def list_sps(self):
|
||||||
with sql.transaction() as session:
|
with sql.session_for_read() as session:
|
||||||
sps = session.query(ServiceProviderModel)
|
sps = session.query(ServiceProviderModel)
|
||||||
sps_list = [sp.to_dict() for sp in sps]
|
sps_list = [sp.to_dict() for sp in sps]
|
||||||
return sps_list
|
return sps_list
|
||||||
|
|
||||||
def get_sp(self, sp_id):
|
def get_sp(self, sp_id):
|
||||||
with sql.transaction() as session:
|
with sql.session_for_read() as session:
|
||||||
sp_ref = self._get_sp(session, sp_id)
|
sp_ref = self._get_sp(session, sp_id)
|
||||||
return sp_ref.to_dict()
|
return sp_ref.to_dict()
|
||||||
|
|
||||||
def update_sp(self, sp_id, sp):
|
def update_sp(self, sp_id, sp):
|
||||||
with sql.transaction() as session:
|
with sql.session_for_write() as session:
|
||||||
sp_ref = self._get_sp(session, sp_id)
|
sp_ref = self._get_sp(session, sp_id)
|
||||||
old_sp = sp_ref.to_dict()
|
old_sp = sp_ref.to_dict()
|
||||||
old_sp.update(sp)
|
old_sp.update(sp)
|
||||||
new_sp = ServiceProviderModel.from_dict(old_sp)
|
new_sp = ServiceProviderModel.from_dict(old_sp)
|
||||||
for attr in ServiceProviderModel.mutable_attributes:
|
for attr in ServiceProviderModel.mutable_attributes:
|
||||||
setattr(sp_ref, attr, getattr(new_sp, attr))
|
setattr(sp_ref, attr, getattr(new_sp, attr))
|
||||||
return sp_ref.to_dict()
|
return sp_ref.to_dict()
|
||||||
|
|
||||||
def get_enabled_service_providers(self):
|
def get_enabled_service_providers(self):
|
||||||
with sql.transaction() as session:
|
with sql.session_for_read() as session:
|
||||||
service_providers = session.query(ServiceProviderModel)
|
service_providers = session.query(ServiceProviderModel)
|
||||||
service_providers = service_providers.filter_by(enabled=True)
|
service_providers = service_providers.filter_by(enabled=True)
|
||||||
return service_providers
|
return service_providers
|
||||||
|
@ -169,10 +169,10 @@ class Federation(core.FederationDriverV9):
|
|||||||
def create_idp(self, idp_id, idp):
|
def create_idp(self, idp_id, idp):
|
||||||
idp['id'] = idp_id
|
idp['id'] = idp_id
|
||||||
try:
|
try:
|
||||||
with sql.transaction() as session:
|
with sql.session_for_write() as session:
|
||||||
idp_ref = IdentityProviderModel.from_dict(idp)
|
idp_ref = IdentityProviderModel.from_dict(idp)
|
||||||
session.add(idp_ref)
|
session.add(idp_ref)
|
||||||
return idp_ref.to_dict()
|
return idp_ref.to_dict()
|
||||||
except sql.DBDuplicateEntry as e:
|
except sql.DBDuplicateEntry as e:
|
||||||
conflict_type = 'identity_provider'
|
conflict_type = 'identity_provider'
|
||||||
details = six.text_type(e)
|
details = six.text_type(e)
|
||||||
@ -186,7 +186,7 @@ class Federation(core.FederationDriverV9):
|
|||||||
raise exception.Conflict(type=conflict_type, details=msg)
|
raise exception.Conflict(type=conflict_type, details=msg)
|
||||||
|
|
||||||
def delete_idp(self, idp_id):
|
def delete_idp(self, idp_id):
|
||||||
with sql.transaction() as session:
|
with sql.session_for_write() as session:
|
||||||
self._delete_assigned_protocols(session, idp_id)
|
self._delete_assigned_protocols(session, idp_id)
|
||||||
idp_ref = self._get_idp(session, idp_id)
|
idp_ref = self._get_idp(session, idp_id)
|
||||||
session.delete(idp_ref)
|
session.delete(idp_ref)
|
||||||
@ -206,31 +206,31 @@ class Federation(core.FederationDriverV9):
|
|||||||
raise exception.IdentityProviderNotFound(idp_id=remote_id)
|
raise exception.IdentityProviderNotFound(idp_id=remote_id)
|
||||||
|
|
||||||
def list_idps(self, hints=None):
|
def list_idps(self, hints=None):
|
||||||
with sql.transaction() as session:
|
with sql.session_for_read() as session:
|
||||||
query = session.query(IdentityProviderModel)
|
query = session.query(IdentityProviderModel)
|
||||||
idps = sql.filter_limit_query(IdentityProviderModel, query, hints)
|
idps = sql.filter_limit_query(IdentityProviderModel, query, hints)
|
||||||
idps_list = [idp.to_dict() for idp in idps]
|
idps_list = [idp.to_dict() for idp in idps]
|
||||||
return idps_list
|
return idps_list
|
||||||
|
|
||||||
def get_idp(self, idp_id):
|
def get_idp(self, idp_id):
|
||||||
with sql.transaction() as session:
|
with sql.session_for_read() as session:
|
||||||
idp_ref = self._get_idp(session, idp_id)
|
idp_ref = self._get_idp(session, idp_id)
|
||||||
return idp_ref.to_dict()
|
return idp_ref.to_dict()
|
||||||
|
|
||||||
def get_idp_from_remote_id(self, remote_id):
|
def get_idp_from_remote_id(self, remote_id):
|
||||||
with sql.transaction() as session:
|
with sql.session_for_read() as session:
|
||||||
ref = self._get_idp_from_remote_id(session, remote_id)
|
ref = self._get_idp_from_remote_id(session, remote_id)
|
||||||
return ref.to_dict()
|
return ref.to_dict()
|
||||||
|
|
||||||
def update_idp(self, idp_id, idp):
|
def update_idp(self, idp_id, idp):
|
||||||
with sql.transaction() as session:
|
with sql.session_for_write() as session:
|
||||||
idp_ref = self._get_idp(session, idp_id)
|
idp_ref = self._get_idp(session, idp_id)
|
||||||
old_idp = idp_ref.to_dict()
|
old_idp = idp_ref.to_dict()
|
||||||
old_idp.update(idp)
|
old_idp.update(idp)
|
||||||
new_idp = IdentityProviderModel.from_dict(old_idp)
|
new_idp = IdentityProviderModel.from_dict(old_idp)
|
||||||
for attr in IdentityProviderModel.mutable_attributes:
|
for attr in IdentityProviderModel.mutable_attributes:
|
||||||
setattr(idp_ref, attr, getattr(new_idp, attr))
|
setattr(idp_ref, attr, getattr(new_idp, attr))
|
||||||
return idp_ref.to_dict()
|
return idp_ref.to_dict()
|
||||||
|
|
||||||
# Protocol CRUD
|
# Protocol CRUD
|
||||||
def _get_protocol(self, session, idp_id, protocol_id):
|
def _get_protocol(self, session, idp_id, protocol_id):
|
||||||
@ -247,36 +247,36 @@ class Federation(core.FederationDriverV9):
|
|||||||
def create_protocol(self, idp_id, protocol_id, protocol):
|
def create_protocol(self, idp_id, protocol_id, protocol):
|
||||||
protocol['id'] = protocol_id
|
protocol['id'] = protocol_id
|
||||||
protocol['idp_id'] = idp_id
|
protocol['idp_id'] = idp_id
|
||||||
with sql.transaction() as session:
|
with sql.session_for_write() as session:
|
||||||
self._get_idp(session, idp_id)
|
self._get_idp(session, idp_id)
|
||||||
protocol_ref = FederationProtocolModel.from_dict(protocol)
|
protocol_ref = FederationProtocolModel.from_dict(protocol)
|
||||||
session.add(protocol_ref)
|
session.add(protocol_ref)
|
||||||
return protocol_ref.to_dict()
|
return protocol_ref.to_dict()
|
||||||
|
|
||||||
def update_protocol(self, idp_id, protocol_id, protocol):
|
def update_protocol(self, idp_id, protocol_id, protocol):
|
||||||
with sql.transaction() as session:
|
with sql.session_for_write() as session:
|
||||||
proto_ref = self._get_protocol(session, idp_id, protocol_id)
|
proto_ref = self._get_protocol(session, idp_id, protocol_id)
|
||||||
old_proto = proto_ref.to_dict()
|
old_proto = proto_ref.to_dict()
|
||||||
old_proto.update(protocol)
|
old_proto.update(protocol)
|
||||||
new_proto = FederationProtocolModel.from_dict(old_proto)
|
new_proto = FederationProtocolModel.from_dict(old_proto)
|
||||||
for attr in FederationProtocolModel.mutable_attributes:
|
for attr in FederationProtocolModel.mutable_attributes:
|
||||||
setattr(proto_ref, attr, getattr(new_proto, attr))
|
setattr(proto_ref, attr, getattr(new_proto, attr))
|
||||||
return proto_ref.to_dict()
|
return proto_ref.to_dict()
|
||||||
|
|
||||||
def get_protocol(self, idp_id, protocol_id):
|
def get_protocol(self, idp_id, protocol_id):
|
||||||
with sql.transaction() as session:
|
with sql.session_for_read() as session:
|
||||||
protocol_ref = self._get_protocol(session, idp_id, protocol_id)
|
protocol_ref = self._get_protocol(session, idp_id, protocol_id)
|
||||||
return protocol_ref.to_dict()
|
return protocol_ref.to_dict()
|
||||||
|
|
||||||
def list_protocols(self, idp_id):
|
def list_protocols(self, idp_id):
|
||||||
with sql.transaction() as session:
|
with sql.session_for_read() as session:
|
||||||
q = session.query(FederationProtocolModel)
|
q = session.query(FederationProtocolModel)
|
||||||
q = q.filter_by(idp_id=idp_id)
|
q = q.filter_by(idp_id=idp_id)
|
||||||
protocols = [protocol.to_dict() for protocol in q]
|
protocols = [protocol.to_dict() for protocol in q]
|
||||||
return protocols
|
return protocols
|
||||||
|
|
||||||
def delete_protocol(self, idp_id, protocol_id):
|
def delete_protocol(self, idp_id, protocol_id):
|
||||||
with sql.transaction() as session:
|
with sql.session_for_write() as session:
|
||||||
key_ref = self._get_protocol(session, idp_id, protocol_id)
|
key_ref = self._get_protocol(session, idp_id, protocol_id)
|
||||||
session.delete(key_ref)
|
session.delete(key_ref)
|
||||||
|
|
||||||
@ -297,58 +297,58 @@ class Federation(core.FederationDriverV9):
|
|||||||
ref = {}
|
ref = {}
|
||||||
ref['id'] = mapping_id
|
ref['id'] = mapping_id
|
||||||
ref['rules'] = mapping.get('rules')
|
ref['rules'] = mapping.get('rules')
|
||||||
with sql.transaction() as session:
|
with sql.session_for_write() as session:
|
||||||
mapping_ref = MappingModel.from_dict(ref)
|
mapping_ref = MappingModel.from_dict(ref)
|
||||||
session.add(mapping_ref)
|
session.add(mapping_ref)
|
||||||
return mapping_ref.to_dict()
|
return mapping_ref.to_dict()
|
||||||
|
|
||||||
def delete_mapping(self, mapping_id):
|
def delete_mapping(self, mapping_id):
|
||||||
with sql.transaction() as session:
|
with sql.session_for_write() as session:
|
||||||
mapping_ref = self._get_mapping(session, mapping_id)
|
mapping_ref = self._get_mapping(session, mapping_id)
|
||||||
session.delete(mapping_ref)
|
session.delete(mapping_ref)
|
||||||
|
|
||||||
def list_mappings(self):
|
def list_mappings(self):
|
||||||
with sql.transaction() as session:
|
with sql.session_for_read() as session:
|
||||||
mappings = session.query(MappingModel)
|
mappings = session.query(MappingModel)
|
||||||
return [x.to_dict() for x in mappings]
|
return [x.to_dict() for x in mappings]
|
||||||
|
|
||||||
def get_mapping(self, mapping_id):
|
def get_mapping(self, mapping_id):
|
||||||
with sql.transaction() as session:
|
with sql.session_for_read() as session:
|
||||||
mapping_ref = self._get_mapping(session, mapping_id)
|
mapping_ref = self._get_mapping(session, mapping_id)
|
||||||
return mapping_ref.to_dict()
|
return mapping_ref.to_dict()
|
||||||
|
|
||||||
@sql.handle_conflicts(conflict_type='mapping')
|
@sql.handle_conflicts(conflict_type='mapping')
|
||||||
def update_mapping(self, mapping_id, mapping):
|
def update_mapping(self, mapping_id, mapping):
|
||||||
ref = {}
|
ref = {}
|
||||||
ref['id'] = mapping_id
|
ref['id'] = mapping_id
|
||||||
ref['rules'] = mapping.get('rules')
|
ref['rules'] = mapping.get('rules')
|
||||||
with sql.transaction() as session:
|
with sql.session_for_write() as session:
|
||||||
mapping_ref = self._get_mapping(session, mapping_id)
|
mapping_ref = self._get_mapping(session, mapping_id)
|
||||||
old_mapping = mapping_ref.to_dict()
|
old_mapping = mapping_ref.to_dict()
|
||||||
old_mapping.update(ref)
|
old_mapping.update(ref)
|
||||||
new_mapping = MappingModel.from_dict(old_mapping)
|
new_mapping = MappingModel.from_dict(old_mapping)
|
||||||
for attr in MappingModel.attributes:
|
for attr in MappingModel.attributes:
|
||||||
setattr(mapping_ref, attr, getattr(new_mapping, attr))
|
setattr(mapping_ref, attr, getattr(new_mapping, attr))
|
||||||
return mapping_ref.to_dict()
|
return mapping_ref.to_dict()
|
||||||
|
|
||||||
def get_mapping_from_idp_and_protocol(self, idp_id, protocol_id):
|
def get_mapping_from_idp_and_protocol(self, idp_id, protocol_id):
|
||||||
with sql.transaction() as session:
|
with sql.session_for_read() as session:
|
||||||
protocol_ref = self._get_protocol(session, idp_id, protocol_id)
|
protocol_ref = self._get_protocol(session, idp_id, protocol_id)
|
||||||
mapping_id = protocol_ref.mapping_id
|
mapping_id = protocol_ref.mapping_id
|
||||||
mapping_ref = self._get_mapping(session, mapping_id)
|
mapping_ref = self._get_mapping(session, mapping_id)
|
||||||
return mapping_ref.to_dict()
|
return mapping_ref.to_dict()
|
||||||
|
|
||||||
# Service Provider CRUD
|
# Service Provider CRUD
|
||||||
@sql.handle_conflicts(conflict_type='service_provider')
|
@sql.handle_conflicts(conflict_type='service_provider')
|
||||||
def create_sp(self, sp_id, sp):
|
def create_sp(self, sp_id, sp):
|
||||||
sp['id'] = sp_id
|
sp['id'] = sp_id
|
||||||
with sql.transaction() as session:
|
with sql.session_for_write() as session:
|
||||||
sp_ref = ServiceProviderModel.from_dict(sp)
|
sp_ref = ServiceProviderModel.from_dict(sp)
|
||||||
session.add(sp_ref)
|
session.add(sp_ref)
|
||||||
return sp_ref.to_dict()
|
return sp_ref.to_dict()
|
||||||
|
|
||||||
def delete_sp(self, sp_id):
|
def delete_sp(self, sp_id):
|
||||||
with sql.transaction() as session:
|
with sql.session_for_write() as session:
|
||||||
sp_ref = self._get_sp(session, sp_id)
|
sp_ref = self._get_sp(session, sp_id)
|
||||||
session.delete(sp_ref)
|
session.delete(sp_ref)
|
||||||
|
|
||||||
@ -359,28 +359,28 @@ class Federation(core.FederationDriverV9):
|
|||||||
return sp_ref
|
return sp_ref
|
||||||
|
|
||||||
def list_sps(self):
|
def list_sps(self):
|
||||||
with sql.transaction() as session:
|
with sql.session_for_read() as session:
|
||||||
sps = session.query(ServiceProviderModel)
|
sps = session.query(ServiceProviderModel)
|
||||||
sps_list = [sp.to_dict() for sp in sps]
|
sps_list = [sp.to_dict() for sp in sps]
|
||||||
return sps_list
|
return sps_list
|
||||||
|
|
||||||
def get_sp(self, sp_id):
|
def get_sp(self, sp_id):
|
||||||
with sql.transaction() as session:
|
with sql.session_for_read() as session:
|
||||||
sp_ref = self._get_sp(session, sp_id)
|
sp_ref = self._get_sp(session, sp_id)
|
||||||
return sp_ref.to_dict()
|
return sp_ref.to_dict()
|
||||||
|
|
||||||
def update_sp(self, sp_id, sp):
|
def update_sp(self, sp_id, sp):
|
||||||
with sql.transaction() as session:
|
with sql.session_for_write() as session:
|
||||||
sp_ref = self._get_sp(session, sp_id)
|
sp_ref = self._get_sp(session, sp_id)
|
||||||
old_sp = sp_ref.to_dict()
|
old_sp = sp_ref.to_dict()
|
||||||
old_sp.update(sp)
|
old_sp.update(sp)
|
||||||
new_sp = ServiceProviderModel.from_dict(old_sp)
|
new_sp = ServiceProviderModel.from_dict(old_sp)
|
||||||
for attr in ServiceProviderModel.mutable_attributes:
|
for attr in ServiceProviderModel.mutable_attributes:
|
||||||
setattr(sp_ref, attr, getattr(new_sp, attr))
|
setattr(sp_ref, attr, getattr(new_sp, attr))
|
||||||
return sp_ref.to_dict()
|
return sp_ref.to_dict()
|
||||||
|
|
||||||
def get_enabled_service_providers(self):
|
def get_enabled_service_providers(self):
|
||||||
with sql.transaction() as session:
|
with sql.session_for_read() as session:
|
||||||
service_providers = session.query(ServiceProviderModel)
|
service_providers = session.query(ServiceProviderModel)
|
||||||
service_providers = service_providers.filter_by(enabled=True)
|
service_providers = service_providers.filter_by(enabled=True)
|
||||||
return service_providers
|
return service_providers
|
||||||
|
@ -178,33 +178,32 @@ class Identity(identity.IdentityDriverV8):
|
|||||||
|
|
||||||
# Identity interface
|
# Identity interface
|
||||||
def authenticate(self, user_id, password):
|
def authenticate(self, user_id, password):
|
||||||
session = sql.get_session()
|
with sql.session_for_read() as session:
|
||||||
user_ref = None
|
user_ref = None
|
||||||
try:
|
try:
|
||||||
user_ref = self._get_user(session, user_id)
|
user_ref = self._get_user(session, user_id)
|
||||||
except exception.UserNotFound:
|
except exception.UserNotFound:
|
||||||
raise AssertionError(_('Invalid user / password'))
|
raise AssertionError(_('Invalid user / password'))
|
||||||
if not self._check_password(password, user_ref):
|
if not self._check_password(password, user_ref):
|
||||||
raise AssertionError(_('Invalid user / password'))
|
raise AssertionError(_('Invalid user / password'))
|
||||||
return identity.filter_user(user_ref.to_dict())
|
return identity.filter_user(user_ref.to_dict())
|
||||||
|
|
||||||
# user crud
|
# user crud
|
||||||
|
|
||||||
@sql.handle_conflicts(conflict_type='user')
|
@sql.handle_conflicts(conflict_type='user')
|
||||||
def create_user(self, user_id, user):
|
def create_user(self, user_id, user):
|
||||||
user = utils.hash_user_password(user)
|
user = utils.hash_user_password(user)
|
||||||
session = sql.get_session()
|
with sql.session_for_write() as session:
|
||||||
with session.begin():
|
|
||||||
user_ref = User.from_dict(user)
|
user_ref = User.from_dict(user)
|
||||||
session.add(user_ref)
|
session.add(user_ref)
|
||||||
return identity.filter_user(user_ref.to_dict())
|
return identity.filter_user(user_ref.to_dict())
|
||||||
|
|
||||||
@driver_hints.truncated
|
@driver_hints.truncated
|
||||||
def list_users(self, hints):
|
def list_users(self, hints):
|
||||||
session = sql.get_session()
|
with sql.session_for_read() as session:
|
||||||
query = session.query(User).outerjoin(LocalUser)
|
query = session.query(User).outerjoin(LocalUser)
|
||||||
user_refs = sql.filter_limit_query(User, query, hints)
|
user_refs = sql.filter_limit_query(User, query, hints)
|
||||||
return [identity.filter_user(x.to_dict()) for x in user_refs]
|
return [identity.filter_user(x.to_dict()) for x in user_refs]
|
||||||
|
|
||||||
def _get_user(self, session, user_id):
|
def _get_user(self, session, user_id):
|
||||||
user_ref = session.query(User).get(user_id)
|
user_ref = session.query(User).get(user_id)
|
||||||
@ -213,25 +212,24 @@ class Identity(identity.IdentityDriverV8):
|
|||||||
return user_ref
|
return user_ref
|
||||||
|
|
||||||
def get_user(self, user_id):
|
def get_user(self, user_id):
|
||||||
session = sql.get_session()
|
with sql.session_for_read() as session:
|
||||||
return identity.filter_user(self._get_user(session, user_id).to_dict())
|
return identity.filter_user(
|
||||||
|
self._get_user(session, user_id).to_dict())
|
||||||
|
|
||||||
def get_user_by_name(self, user_name, domain_id):
|
def get_user_by_name(self, user_name, domain_id):
|
||||||
session = sql.get_session()
|
with sql.session_for_read() as session:
|
||||||
query = session.query(User).join(LocalUser)
|
query = session.query(User).join(LocalUser)
|
||||||
query = query.filter(and_(LocalUser.name == user_name,
|
query = query.filter(and_(LocalUser.name == user_name,
|
||||||
LocalUser.domain_id == domain_id))
|
LocalUser.domain_id == domain_id))
|
||||||
try:
|
try:
|
||||||
user_ref = query.one()
|
user_ref = query.one()
|
||||||
except sql.NotFound:
|
except sql.NotFound:
|
||||||
raise exception.UserNotFound(user_id=user_name)
|
raise exception.UserNotFound(user_id=user_name)
|
||||||
return identity.filter_user(user_ref.to_dict())
|
return identity.filter_user(user_ref.to_dict())
|
||||||
|
|
||||||
@sql.handle_conflicts(conflict_type='user')
|
@sql.handle_conflicts(conflict_type='user')
|
||||||
def update_user(self, user_id, user):
|
def update_user(self, user_id, user):
|
||||||
session = sql.get_session()
|
with sql.session_for_write() as session:
|
||||||
|
|
||||||
with session.begin():
|
|
||||||
user_ref = self._get_user(session, user_id)
|
user_ref = self._get_user(session, user_id)
|
||||||
old_user_dict = user_ref.to_dict()
|
old_user_dict = user_ref.to_dict()
|
||||||
user = utils.hash_user_password(user)
|
user = utils.hash_user_password(user)
|
||||||
@ -242,77 +240,74 @@ class Identity(identity.IdentityDriverV8):
|
|||||||
if attr != 'id':
|
if attr != 'id':
|
||||||
setattr(user_ref, attr, getattr(new_user, attr))
|
setattr(user_ref, attr, getattr(new_user, attr))
|
||||||
user_ref.extra = new_user.extra
|
user_ref.extra = new_user.extra
|
||||||
return identity.filter_user(user_ref.to_dict(include_extra_dict=True))
|
return identity.filter_user(
|
||||||
|
user_ref.to_dict(include_extra_dict=True))
|
||||||
|
|
||||||
def add_user_to_group(self, user_id, group_id):
|
def add_user_to_group(self, user_id, group_id):
|
||||||
session = sql.get_session()
|
with sql.session_for_write() as session:
|
||||||
self.get_group(group_id)
|
self.get_group(group_id)
|
||||||
self.get_user(user_id)
|
self.get_user(user_id)
|
||||||
query = session.query(UserGroupMembership)
|
query = session.query(UserGroupMembership)
|
||||||
query = query.filter_by(user_id=user_id)
|
query = query.filter_by(user_id=user_id)
|
||||||
query = query.filter_by(group_id=group_id)
|
query = query.filter_by(group_id=group_id)
|
||||||
rv = query.first()
|
rv = query.first()
|
||||||
if rv:
|
if rv:
|
||||||
return
|
return
|
||||||
|
|
||||||
with session.begin():
|
|
||||||
session.add(UserGroupMembership(user_id=user_id,
|
session.add(UserGroupMembership(user_id=user_id,
|
||||||
group_id=group_id))
|
group_id=group_id))
|
||||||
|
|
||||||
def check_user_in_group(self, user_id, group_id):
|
def check_user_in_group(self, user_id, group_id):
|
||||||
session = sql.get_session()
|
with sql.session_for_read() as session:
|
||||||
self.get_group(group_id)
|
|
||||||
self.get_user(user_id)
|
|
||||||
query = session.query(UserGroupMembership)
|
|
||||||
query = query.filter_by(user_id=user_id)
|
|
||||||
query = query.filter_by(group_id=group_id)
|
|
||||||
if not query.first():
|
|
||||||
raise exception.NotFound(_("User '%(user_id)s' not found in"
|
|
||||||
" group '%(group_id)s'") %
|
|
||||||
{'user_id': user_id,
|
|
||||||
'group_id': group_id})
|
|
||||||
|
|
||||||
def remove_user_from_group(self, user_id, group_id):
|
|
||||||
session = sql.get_session()
|
|
||||||
# We don't check if user or group are still valid and let the remove
|
|
||||||
# be tried anyway - in case this is some kind of clean-up operation
|
|
||||||
query = session.query(UserGroupMembership)
|
|
||||||
query = query.filter_by(user_id=user_id)
|
|
||||||
query = query.filter_by(group_id=group_id)
|
|
||||||
membership_ref = query.first()
|
|
||||||
if membership_ref is None:
|
|
||||||
# Check if the group and user exist to return descriptive
|
|
||||||
# exceptions.
|
|
||||||
self.get_group(group_id)
|
self.get_group(group_id)
|
||||||
self.get_user(user_id)
|
self.get_user(user_id)
|
||||||
raise exception.NotFound(_("User '%(user_id)s' not found in"
|
query = session.query(UserGroupMembership)
|
||||||
" group '%(group_id)s'") %
|
query = query.filter_by(user_id=user_id)
|
||||||
{'user_id': user_id,
|
query = query.filter_by(group_id=group_id)
|
||||||
'group_id': group_id})
|
if not query.first():
|
||||||
with session.begin():
|
raise exception.NotFound(_("User '%(user_id)s' not found in"
|
||||||
|
" group '%(group_id)s'") %
|
||||||
|
{'user_id': user_id,
|
||||||
|
'group_id': group_id})
|
||||||
|
|
||||||
|
def remove_user_from_group(self, user_id, group_id):
|
||||||
|
# We don't check if user or group are still valid and let the remove
|
||||||
|
# be tried anyway - in case this is some kind of clean-up operation
|
||||||
|
with sql.session_for_write() as session:
|
||||||
|
query = session.query(UserGroupMembership)
|
||||||
|
query = query.filter_by(user_id=user_id)
|
||||||
|
query = query.filter_by(group_id=group_id)
|
||||||
|
membership_ref = query.first()
|
||||||
|
if membership_ref is None:
|
||||||
|
# Check if the group and user exist to return descriptive
|
||||||
|
# exceptions.
|
||||||
|
self.get_group(group_id)
|
||||||
|
self.get_user(user_id)
|
||||||
|
raise exception.NotFound(_("User '%(user_id)s' not found in"
|
||||||
|
" group '%(group_id)s'") %
|
||||||
|
{'user_id': user_id,
|
||||||
|
'group_id': group_id})
|
||||||
session.delete(membership_ref)
|
session.delete(membership_ref)
|
||||||
|
|
||||||
def list_groups_for_user(self, user_id, hints):
|
def list_groups_for_user(self, user_id, hints):
|
||||||
session = sql.get_session()
|
with sql.session_for_read() as session:
|
||||||
self.get_user(user_id)
|
self.get_user(user_id)
|
||||||
query = session.query(Group).join(UserGroupMembership)
|
query = session.query(Group).join(UserGroupMembership)
|
||||||
query = query.filter(UserGroupMembership.user_id == user_id)
|
query = query.filter(UserGroupMembership.user_id == user_id)
|
||||||
query = sql.filter_limit_query(Group, query, hints)
|
query = sql.filter_limit_query(Group, query, hints)
|
||||||
return [g.to_dict() for g in query]
|
return [g.to_dict() for g in query]
|
||||||
|
|
||||||
def list_users_in_group(self, group_id, hints):
|
def list_users_in_group(self, group_id, hints):
|
||||||
session = sql.get_session()
|
with sql.session_for_read() as session:
|
||||||
self.get_group(group_id)
|
self.get_group(group_id)
|
||||||
query = session.query(User).outerjoin(LocalUser)
|
query = session.query(User).outerjoin(LocalUser)
|
||||||
query = query.join(UserGroupMembership)
|
query = query.join(UserGroupMembership)
|
||||||
query = query.filter(UserGroupMembership.group_id == group_id)
|
query = query.filter(UserGroupMembership.group_id == group_id)
|
||||||
query = sql.filter_limit_query(User, query, hints)
|
query = sql.filter_limit_query(User, query, hints)
|
||||||
return [identity.filter_user(u.to_dict()) for u in query]
|
return [identity.filter_user(u.to_dict()) for u in query]
|
||||||
|
|
||||||
def delete_user(self, user_id):
|
def delete_user(self, user_id):
|
||||||
session = sql.get_session()
|
with sql.session_for_write() as session:
|
||||||
|
|
||||||
with session.begin():
|
|
||||||
ref = self._get_user(session, user_id)
|
ref = self._get_user(session, user_id)
|
||||||
|
|
||||||
q = session.query(UserGroupMembership)
|
q = session.query(UserGroupMembership)
|
||||||
@ -325,18 +320,17 @@ class Identity(identity.IdentityDriverV8):
|
|||||||
|
|
||||||
@sql.handle_conflicts(conflict_type='group')
|
@sql.handle_conflicts(conflict_type='group')
|
||||||
def create_group(self, group_id, group):
|
def create_group(self, group_id, group):
|
||||||
session = sql.get_session()
|
with sql.session_for_write() as session:
|
||||||
with session.begin():
|
|
||||||
ref = Group.from_dict(group)
|
ref = Group.from_dict(group)
|
||||||
session.add(ref)
|
session.add(ref)
|
||||||
return ref.to_dict()
|
return ref.to_dict()
|
||||||
|
|
||||||
@driver_hints.truncated
|
@driver_hints.truncated
|
||||||
def list_groups(self, hints):
|
def list_groups(self, hints):
|
||||||
session = sql.get_session()
|
with sql.session_for_read() as session:
|
||||||
query = session.query(Group)
|
query = session.query(Group)
|
||||||
refs = sql.filter_limit_query(Group, query, hints)
|
refs = sql.filter_limit_query(Group, query, hints)
|
||||||
return [ref.to_dict() for ref in refs]
|
return [ref.to_dict() for ref in refs]
|
||||||
|
|
||||||
def _get_group(self, session, group_id):
|
def _get_group(self, session, group_id):
|
||||||
ref = session.query(Group).get(group_id)
|
ref = session.query(Group).get(group_id)
|
||||||
@ -345,25 +339,23 @@ class Identity(identity.IdentityDriverV8):
|
|||||||
return ref
|
return ref
|
||||||
|
|
||||||
def get_group(self, group_id):
|
def get_group(self, group_id):
|
||||||
session = sql.get_session()
|
with sql.session_for_read() as session:
|
||||||
return self._get_group(session, group_id).to_dict()
|
return self._get_group(session, group_id).to_dict()
|
||||||
|
|
||||||
def get_group_by_name(self, group_name, domain_id):
|
def get_group_by_name(self, group_name, domain_id):
|
||||||
session = sql.get_session()
|
with sql.session_for_read() as session:
|
||||||
query = session.query(Group)
|
query = session.query(Group)
|
||||||
query = query.filter_by(name=group_name)
|
query = query.filter_by(name=group_name)
|
||||||
query = query.filter_by(domain_id=domain_id)
|
query = query.filter_by(domain_id=domain_id)
|
||||||
try:
|
try:
|
||||||
group_ref = query.one()
|
group_ref = query.one()
|
||||||
except sql.NotFound:
|
except sql.NotFound:
|
||||||
raise exception.GroupNotFound(group_id=group_name)
|
raise exception.GroupNotFound(group_id=group_name)
|
||||||
return group_ref.to_dict()
|
return group_ref.to_dict()
|
||||||
|
|
||||||
@sql.handle_conflicts(conflict_type='group')
|
@sql.handle_conflicts(conflict_type='group')
|
||||||
def update_group(self, group_id, group):
|
def update_group(self, group_id, group):
|
||||||
session = sql.get_session()
|
with sql.session_for_write() as session:
|
||||||
|
|
||||||
with session.begin():
|
|
||||||
ref = self._get_group(session, group_id)
|
ref = self._get_group(session, group_id)
|
||||||
old_dict = ref.to_dict()
|
old_dict = ref.to_dict()
|
||||||
for k in group:
|
for k in group:
|
||||||
@ -373,12 +365,10 @@ class Identity(identity.IdentityDriverV8):
|
|||||||
if attr != 'id':
|
if attr != 'id':
|
||||||
setattr(ref, attr, getattr(new_group, attr))
|
setattr(ref, attr, getattr(new_group, attr))
|
||||||
ref.extra = new_group.extra
|
ref.extra = new_group.extra
|
||||||
return ref.to_dict()
|
return ref.to_dict()
|
||||||
|
|
||||||
def delete_group(self, group_id):
|
def delete_group(self, group_id):
|
||||||
session = sql.get_session()
|
with sql.session_for_write() as session:
|
||||||
|
|
||||||
with session.begin():
|
|
||||||
ref = self._get_group(session, group_id)
|
ref = self._get_group(session, group_id)
|
||||||
|
|
||||||
q = session.query(UserGroupMembership)
|
q = session.query(UserGroupMembership)
|
||||||
|
@ -45,27 +45,27 @@ class Mapping(identity.MappingDriverV8):
|
|||||||
# work if we hashed all the entries, even those that already generate
|
# work if we hashed all the entries, even those that already generate
|
||||||
# UUIDs, like SQL. Further, this would only work if the generation
|
# UUIDs, like SQL. Further, this would only work if the generation
|
||||||
# algorithm was immutable (e.g. it had always been sha256).
|
# algorithm was immutable (e.g. it had always been sha256).
|
||||||
session = sql.get_session()
|
with sql.session_for_read() as session:
|
||||||
query = session.query(IDMapping.public_id)
|
query = session.query(IDMapping.public_id)
|
||||||
query = query.filter_by(domain_id=local_entity['domain_id'])
|
query = query.filter_by(domain_id=local_entity['domain_id'])
|
||||||
query = query.filter_by(local_id=local_entity['local_id'])
|
query = query.filter_by(local_id=local_entity['local_id'])
|
||||||
query = query.filter_by(entity_type=local_entity['entity_type'])
|
query = query.filter_by(entity_type=local_entity['entity_type'])
|
||||||
try:
|
try:
|
||||||
public_ref = query.one()
|
public_ref = query.one()
|
||||||
public_id = public_ref.public_id
|
public_id = public_ref.public_id
|
||||||
return public_id
|
return public_id
|
||||||
except sql.NotFound:
|
except sql.NotFound:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def get_id_mapping(self, public_id):
|
def get_id_mapping(self, public_id):
|
||||||
session = sql.get_session()
|
with sql.session_for_read() as session:
|
||||||
mapping_ref = session.query(IDMapping).get(public_id)
|
mapping_ref = session.query(IDMapping).get(public_id)
|
||||||
if mapping_ref:
|
if mapping_ref:
|
||||||
return mapping_ref.to_dict()
|
return mapping_ref.to_dict()
|
||||||
|
|
||||||
def create_id_mapping(self, local_entity, public_id=None):
|
def create_id_mapping(self, local_entity, public_id=None):
|
||||||
entity = local_entity.copy()
|
entity = local_entity.copy()
|
||||||
with sql.transaction() as session:
|
with sql.session_for_write() as session:
|
||||||
if public_id is None:
|
if public_id is None:
|
||||||
public_id = self.id_generator_api.generate_public_ID(entity)
|
public_id = self.id_generator_api.generate_public_ID(entity)
|
||||||
entity['public_id'] = public_id
|
entity['public_id'] = public_id
|
||||||
@ -74,7 +74,7 @@ class Mapping(identity.MappingDriverV8):
|
|||||||
return public_id
|
return public_id
|
||||||
|
|
||||||
def delete_id_mapping(self, public_id):
|
def delete_id_mapping(self, public_id):
|
||||||
with sql.transaction() as session:
|
with sql.session_for_write() as session:
|
||||||
try:
|
try:
|
||||||
session.query(IDMapping).filter(
|
session.query(IDMapping).filter(
|
||||||
IDMapping.public_id == public_id).delete()
|
IDMapping.public_id == public_id).delete()
|
||||||
@ -84,14 +84,15 @@ class Mapping(identity.MappingDriverV8):
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
def purge_mappings(self, purge_filter):
|
def purge_mappings(self, purge_filter):
|
||||||
session = sql.get_session()
|
with sql.session_for_write() as session:
|
||||||
query = session.query(IDMapping)
|
query = session.query(IDMapping)
|
||||||
if 'domain_id' in purge_filter:
|
if 'domain_id' in purge_filter:
|
||||||
query = query.filter_by(domain_id=purge_filter['domain_id'])
|
query = query.filter_by(domain_id=purge_filter['domain_id'])
|
||||||
if 'public_id' in purge_filter:
|
if 'public_id' in purge_filter:
|
||||||
query = query.filter_by(public_id=purge_filter['public_id'])
|
query = query.filter_by(public_id=purge_filter['public_id'])
|
||||||
if 'local_id' in purge_filter:
|
if 'local_id' in purge_filter:
|
||||||
query = query.filter_by(local_id=purge_filter['local_id'])
|
query = query.filter_by(local_id=purge_filter['local_id'])
|
||||||
if 'entity_type' in purge_filter:
|
if 'entity_type' in purge_filter:
|
||||||
query = query.filter_by(entity_type=purge_filter['entity_type'])
|
query = query.filter_by(
|
||||||
query.delete()
|
entity_type=purge_filter['entity_type'])
|
||||||
|
query.delete()
|
||||||
|
@ -92,17 +92,16 @@ class OAuth1(core.Oauth1DriverV8):
|
|||||||
return consumer_ref
|
return consumer_ref
|
||||||
|
|
||||||
def get_consumer_with_secret(self, consumer_id):
|
def get_consumer_with_secret(self, consumer_id):
|
||||||
session = sql.get_session()
|
with sql.session_for_read() as session:
|
||||||
consumer_ref = self._get_consumer(session, consumer_id)
|
consumer_ref = self._get_consumer(session, consumer_id)
|
||||||
return consumer_ref.to_dict()
|
return consumer_ref.to_dict()
|
||||||
|
|
||||||
def get_consumer(self, consumer_id):
|
def get_consumer(self, consumer_id):
|
||||||
return core.filter_consumer(
|
return core.filter_consumer(
|
||||||
self.get_consumer_with_secret(consumer_id))
|
self.get_consumer_with_secret(consumer_id))
|
||||||
|
|
||||||
def create_consumer(self, consumer_ref):
|
def create_consumer(self, consumer_ref):
|
||||||
session = sql.get_session()
|
with sql.session_for_write() as session:
|
||||||
with session.begin():
|
|
||||||
consumer = Consumer.from_dict(consumer_ref)
|
consumer = Consumer.from_dict(consumer_ref)
|
||||||
session.add(consumer)
|
session.add(consumer)
|
||||||
return consumer.to_dict()
|
return consumer.to_dict()
|
||||||
@ -128,20 +127,18 @@ class OAuth1(core.Oauth1DriverV8):
|
|||||||
session.delete(token_ref)
|
session.delete(token_ref)
|
||||||
|
|
||||||
def delete_consumer(self, consumer_id):
|
def delete_consumer(self, consumer_id):
|
||||||
session = sql.get_session()
|
with sql.session_for_write() as session:
|
||||||
with session.begin():
|
|
||||||
self._delete_request_tokens(session, consumer_id)
|
self._delete_request_tokens(session, consumer_id)
|
||||||
self._delete_access_tokens(session, consumer_id)
|
self._delete_access_tokens(session, consumer_id)
|
||||||
self._delete_consumer(session, consumer_id)
|
self._delete_consumer(session, consumer_id)
|
||||||
|
|
||||||
def list_consumers(self):
|
def list_consumers(self):
|
||||||
session = sql.get_session()
|
with sql.session_for_read() as session:
|
||||||
cons = session.query(Consumer)
|
cons = session.query(Consumer)
|
||||||
return [core.filter_consumer(x.to_dict()) for x in cons]
|
return [core.filter_consumer(x.to_dict()) for x in cons]
|
||||||
|
|
||||||
def update_consumer(self, consumer_id, consumer_ref):
|
def update_consumer(self, consumer_id, consumer_ref):
|
||||||
session = sql.get_session()
|
with sql.session_for_write() as session:
|
||||||
with session.begin():
|
|
||||||
consumer = self._get_consumer(session, consumer_id)
|
consumer = self._get_consumer(session, consumer_id)
|
||||||
old_consumer_dict = consumer.to_dict()
|
old_consumer_dict = consumer.to_dict()
|
||||||
old_consumer_dict.update(consumer_ref)
|
old_consumer_dict.update(consumer_ref)
|
||||||
@ -169,11 +166,10 @@ class OAuth1(core.Oauth1DriverV8):
|
|||||||
ref['role_ids'] = None
|
ref['role_ids'] = None
|
||||||
ref['consumer_id'] = consumer_id
|
ref['consumer_id'] = consumer_id
|
||||||
ref['expires_at'] = expiry_date
|
ref['expires_at'] = expiry_date
|
||||||
session = sql.get_session()
|
with sql.session_for_write() as session:
|
||||||
with session.begin():
|
|
||||||
token_ref = RequestToken.from_dict(ref)
|
token_ref = RequestToken.from_dict(ref)
|
||||||
session.add(token_ref)
|
session.add(token_ref)
|
||||||
return token_ref.to_dict()
|
return token_ref.to_dict()
|
||||||
|
|
||||||
def _get_request_token(self, session, request_token_id):
|
def _get_request_token(self, session, request_token_id):
|
||||||
token_ref = session.query(RequestToken).get(request_token_id)
|
token_ref = session.query(RequestToken).get(request_token_id)
|
||||||
@ -182,14 +178,13 @@ class OAuth1(core.Oauth1DriverV8):
|
|||||||
return token_ref
|
return token_ref
|
||||||
|
|
||||||
def get_request_token(self, request_token_id):
|
def get_request_token(self, request_token_id):
|
||||||
session = sql.get_session()
|
with sql.session_for_read() as session:
|
||||||
token_ref = self._get_request_token(session, request_token_id)
|
token_ref = self._get_request_token(session, request_token_id)
|
||||||
return token_ref.to_dict()
|
return token_ref.to_dict()
|
||||||
|
|
||||||
def authorize_request_token(self, request_token_id, user_id,
|
def authorize_request_token(self, request_token_id, user_id,
|
||||||
role_ids):
|
role_ids):
|
||||||
session = sql.get_session()
|
with sql.session_for_write() as session:
|
||||||
with session.begin():
|
|
||||||
token_ref = self._get_request_token(session, request_token_id)
|
token_ref = self._get_request_token(session, request_token_id)
|
||||||
token_dict = token_ref.to_dict()
|
token_dict = token_ref.to_dict()
|
||||||
token_dict['authorizing_user_id'] = user_id
|
token_dict['authorizing_user_id'] = user_id
|
||||||
@ -203,13 +198,12 @@ class OAuth1(core.Oauth1DriverV8):
|
|||||||
or attr == 'role_ids'):
|
or attr == 'role_ids'):
|
||||||
setattr(token_ref, attr, getattr(new_token, attr))
|
setattr(token_ref, attr, getattr(new_token, attr))
|
||||||
|
|
||||||
return token_ref.to_dict()
|
return token_ref.to_dict()
|
||||||
|
|
||||||
def create_access_token(self, request_id, access_token_duration):
|
def create_access_token(self, request_id, access_token_duration):
|
||||||
access_token_id = uuid.uuid4().hex
|
access_token_id = uuid.uuid4().hex
|
||||||
access_token_secret = uuid.uuid4().hex
|
access_token_secret = uuid.uuid4().hex
|
||||||
session = sql.get_session()
|
with sql.session_for_write() as session:
|
||||||
with session.begin():
|
|
||||||
req_token_ref = self._get_request_token(session, request_id)
|
req_token_ref = self._get_request_token(session, request_id)
|
||||||
token_dict = req_token_ref.to_dict()
|
token_dict = req_token_ref.to_dict()
|
||||||
|
|
||||||
@ -235,7 +229,7 @@ class OAuth1(core.Oauth1DriverV8):
|
|||||||
# remove request token, it's been used
|
# remove request token, it's been used
|
||||||
session.delete(req_token_ref)
|
session.delete(req_token_ref)
|
||||||
|
|
||||||
return token_ref.to_dict()
|
return token_ref.to_dict()
|
||||||
|
|
||||||
def _get_access_token(self, session, access_token_id):
|
def _get_access_token(self, session, access_token_id):
|
||||||
token_ref = session.query(AccessToken).get(access_token_id)
|
token_ref = session.query(AccessToken).get(access_token_id)
|
||||||
@ -244,19 +238,18 @@ class OAuth1(core.Oauth1DriverV8):
|
|||||||
return token_ref
|
return token_ref
|
||||||
|
|
||||||
def get_access_token(self, access_token_id):
|
def get_access_token(self, access_token_id):
|
||||||
session = sql.get_session()
|
with sql.session_for_read() as session:
|
||||||
token_ref = self._get_access_token(session, access_token_id)
|
token_ref = self._get_access_token(session, access_token_id)
|
||||||
return token_ref.to_dict()
|
return token_ref.to_dict()
|
||||||
|
|
||||||
def list_access_tokens(self, user_id):
|
def list_access_tokens(self, user_id):
|
||||||
session = sql.get_session()
|
with sql.session_for_read() as session:
|
||||||
q = session.query(AccessToken)
|
q = session.query(AccessToken)
|
||||||
user_auths = q.filter_by(authorizing_user_id=user_id)
|
user_auths = q.filter_by(authorizing_user_id=user_id)
|
||||||
return [core.filter_token(x.to_dict()) for x in user_auths]
|
return [core.filter_token(x.to_dict()) for x in user_auths]
|
||||||
|
|
||||||
def delete_access_token(self, user_id, access_token_id):
|
def delete_access_token(self, user_id, access_token_id):
|
||||||
session = sql.get_session()
|
with sql.session_for_write() as session:
|
||||||
with session.begin():
|
|
||||||
token_ref = self._get_access_token(session, access_token_id)
|
token_ref = self._get_access_token(session, access_token_id)
|
||||||
token_dict = token_ref.to_dict()
|
token_dict = token_ref.to_dict()
|
||||||
if token_dict['authorizing_user_id'] != user_id:
|
if token_dict['authorizing_user_id'] != user_id:
|
||||||
|
@ -30,19 +30,16 @@ class Policy(rules.Policy):
|
|||||||
|
|
||||||
@sql.handle_conflicts(conflict_type='policy')
|
@sql.handle_conflicts(conflict_type='policy')
|
||||||
def create_policy(self, policy_id, policy):
|
def create_policy(self, policy_id, policy):
|
||||||
session = sql.get_session()
|
with sql.session_for_write() as session:
|
||||||
|
|
||||||
with session.begin():
|
|
||||||
ref = PolicyModel.from_dict(policy)
|
ref = PolicyModel.from_dict(policy)
|
||||||
session.add(ref)
|
session.add(ref)
|
||||||
|
|
||||||
return ref.to_dict()
|
return ref.to_dict()
|
||||||
|
|
||||||
def list_policies(self):
|
def list_policies(self):
|
||||||
session = sql.get_session()
|
with sql.session_for_read() as session:
|
||||||
|
refs = session.query(PolicyModel).all()
|
||||||
refs = session.query(PolicyModel).all()
|
return [ref.to_dict() for ref in refs]
|
||||||
return [ref.to_dict() for ref in refs]
|
|
||||||
|
|
||||||
def _get_policy(self, session, policy_id):
|
def _get_policy(self, session, policy_id):
|
||||||
"""Private method to get a policy model object (NOT a dictionary)."""
|
"""Private method to get a policy model object (NOT a dictionary)."""
|
||||||
@ -52,15 +49,12 @@ class Policy(rules.Policy):
|
|||||||
return ref
|
return ref
|
||||||
|
|
||||||
def get_policy(self, policy_id):
|
def get_policy(self, policy_id):
|
||||||
session = sql.get_session()
|
with sql.session_for_read() as session:
|
||||||
|
return self._get_policy(session, policy_id).to_dict()
|
||||||
return self._get_policy(session, policy_id).to_dict()
|
|
||||||
|
|
||||||
@sql.handle_conflicts(conflict_type='policy')
|
@sql.handle_conflicts(conflict_type='policy')
|
||||||
def update_policy(self, policy_id, policy):
|
def update_policy(self, policy_id, policy):
|
||||||
session = sql.get_session()
|
with sql.session_for_write() as session:
|
||||||
|
|
||||||
with session.begin():
|
|
||||||
ref = self._get_policy(session, policy_id)
|
ref = self._get_policy(session, policy_id)
|
||||||
old_dict = ref.to_dict()
|
old_dict = ref.to_dict()
|
||||||
old_dict.update(policy)
|
old_dict.update(policy)
|
||||||
@ -72,8 +66,6 @@ class Policy(rules.Policy):
|
|||||||
return ref.to_dict()
|
return ref.to_dict()
|
||||||
|
|
||||||
def delete_policy(self, policy_id):
|
def delete_policy(self, policy_id):
|
||||||
session = sql.get_session()
|
with sql.session_for_write() as session:
|
||||||
|
|
||||||
with session.begin():
|
|
||||||
ref = self._get_policy(session, policy_id)
|
ref = self._get_policy(session, policy_id)
|
||||||
session.delete(ref)
|
session.delete(ref)
|
||||||
|
@ -35,11 +35,11 @@ class Resource(keystone_resource.ResourceDriverV8):
|
|||||||
return project_ref
|
return project_ref
|
||||||
|
|
||||||
def get_project(self, tenant_id):
|
def get_project(self, tenant_id):
|
||||||
with sql.transaction() as session:
|
with sql.session_for_read() as session:
|
||||||
return self._get_project(session, tenant_id).to_dict()
|
return self._get_project(session, tenant_id).to_dict()
|
||||||
|
|
||||||
def get_project_by_name(self, tenant_name, domain_id):
|
def get_project_by_name(self, tenant_name, domain_id):
|
||||||
with sql.transaction() as session:
|
with sql.session_for_read() as session:
|
||||||
query = session.query(Project)
|
query = session.query(Project)
|
||||||
query = query.filter_by(name=tenant_name)
|
query = query.filter_by(name=tenant_name)
|
||||||
query = query.filter_by(domain_id=domain_id)
|
query = query.filter_by(domain_id=domain_id)
|
||||||
@ -51,7 +51,7 @@ class Resource(keystone_resource.ResourceDriverV8):
|
|||||||
|
|
||||||
@driver_hints.truncated
|
@driver_hints.truncated
|
||||||
def list_projects(self, hints):
|
def list_projects(self, hints):
|
||||||
with sql.transaction() as session:
|
with sql.session_for_read() as session:
|
||||||
query = session.query(Project)
|
query = session.query(Project)
|
||||||
project_refs = sql.filter_limit_query(Project, query, hints)
|
project_refs = sql.filter_limit_query(Project, query, hints)
|
||||||
return [project_ref.to_dict() for project_ref in project_refs]
|
return [project_ref.to_dict() for project_ref in project_refs]
|
||||||
@ -60,7 +60,7 @@ class Resource(keystone_resource.ResourceDriverV8):
|
|||||||
if not ids:
|
if not ids:
|
||||||
return []
|
return []
|
||||||
else:
|
else:
|
||||||
with sql.transaction() as session:
|
with sql.session_for_read() as session:
|
||||||
query = session.query(Project)
|
query = session.query(Project)
|
||||||
query = query.filter(Project.id.in_(ids))
|
query = query.filter(Project.id.in_(ids))
|
||||||
return [project_ref.to_dict() for project_ref in query.all()]
|
return [project_ref.to_dict() for project_ref in query.all()]
|
||||||
@ -69,14 +69,14 @@ class Resource(keystone_resource.ResourceDriverV8):
|
|||||||
if not domain_ids:
|
if not domain_ids:
|
||||||
return []
|
return []
|
||||||
else:
|
else:
|
||||||
with sql.transaction() as session:
|
with sql.session_for_read() as session:
|
||||||
query = session.query(Project.id)
|
query = session.query(Project.id)
|
||||||
query = (
|
query = (
|
||||||
query.filter(Project.domain_id.in_(domain_ids)))
|
query.filter(Project.domain_id.in_(domain_ids)))
|
||||||
return [x.id for x in query.all()]
|
return [x.id for x in query.all()]
|
||||||
|
|
||||||
def list_projects_in_domain(self, domain_id):
|
def list_projects_in_domain(self, domain_id):
|
||||||
with sql.transaction() as session:
|
with sql.session_for_read() as session:
|
||||||
self._get_domain(session, domain_id)
|
self._get_domain(session, domain_id)
|
||||||
query = session.query(Project)
|
query = session.query(Project)
|
||||||
project_refs = query.filter_by(domain_id=domain_id)
|
project_refs = query.filter_by(domain_id=domain_id)
|
||||||
@ -89,7 +89,7 @@ class Resource(keystone_resource.ResourceDriverV8):
|
|||||||
return [project_ref.to_dict() for project_ref in project_refs]
|
return [project_ref.to_dict() for project_ref in project_refs]
|
||||||
|
|
||||||
def list_projects_in_subtree(self, project_id):
|
def list_projects_in_subtree(self, project_id):
|
||||||
with sql.transaction() as session:
|
with sql.session_for_read() as session:
|
||||||
children = self._get_children(session, [project_id])
|
children = self._get_children(session, [project_id])
|
||||||
subtree = []
|
subtree = []
|
||||||
examined = set([project_id])
|
examined = set([project_id])
|
||||||
@ -110,7 +110,7 @@ class Resource(keystone_resource.ResourceDriverV8):
|
|||||||
return subtree
|
return subtree
|
||||||
|
|
||||||
def list_project_parents(self, project_id):
|
def list_project_parents(self, project_id):
|
||||||
with sql.transaction() as session:
|
with sql.session_for_read() as session:
|
||||||
project = self._get_project(session, project_id).to_dict()
|
project = self._get_project(session, project_id).to_dict()
|
||||||
parents = []
|
parents = []
|
||||||
examined = set()
|
examined = set()
|
||||||
@ -130,7 +130,7 @@ class Resource(keystone_resource.ResourceDriverV8):
|
|||||||
return parents
|
return parents
|
||||||
|
|
||||||
def is_leaf_project(self, project_id):
|
def is_leaf_project(self, project_id):
|
||||||
with sql.transaction() as session:
|
with sql.session_for_read() as session:
|
||||||
project_refs = self._get_children(session, [project_id])
|
project_refs = self._get_children(session, [project_id])
|
||||||
return not project_refs
|
return not project_refs
|
||||||
|
|
||||||
@ -138,7 +138,7 @@ class Resource(keystone_resource.ResourceDriverV8):
|
|||||||
@sql.handle_conflicts(conflict_type='project')
|
@sql.handle_conflicts(conflict_type='project')
|
||||||
def create_project(self, tenant_id, tenant):
|
def create_project(self, tenant_id, tenant):
|
||||||
tenant['name'] = clean.project_name(tenant['name'])
|
tenant['name'] = clean.project_name(tenant['name'])
|
||||||
with sql.transaction() as session:
|
with sql.session_for_write() as session:
|
||||||
tenant_ref = Project.from_dict(tenant)
|
tenant_ref = Project.from_dict(tenant)
|
||||||
session.add(tenant_ref)
|
session.add(tenant_ref)
|
||||||
return tenant_ref.to_dict()
|
return tenant_ref.to_dict()
|
||||||
@ -148,7 +148,7 @@ class Resource(keystone_resource.ResourceDriverV8):
|
|||||||
if 'name' in tenant:
|
if 'name' in tenant:
|
||||||
tenant['name'] = clean.project_name(tenant['name'])
|
tenant['name'] = clean.project_name(tenant['name'])
|
||||||
|
|
||||||
with sql.transaction() as session:
|
with sql.session_for_write() as session:
|
||||||
tenant_ref = self._get_project(session, tenant_id)
|
tenant_ref = self._get_project(session, tenant_id)
|
||||||
old_project_dict = tenant_ref.to_dict()
|
old_project_dict = tenant_ref.to_dict()
|
||||||
for k in tenant:
|
for k in tenant:
|
||||||
@ -162,7 +162,7 @@ class Resource(keystone_resource.ResourceDriverV8):
|
|||||||
|
|
||||||
@sql.handle_conflicts(conflict_type='project')
|
@sql.handle_conflicts(conflict_type='project')
|
||||||
def delete_project(self, tenant_id):
|
def delete_project(self, tenant_id):
|
||||||
with sql.transaction() as session:
|
with sql.session_for_write() as session:
|
||||||
tenant_ref = self._get_project(session, tenant_id)
|
tenant_ref = self._get_project(session, tenant_id)
|
||||||
session.delete(tenant_ref)
|
session.delete(tenant_ref)
|
||||||
|
|
||||||
@ -170,14 +170,14 @@ class Resource(keystone_resource.ResourceDriverV8):
|
|||||||
|
|
||||||
@sql.handle_conflicts(conflict_type='domain')
|
@sql.handle_conflicts(conflict_type='domain')
|
||||||
def create_domain(self, domain_id, domain):
|
def create_domain(self, domain_id, domain):
|
||||||
with sql.transaction() as session:
|
with sql.session_for_write() as session:
|
||||||
ref = Domain.from_dict(domain)
|
ref = Domain.from_dict(domain)
|
||||||
session.add(ref)
|
session.add(ref)
|
||||||
return ref.to_dict()
|
return ref.to_dict()
|
||||||
|
|
||||||
@driver_hints.truncated
|
@driver_hints.truncated
|
||||||
def list_domains(self, hints):
|
def list_domains(self, hints):
|
||||||
with sql.transaction() as session:
|
with sql.session_for_read() as session:
|
||||||
query = session.query(Domain)
|
query = session.query(Domain)
|
||||||
refs = sql.filter_limit_query(Domain, query, hints)
|
refs = sql.filter_limit_query(Domain, query, hints)
|
||||||
return [ref.to_dict() for ref in refs]
|
return [ref.to_dict() for ref in refs]
|
||||||
@ -186,7 +186,7 @@ class Resource(keystone_resource.ResourceDriverV8):
|
|||||||
if not ids:
|
if not ids:
|
||||||
return []
|
return []
|
||||||
else:
|
else:
|
||||||
with sql.transaction() as session:
|
with sql.session_for_read() as session:
|
||||||
query = session.query(Domain)
|
query = session.query(Domain)
|
||||||
query = query.filter(Domain.id.in_(ids))
|
query = query.filter(Domain.id.in_(ids))
|
||||||
domain_refs = query.all()
|
domain_refs = query.all()
|
||||||
@ -199,11 +199,11 @@ class Resource(keystone_resource.ResourceDriverV8):
|
|||||||
return ref
|
return ref
|
||||||
|
|
||||||
def get_domain(self, domain_id):
|
def get_domain(self, domain_id):
|
||||||
with sql.transaction() as session:
|
with sql.session_for_read() as session:
|
||||||
return self._get_domain(session, domain_id).to_dict()
|
return self._get_domain(session, domain_id).to_dict()
|
||||||
|
|
||||||
def get_domain_by_name(self, domain_name):
|
def get_domain_by_name(self, domain_name):
|
||||||
with sql.transaction() as session:
|
with sql.session_for_read() as session:
|
||||||
try:
|
try:
|
||||||
ref = (session.query(Domain).
|
ref = (session.query(Domain).
|
||||||
filter_by(name=domain_name).one())
|
filter_by(name=domain_name).one())
|
||||||
@ -213,7 +213,7 @@ class Resource(keystone_resource.ResourceDriverV8):
|
|||||||
|
|
||||||
@sql.handle_conflicts(conflict_type='domain')
|
@sql.handle_conflicts(conflict_type='domain')
|
||||||
def update_domain(self, domain_id, domain):
|
def update_domain(self, domain_id, domain):
|
||||||
with sql.transaction() as session:
|
with sql.session_for_write() as session:
|
||||||
ref = self._get_domain(session, domain_id)
|
ref = self._get_domain(session, domain_id)
|
||||||
old_dict = ref.to_dict()
|
old_dict = ref.to_dict()
|
||||||
for k in domain:
|
for k in domain:
|
||||||
@ -226,7 +226,7 @@ class Resource(keystone_resource.ResourceDriverV8):
|
|||||||
return ref.to_dict()
|
return ref.to_dict()
|
||||||
|
|
||||||
def delete_domain(self, domain_id):
|
def delete_domain(self, domain_id):
|
||||||
with sql.transaction() as session:
|
with sql.session_for_write() as session:
|
||||||
ref = self._get_domain(session, domain_id)
|
ref = self._get_domain(session, domain_id)
|
||||||
session.delete(ref)
|
session.delete(ref)
|
||||||
|
|
||||||
|
@ -38,11 +38,11 @@ class Resource(keystone_resource.ResourceDriverV9):
|
|||||||
return project_ref
|
return project_ref
|
||||||
|
|
||||||
def get_project(self, project_id):
|
def get_project(self, project_id):
|
||||||
with sql.transaction() as session:
|
with sql.session_for_read() as session:
|
||||||
return self._get_project(session, project_id).to_dict()
|
return self._get_project(session, project_id).to_dict()
|
||||||
|
|
||||||
def get_project_by_name(self, project_name, domain_id):
|
def get_project_by_name(self, project_name, domain_id):
|
||||||
with sql.transaction() as session:
|
with sql.session_for_read() as session:
|
||||||
query = session.query(Project)
|
query = session.query(Project)
|
||||||
query = query.filter_by(name=project_name)
|
query = query.filter_by(name=project_name)
|
||||||
if domain_id is None:
|
if domain_id is None:
|
||||||
@ -70,7 +70,7 @@ class Resource(keystone_resource.ResourceDriverV9):
|
|||||||
for f in hints.filters:
|
for f in hints.filters:
|
||||||
if (f['name'] == 'domain_id' and f['value'] is None):
|
if (f['name'] == 'domain_id' and f['value'] is None):
|
||||||
f['value'] = keystone_resource.NULL_DOMAIN_ID
|
f['value'] = keystone_resource.NULL_DOMAIN_ID
|
||||||
with sql.transaction() as session:
|
with sql.session_for_read() as session:
|
||||||
query = session.query(Project)
|
query = session.query(Project)
|
||||||
project_refs = sql.filter_limit_query(Project, query, hints)
|
project_refs = sql.filter_limit_query(Project, query, hints)
|
||||||
return [project_ref.to_dict() for project_ref in project_refs
|
return [project_ref.to_dict() for project_ref in project_refs
|
||||||
@ -80,7 +80,7 @@ class Resource(keystone_resource.ResourceDriverV9):
|
|||||||
if not ids:
|
if not ids:
|
||||||
return []
|
return []
|
||||||
else:
|
else:
|
||||||
with sql.transaction() as session:
|
with sql.session_for_read() as session:
|
||||||
query = session.query(Project)
|
query = session.query(Project)
|
||||||
query = query.filter(Project.id.in_(ids))
|
query = query.filter(Project.id.in_(ids))
|
||||||
return [project_ref.to_dict() for project_ref in query.all()
|
return [project_ref.to_dict() for project_ref in query.all()
|
||||||
@ -90,7 +90,7 @@ class Resource(keystone_resource.ResourceDriverV9):
|
|||||||
if not domain_ids:
|
if not domain_ids:
|
||||||
return []
|
return []
|
||||||
else:
|
else:
|
||||||
with sql.transaction() as session:
|
with sql.session_for_read() as session:
|
||||||
query = session.query(Project.id)
|
query = session.query(Project.id)
|
||||||
query = (
|
query = (
|
||||||
query.filter(Project.domain_id.in_(domain_ids)))
|
query.filter(Project.domain_id.in_(domain_ids)))
|
||||||
@ -98,7 +98,7 @@ class Resource(keystone_resource.ResourceDriverV9):
|
|||||||
if not self._is_hidden_ref(x)]
|
if not self._is_hidden_ref(x)]
|
||||||
|
|
||||||
def list_projects_in_domain(self, domain_id):
|
def list_projects_in_domain(self, domain_id):
|
||||||
with sql.transaction() as session:
|
with sql.session_for_read() as session:
|
||||||
self._get_domain(session, domain_id)
|
self._get_domain(session, domain_id)
|
||||||
query = session.query(Project)
|
query = session.query(Project)
|
||||||
project_refs = query.filter_by(domain_id=domain_id)
|
project_refs = query.filter_by(domain_id=domain_id)
|
||||||
@ -111,7 +111,7 @@ class Resource(keystone_resource.ResourceDriverV9):
|
|||||||
return [project_ref.to_dict() for project_ref in project_refs]
|
return [project_ref.to_dict() for project_ref in project_refs]
|
||||||
|
|
||||||
def list_projects_in_subtree(self, project_id):
|
def list_projects_in_subtree(self, project_id):
|
||||||
with sql.transaction() as session:
|
with sql.session_for_read() as session:
|
||||||
children = self._get_children(session, [project_id])
|
children = self._get_children(session, [project_id])
|
||||||
subtree = []
|
subtree = []
|
||||||
examined = set([project_id])
|
examined = set([project_id])
|
||||||
@ -132,7 +132,7 @@ class Resource(keystone_resource.ResourceDriverV9):
|
|||||||
return subtree
|
return subtree
|
||||||
|
|
||||||
def list_project_parents(self, project_id):
|
def list_project_parents(self, project_id):
|
||||||
with sql.transaction() as session:
|
with sql.session_for_read() as session:
|
||||||
project = self._get_project(session, project_id).to_dict()
|
project = self._get_project(session, project_id).to_dict()
|
||||||
parents = []
|
parents = []
|
||||||
examined = set()
|
examined = set()
|
||||||
@ -152,7 +152,7 @@ class Resource(keystone_resource.ResourceDriverV9):
|
|||||||
return parents
|
return parents
|
||||||
|
|
||||||
def is_leaf_project(self, project_id):
|
def is_leaf_project(self, project_id):
|
||||||
with sql.transaction() as session:
|
with sql.session_for_read() as session:
|
||||||
project_refs = self._get_children(session, [project_id])
|
project_refs = self._get_children(session, [project_id])
|
||||||
return not project_refs
|
return not project_refs
|
||||||
|
|
||||||
@ -161,7 +161,7 @@ class Resource(keystone_resource.ResourceDriverV9):
|
|||||||
def create_project(self, project_id, project):
|
def create_project(self, project_id, project):
|
||||||
project['name'] = clean.project_name(project['name'])
|
project['name'] = clean.project_name(project['name'])
|
||||||
new_project = self._encode_domain_id(project)
|
new_project = self._encode_domain_id(project)
|
||||||
with sql.transaction() as session:
|
with sql.session_for_write() as session:
|
||||||
project_ref = Project.from_dict(new_project)
|
project_ref = Project.from_dict(new_project)
|
||||||
session.add(project_ref)
|
session.add(project_ref)
|
||||||
return project_ref.to_dict()
|
return project_ref.to_dict()
|
||||||
@ -172,7 +172,7 @@ class Resource(keystone_resource.ResourceDriverV9):
|
|||||||
project['name'] = clean.project_name(project['name'])
|
project['name'] = clean.project_name(project['name'])
|
||||||
|
|
||||||
update_project = self._encode_domain_id(project)
|
update_project = self._encode_domain_id(project)
|
||||||
with sql.transaction() as session:
|
with sql.session_for_write() as session:
|
||||||
project_ref = self._get_project(session, project_id)
|
project_ref = self._get_project(session, project_id)
|
||||||
old_project_dict = project_ref.to_dict()
|
old_project_dict = project_ref.to_dict()
|
||||||
for k in update_project:
|
for k in update_project:
|
||||||
@ -189,7 +189,7 @@ class Resource(keystone_resource.ResourceDriverV9):
|
|||||||
|
|
||||||
@sql.handle_conflicts(conflict_type='project')
|
@sql.handle_conflicts(conflict_type='project')
|
||||||
def delete_project(self, project_id):
|
def delete_project(self, project_id):
|
||||||
with sql.transaction() as session:
|
with sql.session_for_write() as session:
|
||||||
project_ref = self._get_project(session, project_id)
|
project_ref = self._get_project(session, project_id)
|
||||||
session.delete(project_ref)
|
session.delete(project_ref)
|
||||||
|
|
||||||
@ -197,7 +197,7 @@ class Resource(keystone_resource.ResourceDriverV9):
|
|||||||
def delete_projects_from_ids(self, project_ids):
|
def delete_projects_from_ids(self, project_ids):
|
||||||
if not project_ids:
|
if not project_ids:
|
||||||
return
|
return
|
||||||
with sql.transaction() as session:
|
with sql.session_for_write() as session:
|
||||||
query = session.query(Project).filter(Project.id.in_(
|
query = session.query(Project).filter(Project.id.in_(
|
||||||
project_ids))
|
project_ids))
|
||||||
project_ids_from_bd = [p['id'] for p in query.all()]
|
project_ids_from_bd = [p['id'] for p in query.all()]
|
||||||
@ -212,14 +212,14 @@ class Resource(keystone_resource.ResourceDriverV9):
|
|||||||
|
|
||||||
@sql.handle_conflicts(conflict_type='domain')
|
@sql.handle_conflicts(conflict_type='domain')
|
||||||
def create_domain(self, domain_id, domain):
|
def create_domain(self, domain_id, domain):
|
||||||
with sql.transaction() as session:
|
with sql.session_for_write() as session:
|
||||||
ref = Domain.from_dict(domain)
|
ref = Domain.from_dict(domain)
|
||||||
session.add(ref)
|
session.add(ref)
|
||||||
return ref.to_dict()
|
return ref.to_dict()
|
||||||
|
|
||||||
@driver_hints.truncated
|
@driver_hints.truncated
|
||||||
def list_domains(self, hints):
|
def list_domains(self, hints):
|
||||||
with sql.transaction() as session:
|
with sql.session_for_read() as session:
|
||||||
query = session.query(Domain)
|
query = session.query(Domain)
|
||||||
refs = sql.filter_limit_query(Domain, query, hints)
|
refs = sql.filter_limit_query(Domain, query, hints)
|
||||||
return [ref.to_dict() for ref in refs
|
return [ref.to_dict() for ref in refs
|
||||||
@ -229,7 +229,7 @@ class Resource(keystone_resource.ResourceDriverV9):
|
|||||||
if not ids:
|
if not ids:
|
||||||
return []
|
return []
|
||||||
else:
|
else:
|
||||||
with sql.transaction() as session:
|
with sql.session_for_read() as session:
|
||||||
query = session.query(Domain)
|
query = session.query(Domain)
|
||||||
query = query.filter(Domain.id.in_(ids))
|
query = query.filter(Domain.id.in_(ids))
|
||||||
domain_refs = query.all()
|
domain_refs = query.all()
|
||||||
@ -243,11 +243,11 @@ class Resource(keystone_resource.ResourceDriverV9):
|
|||||||
return ref
|
return ref
|
||||||
|
|
||||||
def get_domain(self, domain_id):
|
def get_domain(self, domain_id):
|
||||||
with sql.transaction() as session:
|
with sql.session_for_read() as session:
|
||||||
return self._get_domain(session, domain_id).to_dict()
|
return self._get_domain(session, domain_id).to_dict()
|
||||||
|
|
||||||
def get_domain_by_name(self, domain_name):
|
def get_domain_by_name(self, domain_name):
|
||||||
with sql.transaction() as session:
|
with sql.session_for_read() as session:
|
||||||
try:
|
try:
|
||||||
ref = (session.query(Domain).
|
ref = (session.query(Domain).
|
||||||
filter_by(name=domain_name).one())
|
filter_by(name=domain_name).one())
|
||||||
@ -260,7 +260,7 @@ class Resource(keystone_resource.ResourceDriverV9):
|
|||||||
|
|
||||||
@sql.handle_conflicts(conflict_type='domain')
|
@sql.handle_conflicts(conflict_type='domain')
|
||||||
def update_domain(self, domain_id, domain):
|
def update_domain(self, domain_id, domain):
|
||||||
with sql.transaction() as session:
|
with sql.session_for_write() as session:
|
||||||
ref = self._get_domain(session, domain_id)
|
ref = self._get_domain(session, domain_id)
|
||||||
old_dict = ref.to_dict()
|
old_dict = ref.to_dict()
|
||||||
for k in domain:
|
for k in domain:
|
||||||
@ -273,7 +273,7 @@ class Resource(keystone_resource.ResourceDriverV9):
|
|||||||
return ref.to_dict()
|
return ref.to_dict()
|
||||||
|
|
||||||
def delete_domain(self, domain_id):
|
def delete_domain(self, domain_id):
|
||||||
with sql.transaction() as session:
|
with sql.session_for_write() as session:
|
||||||
ref = self._get_domain(session, domain_id)
|
ref = self._get_domain(session, domain_id)
|
||||||
session.delete(ref)
|
session.delete(ref)
|
||||||
|
|
||||||
|
@ -59,12 +59,12 @@ class DomainConfig(resource.DomainConfigDriverV8):
|
|||||||
@sql.handle_conflicts(conflict_type='domain_config')
|
@sql.handle_conflicts(conflict_type='domain_config')
|
||||||
def create_config_option(self, domain_id, group, option, value,
|
def create_config_option(self, domain_id, group, option, value,
|
||||||
sensitive=False):
|
sensitive=False):
|
||||||
with sql.transaction() as session:
|
with sql.session_for_write() as session:
|
||||||
config_table = self.choose_table(sensitive)
|
config_table = self.choose_table(sensitive)
|
||||||
ref = config_table(domain_id=domain_id, group=group,
|
ref = config_table(domain_id=domain_id, group=group,
|
||||||
option=option, value=value)
|
option=option, value=value)
|
||||||
session.add(ref)
|
session.add(ref)
|
||||||
return ref.to_dict()
|
return ref.to_dict()
|
||||||
|
|
||||||
def _get_config_option(self, session, domain_id, group, option, sensitive):
|
def _get_config_option(self, session, domain_id, group, option, sensitive):
|
||||||
try:
|
try:
|
||||||
@ -80,14 +80,14 @@ class DomainConfig(resource.DomainConfigDriverV8):
|
|||||||
return ref
|
return ref
|
||||||
|
|
||||||
def get_config_option(self, domain_id, group, option, sensitive=False):
|
def get_config_option(self, domain_id, group, option, sensitive=False):
|
||||||
with sql.transaction() as session:
|
with sql.session_for_read() as session:
|
||||||
ref = self._get_config_option(session, domain_id, group, option,
|
ref = self._get_config_option(session, domain_id, group, option,
|
||||||
sensitive)
|
sensitive)
|
||||||
return ref.to_dict()
|
return ref.to_dict()
|
||||||
|
|
||||||
def list_config_options(self, domain_id, group=None, option=None,
|
def list_config_options(self, domain_id, group=None, option=None,
|
||||||
sensitive=False):
|
sensitive=False):
|
||||||
with sql.transaction() as session:
|
with sql.session_for_read() as session:
|
||||||
config_table = self.choose_table(sensitive)
|
config_table = self.choose_table(sensitive)
|
||||||
query = session.query(config_table)
|
query = session.query(config_table)
|
||||||
query = query.filter_by(domain_id=domain_id)
|
query = query.filter_by(domain_id=domain_id)
|
||||||
@ -99,11 +99,11 @@ class DomainConfig(resource.DomainConfigDriverV8):
|
|||||||
|
|
||||||
def update_config_option(self, domain_id, group, option, value,
|
def update_config_option(self, domain_id, group, option, value,
|
||||||
sensitive=False):
|
sensitive=False):
|
||||||
with sql.transaction() as session:
|
with sql.session_for_write() as session:
|
||||||
ref = self._get_config_option(session, domain_id, group, option,
|
ref = self._get_config_option(session, domain_id, group, option,
|
||||||
sensitive)
|
sensitive)
|
||||||
ref.value = value
|
ref.value = value
|
||||||
return ref.to_dict()
|
return ref.to_dict()
|
||||||
|
|
||||||
def delete_config_options(self, domain_id, group=None, option=None,
|
def delete_config_options(self, domain_id, group=None, option=None,
|
||||||
sensitive=False):
|
sensitive=False):
|
||||||
@ -114,7 +114,7 @@ class DomainConfig(resource.DomainConfigDriverV8):
|
|||||||
if there was nothing to delete.
|
if there was nothing to delete.
|
||||||
|
|
||||||
"""
|
"""
|
||||||
with sql.transaction() as session:
|
with sql.session_for_write() as session:
|
||||||
config_table = self.choose_table(sensitive)
|
config_table = self.choose_table(sensitive)
|
||||||
query = session.query(config_table)
|
query = session.query(config_table)
|
||||||
query = query.filter_by(domain_id=domain_id)
|
query = query.filter_by(domain_id=domain_id)
|
||||||
@ -126,7 +126,7 @@ class DomainConfig(resource.DomainConfigDriverV8):
|
|||||||
|
|
||||||
def obtain_registration(self, domain_id, type):
|
def obtain_registration(self, domain_id, type):
|
||||||
try:
|
try:
|
||||||
with sql.transaction() as session:
|
with sql.session_for_write() as session:
|
||||||
ref = ConfigRegister(type=type, domain_id=domain_id)
|
ref = ConfigRegister(type=type, domain_id=domain_id)
|
||||||
session.add(ref)
|
session.add(ref)
|
||||||
return True
|
return True
|
||||||
@ -136,15 +136,15 @@ class DomainConfig(resource.DomainConfigDriverV8):
|
|||||||
return False
|
return False
|
||||||
|
|
||||||
def read_registration(self, type):
|
def read_registration(self, type):
|
||||||
with sql.transaction() as session:
|
with sql.session_for_read() as session:
|
||||||
ref = session.query(ConfigRegister).get(type)
|
ref = session.query(ConfigRegister).get(type)
|
||||||
if not ref:
|
if not ref:
|
||||||
raise exception.ConfigRegistrationNotFound()
|
raise exception.ConfigRegistrationNotFound()
|
||||||
return ref.domain_id
|
return ref.domain_id
|
||||||
|
|
||||||
def release_registration(self, domain_id, type=None):
|
def release_registration(self, domain_id, type=None):
|
||||||
"""Silently delete anything registered for the domain specified."""
|
"""Silently delete anything registered for the domain specified."""
|
||||||
with sql.transaction() as session:
|
with sql.session_for_write() as session:
|
||||||
query = session.query(ConfigRegister)
|
query = session.query(ConfigRegister)
|
||||||
if type:
|
if type:
|
||||||
query = query.filter_by(type=type)
|
query = query.filter_by(type=type)
|
||||||
|
@ -60,37 +60,37 @@ class Revoke(revoke.RevokeDriverV8):
|
|||||||
def _prune_expired_events(self):
|
def _prune_expired_events(self):
|
||||||
oldest = revoke.revoked_before_cutoff_time()
|
oldest = revoke.revoked_before_cutoff_time()
|
||||||
|
|
||||||
session = sql.get_session()
|
with sql.session_for_write() as session:
|
||||||
dialect = session.bind.dialect.name
|
dialect = session.bind.dialect.name
|
||||||
batch_size = self._flush_batch_size(dialect)
|
batch_size = self._flush_batch_size(dialect)
|
||||||
if batch_size > 0:
|
if batch_size > 0:
|
||||||
query = session.query(RevocationEvent.id)
|
query = session.query(RevocationEvent.id)
|
||||||
query = query.filter(RevocationEvent.revoked_at < oldest)
|
query = query.filter(RevocationEvent.revoked_at < oldest)
|
||||||
query = query.limit(batch_size).subquery()
|
query = query.limit(batch_size).subquery()
|
||||||
delete_query = (session.query(RevocationEvent).
|
delete_query = (session.query(RevocationEvent).
|
||||||
filter(RevocationEvent.id.in_(query)))
|
filter(RevocationEvent.id.in_(query)))
|
||||||
while True:
|
while True:
|
||||||
rowcount = delete_query.delete(synchronize_session=False)
|
rowcount = delete_query.delete(synchronize_session=False)
|
||||||
if rowcount == 0:
|
if rowcount == 0:
|
||||||
break
|
break
|
||||||
else:
|
else:
|
||||||
query = session.query(RevocationEvent)
|
query = session.query(RevocationEvent)
|
||||||
query = query.filter(RevocationEvent.revoked_at < oldest)
|
query = query.filter(RevocationEvent.revoked_at < oldest)
|
||||||
query.delete(synchronize_session=False)
|
query.delete(synchronize_session=False)
|
||||||
|
|
||||||
session.flush()
|
session.flush()
|
||||||
|
|
||||||
def list_events(self, last_fetch=None):
|
def list_events(self, last_fetch=None):
|
||||||
session = sql.get_session()
|
with sql.session_for_read() as session:
|
||||||
query = session.query(RevocationEvent).order_by(
|
query = session.query(RevocationEvent).order_by(
|
||||||
RevocationEvent.revoked_at)
|
RevocationEvent.revoked_at)
|
||||||
|
|
||||||
if last_fetch:
|
if last_fetch:
|
||||||
query = query.filter(RevocationEvent.revoked_at > last_fetch)
|
query = query.filter(RevocationEvent.revoked_at > last_fetch)
|
||||||
|
|
||||||
events = [model.RevokeEvent(**e.to_dict()) for e in query]
|
events = [model.RevokeEvent(**e.to_dict()) for e in query]
|
||||||
|
|
||||||
return events
|
return events
|
||||||
|
|
||||||
def revoke(self, event):
|
def revoke(self, event):
|
||||||
kwargs = dict()
|
kwargs = dict()
|
||||||
@ -98,7 +98,6 @@ class Revoke(revoke.RevokeDriverV8):
|
|||||||
kwargs[attr] = getattr(event, attr)
|
kwargs[attr] = getattr(event, attr)
|
||||||
kwargs['id'] = uuid.uuid4().hex
|
kwargs['id'] = uuid.uuid4().hex
|
||||||
record = RevocationEvent(**kwargs)
|
record = RevocationEvent(**kwargs)
|
||||||
session = sql.get_session()
|
with sql.session_for_write() as session:
|
||||||
with session.begin():
|
|
||||||
session.add(record)
|
session.add(record)
|
||||||
self._prune_expired_events()
|
self._prune_expired_events()
|
||||||
|
@ -17,6 +17,6 @@ from keystone.identity.mapping_backends import sql as mapping_sql
|
|||||||
|
|
||||||
def list_id_mappings():
|
def list_id_mappings():
|
||||||
"""List all id_mappings for testing purposes."""
|
"""List all id_mappings for testing purposes."""
|
||||||
a_session = sql.get_session()
|
with sql.session_for_read() as session:
|
||||||
refs = a_session.query(mapping_sql.IDMapping).all()
|
refs = session.query(mapping_sql.IDMapping).all()
|
||||||
return [x.to_dict() for x in refs]
|
return [x.to_dict() for x in refs]
|
||||||
|
@ -611,7 +611,7 @@ class SqlToken(SqlTests, test_backend.TokenTests):
|
|||||||
tok = token_sql.Token()
|
tok = token_sql.Token()
|
||||||
tok.list_revoked_tokens()
|
tok.list_revoked_tokens()
|
||||||
|
|
||||||
mock_query = mock_sql.get_session().query
|
mock_query = mock_sql.session_for_read().__enter__().query
|
||||||
mock_query.assert_called_with(*expected_query_args)
|
mock_query.assert_called_with(*expected_query_args)
|
||||||
|
|
||||||
def test_flush_expired_tokens_batch(self):
|
def test_flush_expired_tokens_batch(self):
|
||||||
@ -636,8 +636,12 @@ class SqlToken(SqlTests, test_backend.TokenTests):
|
|||||||
# other tests below test the differences between how they use the batch
|
# other tests below test the differences between how they use the batch
|
||||||
# strategy
|
# strategy
|
||||||
with mock.patch.object(token_sql, 'sql') as mock_sql:
|
with mock.patch.object(token_sql, 'sql') as mock_sql:
|
||||||
mock_sql.get_session().query().filter().delete.return_value = 0
|
mock_sql.session_for_write().__enter__(
|
||||||
mock_sql.get_session().bind.dialect.name = 'mysql'
|
).query().filter().delete.return_value = 0
|
||||||
|
|
||||||
|
mock_sql.session_for_write().__enter__(
|
||||||
|
).bind.dialect.name = 'mysql'
|
||||||
|
|
||||||
tok = token_sql.Token()
|
tok = token_sql.Token()
|
||||||
expiry_mock = mock.Mock()
|
expiry_mock = mock.Mock()
|
||||||
ITERS = [1, 2, 3]
|
ITERS = [1, 2, 3]
|
||||||
@ -648,7 +652,10 @@ class SqlToken(SqlTests, test_backend.TokenTests):
|
|||||||
# The expiry strategy is only invoked once, the other calls are via
|
# The expiry strategy is only invoked once, the other calls are via
|
||||||
# the yield return.
|
# the yield return.
|
||||||
self.assertEqual(1, expiry_mock.call_count)
|
self.assertEqual(1, expiry_mock.call_count)
|
||||||
mock_delete = mock_sql.get_session().query().filter().delete
|
|
||||||
|
mock_delete = mock_sql.session_for_write().__enter__(
|
||||||
|
).query().filter().delete
|
||||||
|
|
||||||
self.assertThat(mock_delete.call_args_list,
|
self.assertThat(mock_delete.call_args_list,
|
||||||
matchers.HasLength(len(ITERS)))
|
matchers.HasLength(len(ITERS)))
|
||||||
|
|
||||||
|
@ -161,7 +161,8 @@ class SqlMigrateBase(unit.SQLDriverOverrides, unit.TestCase):
|
|||||||
self.repo_package())
|
self.repo_package())
|
||||||
self.schema = versioning_api.ControlledSchema.create(
|
self.schema = versioning_api.ControlledSchema.create(
|
||||||
self.engine,
|
self.engine,
|
||||||
self.repo_path, self.initial_db_version)
|
self.repo_path,
|
||||||
|
self.initial_db_version)
|
||||||
|
|
||||||
# auto-detect the highest available schema version in the migrate_repo
|
# auto-detect the highest available schema version in the migrate_repo
|
||||||
self.max_version = self.schema.repository.version().version
|
self.max_version = self.schema.repository.version().version
|
||||||
|
@ -86,11 +86,11 @@ class Token(token.persistence.TokenDriverV8):
|
|||||||
def get_token(self, token_id):
|
def get_token(self, token_id):
|
||||||
if token_id is None:
|
if token_id is None:
|
||||||
raise exception.TokenNotFound(token_id=token_id)
|
raise exception.TokenNotFound(token_id=token_id)
|
||||||
session = sql.get_session()
|
with sql.session_for_read() as session:
|
||||||
token_ref = session.query(TokenModel).get(token_id)
|
token_ref = session.query(TokenModel).get(token_id)
|
||||||
if not token_ref or not token_ref.valid:
|
if not token_ref or not token_ref.valid:
|
||||||
raise exception.TokenNotFound(token_id=token_id)
|
raise exception.TokenNotFound(token_id=token_id)
|
||||||
return token_ref.to_dict()
|
return token_ref.to_dict()
|
||||||
|
|
||||||
def create_token(self, token_id, data):
|
def create_token(self, token_id, data):
|
||||||
data_copy = copy.deepcopy(data)
|
data_copy = copy.deepcopy(data)
|
||||||
@ -101,14 +101,12 @@ class Token(token.persistence.TokenDriverV8):
|
|||||||
|
|
||||||
token_ref = TokenModel.from_dict(data_copy)
|
token_ref = TokenModel.from_dict(data_copy)
|
||||||
token_ref.valid = True
|
token_ref.valid = True
|
||||||
session = sql.get_session()
|
with sql.session_for_write() as session:
|
||||||
with session.begin():
|
|
||||||
session.add(token_ref)
|
session.add(token_ref)
|
||||||
return token_ref.to_dict()
|
return token_ref.to_dict()
|
||||||
|
|
||||||
def delete_token(self, token_id):
|
def delete_token(self, token_id):
|
||||||
session = sql.get_session()
|
with sql.session_for_write() as session:
|
||||||
with session.begin():
|
|
||||||
token_ref = session.query(TokenModel).get(token_id)
|
token_ref = session.query(TokenModel).get(token_id)
|
||||||
if not token_ref or not token_ref.valid:
|
if not token_ref or not token_ref.valid:
|
||||||
raise exception.TokenNotFound(token_id=token_id)
|
raise exception.TokenNotFound(token_id=token_id)
|
||||||
@ -124,9 +122,8 @@ class Token(token.persistence.TokenDriverV8):
|
|||||||
or the trustor's user ID, so will use trust_id to query the tokens.
|
or the trustor's user ID, so will use trust_id to query the tokens.
|
||||||
|
|
||||||
"""
|
"""
|
||||||
session = sql.get_session()
|
|
||||||
token_list = []
|
token_list = []
|
||||||
with session.begin():
|
with sql.session_for_write() as session:
|
||||||
now = timeutils.utcnow()
|
now = timeutils.utcnow()
|
||||||
query = session.query(TokenModel)
|
query = session.query(TokenModel)
|
||||||
query = query.filter_by(valid=True)
|
query = query.filter_by(valid=True)
|
||||||
@ -167,38 +164,37 @@ class Token(token.persistence.TokenDriverV8):
|
|||||||
return False
|
return False
|
||||||
|
|
||||||
def _list_tokens_for_trust(self, trust_id):
|
def _list_tokens_for_trust(self, trust_id):
|
||||||
session = sql.get_session()
|
with sql.session_for_read() as session:
|
||||||
tokens = []
|
tokens = []
|
||||||
now = timeutils.utcnow()
|
now = timeutils.utcnow()
|
||||||
query = session.query(TokenModel)
|
query = session.query(TokenModel)
|
||||||
query = query.filter(TokenModel.expires > now)
|
query = query.filter(TokenModel.expires > now)
|
||||||
query = query.filter(TokenModel.trust_id == trust_id)
|
query = query.filter(TokenModel.trust_id == trust_id)
|
||||||
|
|
||||||
token_references = query.filter_by(valid=True)
|
token_references = query.filter_by(valid=True)
|
||||||
for token_ref in token_references:
|
for token_ref in token_references:
|
||||||
token_ref_dict = token_ref.to_dict()
|
token_ref_dict = token_ref.to_dict()
|
||||||
tokens.append(token_ref_dict['id'])
|
tokens.append(token_ref_dict['id'])
|
||||||
return tokens
|
return tokens
|
||||||
|
|
||||||
def _list_tokens_for_user(self, user_id, tenant_id=None):
|
def _list_tokens_for_user(self, user_id, tenant_id=None):
|
||||||
session = sql.get_session()
|
with sql.session_for_read() as session:
|
||||||
tokens = []
|
tokens = []
|
||||||
now = timeutils.utcnow()
|
now = timeutils.utcnow()
|
||||||
query = session.query(TokenModel)
|
query = session.query(TokenModel)
|
||||||
query = query.filter(TokenModel.expires > now)
|
query = query.filter(TokenModel.expires > now)
|
||||||
query = query.filter(TokenModel.user_id == user_id)
|
query = query.filter(TokenModel.user_id == user_id)
|
||||||
|
|
||||||
token_references = query.filter_by(valid=True)
|
token_references = query.filter_by(valid=True)
|
||||||
for token_ref in token_references:
|
for token_ref in token_references:
|
||||||
token_ref_dict = token_ref.to_dict()
|
token_ref_dict = token_ref.to_dict()
|
||||||
if self._tenant_matches(tenant_id, token_ref_dict):
|
if self._tenant_matches(tenant_id, token_ref_dict):
|
||||||
tokens.append(token_ref['id'])
|
tokens.append(token_ref['id'])
|
||||||
return tokens
|
return tokens
|
||||||
|
|
||||||
def _list_tokens_for_consumer(self, user_id, consumer_id):
|
def _list_tokens_for_consumer(self, user_id, consumer_id):
|
||||||
tokens = []
|
tokens = []
|
||||||
session = sql.get_session()
|
with sql.session_for_write() as session:
|
||||||
with session.begin():
|
|
||||||
now = timeutils.utcnow()
|
now = timeutils.utcnow()
|
||||||
query = session.query(TokenModel)
|
query = session.query(TokenModel)
|
||||||
query = query.filter(TokenModel.expires > now)
|
query = query.filter(TokenModel.expires > now)
|
||||||
@ -223,29 +219,29 @@ class Token(token.persistence.TokenDriverV8):
|
|||||||
return self._list_tokens_for_user(user_id, tenant_id)
|
return self._list_tokens_for_user(user_id, tenant_id)
|
||||||
|
|
||||||
def list_revoked_tokens(self):
|
def list_revoked_tokens(self):
|
||||||
session = sql.get_session()
|
with sql.session_for_read() as session:
|
||||||
tokens = []
|
tokens = []
|
||||||
now = timeutils.utcnow()
|
now = timeutils.utcnow()
|
||||||
query = session.query(TokenModel.id, TokenModel.expires,
|
query = session.query(TokenModel.id, TokenModel.expires,
|
||||||
TokenModel.extra)
|
TokenModel.extra)
|
||||||
query = query.filter(TokenModel.expires > now)
|
query = query.filter(TokenModel.expires > now)
|
||||||
token_references = query.filter_by(valid=False)
|
token_references = query.filter_by(valid=False)
|
||||||
for token_ref in token_references:
|
for token_ref in token_references:
|
||||||
token_data = token_ref[2]['token_data']
|
token_data = token_ref[2]['token_data']
|
||||||
if 'access' in token_data:
|
if 'access' in token_data:
|
||||||
# It's a v2 token.
|
# It's a v2 token.
|
||||||
audit_ids = token_data['access']['token']['audit_ids']
|
audit_ids = token_data['access']['token']['audit_ids']
|
||||||
else:
|
else:
|
||||||
# It's a v3 token.
|
# It's a v3 token.
|
||||||
audit_ids = token_data['token']['audit_ids']
|
audit_ids = token_data['token']['audit_ids']
|
||||||
|
|
||||||
record = {
|
record = {
|
||||||
'id': token_ref[0],
|
'id': token_ref[0],
|
||||||
'expires': token_ref[1],
|
'expires': token_ref[1],
|
||||||
'audit_id': audit_ids[0],
|
'audit_id': audit_ids[0],
|
||||||
}
|
}
|
||||||
tokens.append(record)
|
tokens.append(record)
|
||||||
return tokens
|
return tokens
|
||||||
|
|
||||||
def _expiry_range_strategy(self, dialect):
|
def _expiry_range_strategy(self, dialect):
|
||||||
"""Choose a token range expiration strategy
|
"""Choose a token range expiration strategy
|
||||||
@ -273,18 +269,18 @@ class Token(token.persistence.TokenDriverV8):
|
|||||||
return _expiry_range_all
|
return _expiry_range_all
|
||||||
|
|
||||||
def flush_expired_tokens(self):
|
def flush_expired_tokens(self):
|
||||||
session = sql.get_session()
|
with sql.session_for_write() as session:
|
||||||
dialect = session.bind.dialect.name
|
dialect = session.bind.dialect.name
|
||||||
expiry_range_func = self._expiry_range_strategy(dialect)
|
expiry_range_func = self._expiry_range_strategy(dialect)
|
||||||
query = session.query(TokenModel.expires)
|
query = session.query(TokenModel.expires)
|
||||||
total_removed = 0
|
total_removed = 0
|
||||||
upper_bound_func = timeutils.utcnow
|
upper_bound_func = timeutils.utcnow
|
||||||
for expiry_time in expiry_range_func(session, upper_bound_func):
|
for expiry_time in expiry_range_func(session, upper_bound_func):
|
||||||
delete_query = query.filter(TokenModel.expires <=
|
delete_query = query.filter(TokenModel.expires <=
|
||||||
expiry_time)
|
expiry_time)
|
||||||
row_count = delete_query.delete(synchronize_session=False)
|
row_count = delete_query.delete(synchronize_session=False)
|
||||||
total_removed += row_count
|
total_removed += row_count
|
||||||
LOG.debug('Removed %d total expired tokens', total_removed)
|
LOG.debug('Removed %d total expired tokens', total_removed)
|
||||||
|
|
||||||
session.flush()
|
session.flush()
|
||||||
LOG.info(_LI('Total expired tokens removed: %d'), total_removed)
|
LOG.info(_LI('Total expired tokens removed: %d'), total_removed)
|
||||||
|
@ -59,7 +59,7 @@ class TrustRole(sql.ModelBase):
|
|||||||
class Trust(trust.TrustDriverV8):
|
class Trust(trust.TrustDriverV8):
|
||||||
@sql.handle_conflicts(conflict_type='trust')
|
@sql.handle_conflicts(conflict_type='trust')
|
||||||
def create_trust(self, trust_id, trust, roles):
|
def create_trust(self, trust_id, trust, roles):
|
||||||
with sql.transaction() as session:
|
with sql.session_for_write() as session:
|
||||||
ref = TrustModel.from_dict(trust)
|
ref = TrustModel.from_dict(trust)
|
||||||
ref['id'] = trust_id
|
ref['id'] = trust_id
|
||||||
if ref.get('expires_at') and ref['expires_at'].tzinfo is not None:
|
if ref.get('expires_at') and ref['expires_at'].tzinfo is not None:
|
||||||
@ -72,9 +72,9 @@ class Trust(trust.TrustDriverV8):
|
|||||||
trust_role.role_id = role['id']
|
trust_role.role_id = role['id']
|
||||||
added_roles.append({'id': role['id']})
|
added_roles.append({'id': role['id']})
|
||||||
session.add(trust_role)
|
session.add(trust_role)
|
||||||
trust_dict = ref.to_dict()
|
trust_dict = ref.to_dict()
|
||||||
trust_dict['roles'] = added_roles
|
trust_dict['roles'] = added_roles
|
||||||
return trust_dict
|
return trust_dict
|
||||||
|
|
||||||
def _add_roles(self, trust_id, session, trust_dict):
|
def _add_roles(self, trust_id, session, trust_dict):
|
||||||
roles = []
|
roles = []
|
||||||
@ -86,7 +86,7 @@ class Trust(trust.TrustDriverV8):
|
|||||||
def consume_use(self, trust_id):
|
def consume_use(self, trust_id):
|
||||||
|
|
||||||
for attempt in range(MAXIMUM_CONSUME_ATTEMPTS):
|
for attempt in range(MAXIMUM_CONSUME_ATTEMPTS):
|
||||||
with sql.transaction() as session:
|
with sql.session_for_write() as session:
|
||||||
try:
|
try:
|
||||||
query_result = (session.query(TrustModel.remaining_uses).
|
query_result = (session.query(TrustModel.remaining_uses).
|
||||||
filter_by(id=trust_id).
|
filter_by(id=trust_id).
|
||||||
@ -132,51 +132,51 @@ class Trust(trust.TrustDriverV8):
|
|||||||
raise exception.TrustConsumeMaximumAttempt(trust_id=trust_id)
|
raise exception.TrustConsumeMaximumAttempt(trust_id=trust_id)
|
||||||
|
|
||||||
def get_trust(self, trust_id, deleted=False):
|
def get_trust(self, trust_id, deleted=False):
|
||||||
session = sql.get_session()
|
with sql.session_for_read() as session:
|
||||||
query = session.query(TrustModel).filter_by(id=trust_id)
|
query = session.query(TrustModel).filter_by(id=trust_id)
|
||||||
if not deleted:
|
if not deleted:
|
||||||
query = query.filter_by(deleted_at=None)
|
query = query.filter_by(deleted_at=None)
|
||||||
ref = query.first()
|
ref = query.first()
|
||||||
if ref is None:
|
if ref is None:
|
||||||
raise exception.TrustNotFound(trust_id=trust_id)
|
|
||||||
if ref.expires_at is not None and not deleted:
|
|
||||||
now = timeutils.utcnow()
|
|
||||||
if now > ref.expires_at:
|
|
||||||
raise exception.TrustNotFound(trust_id=trust_id)
|
raise exception.TrustNotFound(trust_id=trust_id)
|
||||||
# Do not return trusts that can't be used anymore
|
if ref.expires_at is not None and not deleted:
|
||||||
if ref.remaining_uses is not None and not deleted:
|
now = timeutils.utcnow()
|
||||||
if ref.remaining_uses <= 0:
|
if now > ref.expires_at:
|
||||||
raise exception.TrustNotFound(trust_id=trust_id)
|
raise exception.TrustNotFound(trust_id=trust_id)
|
||||||
trust_dict = ref.to_dict()
|
# Do not return trusts that can't be used anymore
|
||||||
|
if ref.remaining_uses is not None and not deleted:
|
||||||
|
if ref.remaining_uses <= 0:
|
||||||
|
raise exception.TrustNotFound(trust_id=trust_id)
|
||||||
|
trust_dict = ref.to_dict()
|
||||||
|
|
||||||
self._add_roles(trust_id, session, trust_dict)
|
self._add_roles(trust_id, session, trust_dict)
|
||||||
return trust_dict
|
return trust_dict
|
||||||
|
|
||||||
@sql.handle_conflicts(conflict_type='trust')
|
@sql.handle_conflicts(conflict_type='trust')
|
||||||
def list_trusts(self):
|
def list_trusts(self):
|
||||||
session = sql.get_session()
|
with sql.session_for_read() as session:
|
||||||
trusts = session.query(TrustModel).filter_by(deleted_at=None)
|
trusts = session.query(TrustModel).filter_by(deleted_at=None)
|
||||||
return [trust_ref.to_dict() for trust_ref in trusts]
|
return [trust_ref.to_dict() for trust_ref in trusts]
|
||||||
|
|
||||||
@sql.handle_conflicts(conflict_type='trust')
|
@sql.handle_conflicts(conflict_type='trust')
|
||||||
def list_trusts_for_trustee(self, trustee_user_id):
|
def list_trusts_for_trustee(self, trustee_user_id):
|
||||||
session = sql.get_session()
|
with sql.session_for_read() as session:
|
||||||
trusts = (session.query(TrustModel).
|
trusts = (session.query(TrustModel).
|
||||||
filter_by(deleted_at=None).
|
filter_by(deleted_at=None).
|
||||||
filter_by(trustee_user_id=trustee_user_id))
|
filter_by(trustee_user_id=trustee_user_id))
|
||||||
return [trust_ref.to_dict() for trust_ref in trusts]
|
return [trust_ref.to_dict() for trust_ref in trusts]
|
||||||
|
|
||||||
@sql.handle_conflicts(conflict_type='trust')
|
@sql.handle_conflicts(conflict_type='trust')
|
||||||
def list_trusts_for_trustor(self, trustor_user_id):
|
def list_trusts_for_trustor(self, trustor_user_id):
|
||||||
session = sql.get_session()
|
with sql.session_for_read() as session:
|
||||||
trusts = (session.query(TrustModel).
|
trusts = (session.query(TrustModel).
|
||||||
filter_by(deleted_at=None).
|
filter_by(deleted_at=None).
|
||||||
filter_by(trustor_user_id=trustor_user_id))
|
filter_by(trustor_user_id=trustor_user_id))
|
||||||
return [trust_ref.to_dict() for trust_ref in trusts]
|
return [trust_ref.to_dict() for trust_ref in trusts]
|
||||||
|
|
||||||
@sql.handle_conflicts(conflict_type='trust')
|
@sql.handle_conflicts(conflict_type='trust')
|
||||||
def delete_trust(self, trust_id):
|
def delete_trust(self, trust_id):
|
||||||
with sql.transaction() as session:
|
with sql.session_for_write() as session:
|
||||||
trust_ref = session.query(TrustModel).get(trust_id)
|
trust_ref = session.query(TrustModel).get(trust_id)
|
||||||
if not trust_ref:
|
if not trust_ref:
|
||||||
raise exception.TrustNotFound(trust_id=trust_id)
|
raise exception.TrustNotFound(trust_id=trust_id)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user