Merge "Use the new enginefacade from oslo.db"
This commit is contained in:
commit
d37af165d0
@ -56,7 +56,7 @@ class Assignment(keystone_assignment.AssignmentDriverV8):
|
||||
return 'sql'
|
||||
|
||||
def list_user_ids_for_project(self, tenant_id):
|
||||
with sql.transaction() as session:
|
||||
with sql.session_for_read() as session:
|
||||
query = session.query(RoleAssignment.actor_id)
|
||||
query = query.filter_by(type=AssignmentType.USER_PROJECT)
|
||||
query = query.filter_by(target_id=tenant_id)
|
||||
@ -71,7 +71,7 @@ class Assignment(keystone_assignment.AssignmentDriverV8):
|
||||
assignment_type = AssignmentType.calculate_type(
|
||||
user_id, group_id, project_id, domain_id)
|
||||
try:
|
||||
with sql.transaction() as session:
|
||||
with sql.session_for_write() as session:
|
||||
session.add(RoleAssignment(
|
||||
type=assignment_type,
|
||||
actor_id=user_id or group_id,
|
||||
@ -85,7 +85,7 @@ class Assignment(keystone_assignment.AssignmentDriverV8):
|
||||
def list_grant_role_ids(self, user_id=None, group_id=None,
|
||||
domain_id=None, project_id=None,
|
||||
inherited_to_projects=False):
|
||||
with sql.transaction() as session:
|
||||
with sql.session_for_read() as session:
|
||||
q = session.query(RoleAssignment.role_id)
|
||||
q = q.filter(RoleAssignment.actor_id == (user_id or group_id))
|
||||
q = q.filter(RoleAssignment.target_id == (project_id or domain_id))
|
||||
@ -104,7 +104,7 @@ class Assignment(keystone_assignment.AssignmentDriverV8):
|
||||
def check_grant_role_id(self, role_id, user_id=None, group_id=None,
|
||||
domain_id=None, project_id=None,
|
||||
inherited_to_projects=False):
|
||||
with sql.transaction() as session:
|
||||
with sql.session_for_read() as session:
|
||||
try:
|
||||
q = self._build_grant_filter(
|
||||
session, role_id, user_id, group_id, domain_id, project_id,
|
||||
@ -120,7 +120,7 @@ class Assignment(keystone_assignment.AssignmentDriverV8):
|
||||
def delete_grant(self, role_id, user_id=None, group_id=None,
|
||||
domain_id=None, project_id=None,
|
||||
inherited_to_projects=False):
|
||||
with sql.transaction() as session:
|
||||
with sql.session_for_write() as session:
|
||||
q = self._build_grant_filter(
|
||||
session, role_id, user_id, group_id, domain_id, project_id,
|
||||
inherited_to_projects)
|
||||
@ -145,11 +145,11 @@ class Assignment(keystone_assignment.AssignmentDriverV8):
|
||||
RoleAssignment.inherited == inherited,
|
||||
RoleAssignment.actor_id.in_(actors))
|
||||
|
||||
with sql.transaction() as session:
|
||||
with sql.session_for_read() as session:
|
||||
query = session.query(RoleAssignment.target_id).filter(
|
||||
sql_constraints).distinct()
|
||||
|
||||
return [x.target_id for x in query.all()]
|
||||
return [x.target_id for x in query.all()]
|
||||
|
||||
def list_project_ids_for_user(self, user_id, group_ids, hints,
|
||||
inherited=False):
|
||||
@ -161,7 +161,7 @@ class Assignment(keystone_assignment.AssignmentDriverV8):
|
||||
|
||||
def list_domain_ids_for_user(self, user_id, group_ids, hints,
|
||||
inherited=False):
|
||||
with sql.transaction() as session:
|
||||
with sql.session_for_read() as session:
|
||||
query = session.query(RoleAssignment.target_id)
|
||||
filters = []
|
||||
|
||||
@ -197,10 +197,10 @@ class Assignment(keystone_assignment.AssignmentDriverV8):
|
||||
RoleAssignment.inherited == false(),
|
||||
RoleAssignment.actor_id.in_(group_ids))
|
||||
|
||||
with sql.transaction() as session:
|
||||
with sql.session_for_read() as session:
|
||||
query = session.query(RoleAssignment.role_id).filter(
|
||||
sql_constraints).distinct()
|
||||
return [role.role_id for role in query.all()]
|
||||
return [role.role_id for role in query.all()]
|
||||
|
||||
def list_role_ids_for_groups_on_project(
|
||||
self, group_ids, project_id, project_domain_id, project_parents):
|
||||
@ -237,13 +237,13 @@ class Assignment(keystone_assignment.AssignmentDriverV8):
|
||||
sql_constraints = sqlalchemy.and_(
|
||||
sql_constraints, RoleAssignment.actor_id.in_(group_ids))
|
||||
|
||||
with sql.transaction() as session:
|
||||
with sql.session_for_read() as session:
|
||||
# NOTE(morganfainberg): Only select the columns we actually care
|
||||
# about here, in this case role_id.
|
||||
query = session.query(RoleAssignment.role_id).filter(
|
||||
sql_constraints).distinct()
|
||||
|
||||
return [result.role_id for result in query.all()]
|
||||
return [result.role_id for result in query.all()]
|
||||
|
||||
def list_project_ids_for_groups(self, group_ids, hints,
|
||||
inherited=False):
|
||||
@ -260,14 +260,14 @@ class Assignment(keystone_assignment.AssignmentDriverV8):
|
||||
RoleAssignment.inherited == inherited,
|
||||
RoleAssignment.actor_id.in_(group_ids))
|
||||
|
||||
with sql.transaction() as session:
|
||||
with sql.session_for_read() as session:
|
||||
query = session.query(RoleAssignment.target_id).filter(
|
||||
group_sql_conditions).distinct()
|
||||
return [x.target_id for x in query.all()]
|
||||
return [x.target_id for x in query.all()]
|
||||
|
||||
def add_role_to_user_and_project(self, user_id, tenant_id, role_id):
|
||||
try:
|
||||
with sql.transaction() as session:
|
||||
with sql.session_for_write() as session:
|
||||
session.add(RoleAssignment(
|
||||
type=AssignmentType.USER_PROJECT,
|
||||
actor_id=user_id, target_id=tenant_id,
|
||||
@ -278,7 +278,7 @@ class Assignment(keystone_assignment.AssignmentDriverV8):
|
||||
raise exception.Conflict(type='role grant', details=msg)
|
||||
|
||||
def remove_role_from_user_and_project(self, user_id, tenant_id, role_id):
|
||||
with sql.transaction() as session:
|
||||
with sql.session_for_write() as session:
|
||||
q = session.query(RoleAssignment)
|
||||
q = q.filter_by(actor_id=user_id)
|
||||
q = q.filter_by(target_id=tenant_id)
|
||||
@ -368,7 +368,7 @@ class Assignment(keystone_assignment.AssignmentDriverV8):
|
||||
assignment['inherited_to_projects'] = 'projects'
|
||||
return assignment
|
||||
|
||||
with sql.transaction() as session:
|
||||
with sql.session_for_read() as session:
|
||||
assignment_types = self._get_assignment_types(
|
||||
user_id, group_ids, project_ids, domain_id)
|
||||
|
||||
@ -400,25 +400,25 @@ class Assignment(keystone_assignment.AssignmentDriverV8):
|
||||
return [denormalize_role(ref) for ref in query.all()]
|
||||
|
||||
def delete_project_assignments(self, project_id):
|
||||
with sql.transaction() as session:
|
||||
with sql.session_for_write() as session:
|
||||
q = session.query(RoleAssignment)
|
||||
q = q.filter_by(target_id=project_id)
|
||||
q.delete(False)
|
||||
|
||||
def delete_role_assignments(self, role_id):
|
||||
with sql.transaction() as session:
|
||||
with sql.session_for_write() as session:
|
||||
q = session.query(RoleAssignment)
|
||||
q = q.filter_by(role_id=role_id)
|
||||
q.delete(False)
|
||||
|
||||
def delete_user_assignments(self, user_id):
|
||||
with sql.transaction() as session:
|
||||
with sql.session_for_write() as session:
|
||||
q = session.query(RoleAssignment)
|
||||
q = q.filter_by(actor_id=user_id)
|
||||
q.delete(False)
|
||||
|
||||
def delete_group_assignments(self, group_id):
|
||||
with sql.transaction() as session:
|
||||
with sql.session_for_write() as session:
|
||||
q = session.query(RoleAssignment)
|
||||
q = q.filter_by(actor_id=group_id)
|
||||
q.delete(False)
|
||||
|
@ -19,14 +19,14 @@ class Role(assignment.RoleDriverV8):
|
||||
|
||||
@sql.handle_conflicts(conflict_type='role')
|
||||
def create_role(self, role_id, role):
|
||||
with sql.transaction() as session:
|
||||
with sql.session_for_write() as session:
|
||||
ref = RoleTable.from_dict(role)
|
||||
session.add(ref)
|
||||
return ref.to_dict()
|
||||
|
||||
@sql.truncated
|
||||
def list_roles(self, hints):
|
||||
with sql.transaction() as session:
|
||||
with sql.session_for_read() as session:
|
||||
query = session.query(RoleTable)
|
||||
refs = sql.filter_limit_query(RoleTable, query, hints)
|
||||
return [ref.to_dict() for ref in refs]
|
||||
@ -35,7 +35,7 @@ class Role(assignment.RoleDriverV8):
|
||||
if not ids:
|
||||
return []
|
||||
else:
|
||||
with sql.transaction() as session:
|
||||
with sql.session_for_read() as session:
|
||||
query = session.query(RoleTable)
|
||||
query = query.filter(RoleTable.id.in_(ids))
|
||||
role_refs = query.all()
|
||||
@ -48,12 +48,12 @@ class Role(assignment.RoleDriverV8):
|
||||
return ref
|
||||
|
||||
def get_role(self, role_id):
|
||||
with sql.transaction() as session:
|
||||
with sql.session_for_read() as session:
|
||||
return self._get_role(session, role_id).to_dict()
|
||||
|
||||
@sql.handle_conflicts(conflict_type='role')
|
||||
def update_role(self, role_id, role):
|
||||
with sql.transaction() as session:
|
||||
with sql.session_for_write() as session:
|
||||
ref = self._get_role(session, role_id)
|
||||
old_dict = ref.to_dict()
|
||||
for k in role:
|
||||
@ -66,7 +66,7 @@ class Role(assignment.RoleDriverV8):
|
||||
return ref.to_dict()
|
||||
|
||||
def delete_role(self, role_id):
|
||||
with sql.transaction() as session:
|
||||
with sql.session_for_write() as session:
|
||||
ref = self._get_role(session, role_id)
|
||||
session.delete(ref)
|
||||
|
||||
|
@ -55,7 +55,7 @@ class Assignment(keystone_assignment.AssignmentDriverV9):
|
||||
assignment_type = AssignmentType.calculate_type(
|
||||
user_id, group_id, project_id, domain_id)
|
||||
try:
|
||||
with sql.transaction() as session:
|
||||
with sql.session_for_write() as session:
|
||||
session.add(RoleAssignment(
|
||||
type=assignment_type,
|
||||
actor_id=user_id or group_id,
|
||||
@ -69,7 +69,7 @@ class Assignment(keystone_assignment.AssignmentDriverV9):
|
||||
def list_grant_role_ids(self, user_id=None, group_id=None,
|
||||
domain_id=None, project_id=None,
|
||||
inherited_to_projects=False):
|
||||
with sql.transaction() as session:
|
||||
with sql.session_for_read() as session:
|
||||
q = session.query(RoleAssignment.role_id)
|
||||
q = q.filter(RoleAssignment.actor_id == (user_id or group_id))
|
||||
q = q.filter(RoleAssignment.target_id == (project_id or domain_id))
|
||||
@ -88,7 +88,7 @@ class Assignment(keystone_assignment.AssignmentDriverV9):
|
||||
def check_grant_role_id(self, role_id, user_id=None, group_id=None,
|
||||
domain_id=None, project_id=None,
|
||||
inherited_to_projects=False):
|
||||
with sql.transaction() as session:
|
||||
with sql.session_for_read() as session:
|
||||
try:
|
||||
q = self._build_grant_filter(
|
||||
session, role_id, user_id, group_id, domain_id, project_id,
|
||||
@ -104,7 +104,7 @@ class Assignment(keystone_assignment.AssignmentDriverV9):
|
||||
def delete_grant(self, role_id, user_id=None, group_id=None,
|
||||
domain_id=None, project_id=None,
|
||||
inherited_to_projects=False):
|
||||
with sql.transaction() as session:
|
||||
with sql.session_for_write() as session:
|
||||
q = self._build_grant_filter(
|
||||
session, role_id, user_id, group_id, domain_id, project_id,
|
||||
inherited_to_projects)
|
||||
@ -117,7 +117,7 @@ class Assignment(keystone_assignment.AssignmentDriverV9):
|
||||
|
||||
def add_role_to_user_and_project(self, user_id, tenant_id, role_id):
|
||||
try:
|
||||
with sql.transaction() as session:
|
||||
with sql.session_for_write() as session:
|
||||
session.add(RoleAssignment(
|
||||
type=AssignmentType.USER_PROJECT,
|
||||
actor_id=user_id, target_id=tenant_id,
|
||||
@ -128,7 +128,7 @@ class Assignment(keystone_assignment.AssignmentDriverV9):
|
||||
raise exception.Conflict(type='role grant', details=msg)
|
||||
|
||||
def remove_role_from_user_and_project(self, user_id, tenant_id, role_id):
|
||||
with sql.transaction() as session:
|
||||
with sql.session_for_write() as session:
|
||||
q = session.query(RoleAssignment)
|
||||
q = q.filter_by(actor_id=user_id)
|
||||
q = q.filter_by(target_id=tenant_id)
|
||||
@ -218,7 +218,7 @@ class Assignment(keystone_assignment.AssignmentDriverV9):
|
||||
assignment['inherited_to_projects'] = 'projects'
|
||||
return assignment
|
||||
|
||||
with sql.transaction() as session:
|
||||
with sql.session_for_read() as session:
|
||||
assignment_types = self._get_assignment_types(
|
||||
user_id, group_ids, project_ids, domain_id)
|
||||
|
||||
@ -250,7 +250,7 @@ class Assignment(keystone_assignment.AssignmentDriverV9):
|
||||
return [denormalize_role(ref) for ref in query.all()]
|
||||
|
||||
def delete_project_assignments(self, project_id):
|
||||
with sql.transaction() as session:
|
||||
with sql.session_for_write() as session:
|
||||
q = session.query(RoleAssignment)
|
||||
q = q.filter_by(target_id=project_id).filter(
|
||||
RoleAssignment.type.in_((AssignmentType.USER_PROJECT,
|
||||
@ -259,13 +259,13 @@ class Assignment(keystone_assignment.AssignmentDriverV9):
|
||||
q.delete(False)
|
||||
|
||||
def delete_role_assignments(self, role_id):
|
||||
with sql.transaction() as session:
|
||||
with sql.session_for_write() as session:
|
||||
q = session.query(RoleAssignment)
|
||||
q = q.filter_by(role_id=role_id)
|
||||
q.delete(False)
|
||||
|
||||
def delete_user_assignments(self, user_id):
|
||||
with sql.transaction() as session:
|
||||
with sql.session_for_write() as session:
|
||||
q = session.query(RoleAssignment)
|
||||
q = q.filter_by(actor_id=user_id).filter(
|
||||
RoleAssignment.type.in_((AssignmentType.USER_PROJECT,
|
||||
@ -274,7 +274,7 @@ class Assignment(keystone_assignment.AssignmentDriverV9):
|
||||
q.delete(False)
|
||||
|
||||
def delete_group_assignments(self, group_id):
|
||||
with sql.transaction() as session:
|
||||
with sql.session_for_write() as session:
|
||||
q = session.query(RoleAssignment)
|
||||
q = q.filter_by(actor_id=group_id).filter(
|
||||
RoleAssignment.type.in_((AssignmentType.GROUP_PROJECT,
|
||||
|
@ -29,7 +29,7 @@ class Role(assignment.RoleDriverV9):
|
||||
|
||||
@sql.handle_conflicts(conflict_type='role')
|
||||
def create_role(self, role_id, role):
|
||||
with sql.transaction() as session:
|
||||
with sql.session_for_write() as session:
|
||||
ref = RoleTable.from_dict(role)
|
||||
session.add(ref)
|
||||
return ref.to_dict()
|
||||
@ -46,7 +46,7 @@ class Role(assignment.RoleDriverV9):
|
||||
if (f['name'] == 'domain_id' and f['value'] is None):
|
||||
f['value'] = NULL_DOMAIN_ID
|
||||
|
||||
with sql.transaction() as session:
|
||||
with sql.session_for_read() as session:
|
||||
query = session.query(RoleTable)
|
||||
refs = sql.filter_limit_query(RoleTable, query, hints)
|
||||
return [ref.to_dict() for ref in refs]
|
||||
@ -55,7 +55,7 @@ class Role(assignment.RoleDriverV9):
|
||||
if not ids:
|
||||
return []
|
||||
else:
|
||||
with sql.transaction() as session:
|
||||
with sql.session_for_read() as session:
|
||||
query = session.query(RoleTable)
|
||||
query = query.filter(RoleTable.id.in_(ids))
|
||||
role_refs = query.all()
|
||||
@ -68,12 +68,12 @@ class Role(assignment.RoleDriverV9):
|
||||
return ref
|
||||
|
||||
def get_role(self, role_id):
|
||||
with sql.transaction() as session:
|
||||
with sql.session_for_read() as session:
|
||||
return self._get_role(session, role_id).to_dict()
|
||||
|
||||
@sql.handle_conflicts(conflict_type='role')
|
||||
def update_role(self, role_id, role):
|
||||
with sql.transaction() as session:
|
||||
with sql.session_for_write() as session:
|
||||
ref = self._get_role(session, role_id)
|
||||
old_dict = ref.to_dict()
|
||||
for k in role:
|
||||
@ -86,7 +86,7 @@ class Role(assignment.RoleDriverV9):
|
||||
return ref.to_dict()
|
||||
|
||||
def delete_role(self, role_id):
|
||||
with sql.transaction() as session:
|
||||
with sql.session_for_write() as session:
|
||||
ref = self._get_role(session, role_id)
|
||||
session.delete(ref)
|
||||
|
||||
@ -105,7 +105,7 @@ class Role(assignment.RoleDriverV9):
|
||||
|
||||
@sql.handle_conflicts(conflict_type='implied_role')
|
||||
def create_implied_role(self, prior_role_id, implied_role_id):
|
||||
with sql.transaction() as session:
|
||||
with sql.session_for_write() as session:
|
||||
inference = {'prior_role_id': prior_role_id,
|
||||
'implied_role_id': implied_role_id}
|
||||
ref = ImpliedRoleTable.from_dict(inference)
|
||||
@ -119,13 +119,13 @@ class Role(assignment.RoleDriverV9):
|
||||
return ref.to_dict()
|
||||
|
||||
def delete_implied_role(self, prior_role_id, implied_role_id):
|
||||
with sql.transaction() as session:
|
||||
with sql.session_for_write() as session:
|
||||
ref = self._get_implied_role(session, prior_role_id,
|
||||
implied_role_id)
|
||||
session.delete(ref)
|
||||
|
||||
def list_implied_roles(self, prior_role_id):
|
||||
with sql.transaction() as session:
|
||||
with sql.session_for_read() as session:
|
||||
query = session.query(
|
||||
ImpliedRoleTable).filter(
|
||||
ImpliedRoleTable.prior_role_id == prior_role_id)
|
||||
@ -133,13 +133,13 @@ class Role(assignment.RoleDriverV9):
|
||||
return [ref.to_dict() for ref in refs]
|
||||
|
||||
def list_role_inference_rules(self):
|
||||
with sql.transaction() as session:
|
||||
with sql.session_for_read() as session:
|
||||
query = session.query(ImpliedRoleTable)
|
||||
refs = query.all()
|
||||
return [ref.to_dict() for ref in refs]
|
||||
|
||||
def get_implied_role(self, prior_role_id, implied_role_id):
|
||||
with sql.transaction() as session:
|
||||
with sql.session_for_read() as session:
|
||||
ref = self._get_implied_role(session, prior_role_id,
|
||||
implied_role_id)
|
||||
return ref.to_dict()
|
||||
|
@ -84,10 +84,10 @@ class Endpoint(sql.ModelBase, sql.DictBase):
|
||||
class Catalog(catalog.CatalogDriverV8):
|
||||
# Regions
|
||||
def list_regions(self, hints):
|
||||
session = sql.get_session()
|
||||
regions = session.query(Region)
|
||||
regions = sql.filter_limit_query(Region, regions, hints)
|
||||
return [s.to_dict() for s in list(regions)]
|
||||
with sql.session_for_read() as session:
|
||||
regions = session.query(Region)
|
||||
regions = sql.filter_limit_query(Region, regions, hints)
|
||||
return [s.to_dict() for s in list(regions)]
|
||||
|
||||
def _get_region(self, session, region_id):
|
||||
ref = session.query(Region).get(region_id)
|
||||
@ -136,12 +136,11 @@ class Catalog(catalog.CatalogDriverV8):
|
||||
return False
|
||||
|
||||
def get_region(self, region_id):
|
||||
session = sql.get_session()
|
||||
return self._get_region(session, region_id).to_dict()
|
||||
with sql.session_for_read() as session:
|
||||
return self._get_region(session, region_id).to_dict()
|
||||
|
||||
def delete_region(self, region_id):
|
||||
session = sql.get_session()
|
||||
with session.begin():
|
||||
with sql.session_for_write() as session:
|
||||
ref = self._get_region(session, region_id)
|
||||
if self._has_endpoints(session, ref, ref):
|
||||
raise exception.RegionDeletionError(region_id=region_id)
|
||||
@ -150,16 +149,14 @@ class Catalog(catalog.CatalogDriverV8):
|
||||
|
||||
@sql.handle_conflicts(conflict_type='region')
|
||||
def create_region(self, region_ref):
|
||||
session = sql.get_session()
|
||||
with session.begin():
|
||||
with sql.session_for_write() as session:
|
||||
self._check_parent_region(session, region_ref)
|
||||
region = Region.from_dict(region_ref)
|
||||
session.add(region)
|
||||
return region.to_dict()
|
||||
return region.to_dict()
|
||||
|
||||
def update_region(self, region_id, region_ref):
|
||||
session = sql.get_session()
|
||||
with session.begin():
|
||||
with sql.session_for_write() as session:
|
||||
self._check_parent_region(session, region_ref)
|
||||
ref = self._get_region(session, region_id)
|
||||
old_dict = ref.to_dict()
|
||||
@ -169,15 +166,15 @@ class Catalog(catalog.CatalogDriverV8):
|
||||
for attr in Region.attributes:
|
||||
if attr != 'id':
|
||||
setattr(ref, attr, getattr(new_region, attr))
|
||||
return ref.to_dict()
|
||||
return ref.to_dict()
|
||||
|
||||
# Services
|
||||
@driver_hints.truncated
|
||||
def list_services(self, hints):
|
||||
session = sql.get_session()
|
||||
services = session.query(Service)
|
||||
services = sql.filter_limit_query(Service, services, hints)
|
||||
return [s.to_dict() for s in list(services)]
|
||||
with sql.session_for_read() as session:
|
||||
services = session.query(Service)
|
||||
services = sql.filter_limit_query(Service, services, hints)
|
||||
return [s.to_dict() for s in list(services)]
|
||||
|
||||
def _get_service(self, session, service_id):
|
||||
ref = session.query(Service).get(service_id)
|
||||
@ -186,26 +183,23 @@ class Catalog(catalog.CatalogDriverV8):
|
||||
return ref
|
||||
|
||||
def get_service(self, service_id):
|
||||
session = sql.get_session()
|
||||
return self._get_service(session, service_id).to_dict()
|
||||
with sql.session_for_read() as session:
|
||||
return self._get_service(session, service_id).to_dict()
|
||||
|
||||
def delete_service(self, service_id):
|
||||
session = sql.get_session()
|
||||
with session.begin():
|
||||
with sql.session_for_write() as session:
|
||||
ref = self._get_service(session, service_id)
|
||||
session.query(Endpoint).filter_by(service_id=service_id).delete()
|
||||
session.delete(ref)
|
||||
|
||||
def create_service(self, service_id, service_ref):
|
||||
session = sql.get_session()
|
||||
with session.begin():
|
||||
with sql.session_for_write() as session:
|
||||
service = Service.from_dict(service_ref)
|
||||
session.add(service)
|
||||
return service.to_dict()
|
||||
return service.to_dict()
|
||||
|
||||
def update_service(self, service_id, service_ref):
|
||||
session = sql.get_session()
|
||||
with session.begin():
|
||||
with sql.session_for_write() as session:
|
||||
ref = self._get_service(session, service_id)
|
||||
old_dict = ref.to_dict()
|
||||
old_dict.update(service_ref)
|
||||
@ -214,20 +208,17 @@ class Catalog(catalog.CatalogDriverV8):
|
||||
if attr != 'id':
|
||||
setattr(ref, attr, getattr(new_service, attr))
|
||||
ref.extra = new_service.extra
|
||||
return ref.to_dict()
|
||||
return ref.to_dict()
|
||||
|
||||
# Endpoints
|
||||
def create_endpoint(self, endpoint_id, endpoint_ref):
|
||||
session = sql.get_session()
|
||||
new_endpoint = Endpoint.from_dict(endpoint_ref)
|
||||
|
||||
with session.begin():
|
||||
with sql.session_for_write() as session:
|
||||
session.add(new_endpoint)
|
||||
return new_endpoint.to_dict()
|
||||
|
||||
def delete_endpoint(self, endpoint_id):
|
||||
session = sql.get_session()
|
||||
with session.begin():
|
||||
with sql.session_for_write() as session:
|
||||
ref = self._get_endpoint(session, endpoint_id)
|
||||
session.delete(ref)
|
||||
|
||||
@ -238,20 +229,18 @@ class Catalog(catalog.CatalogDriverV8):
|
||||
raise exception.EndpointNotFound(endpoint_id=endpoint_id)
|
||||
|
||||
def get_endpoint(self, endpoint_id):
|
||||
session = sql.get_session()
|
||||
return self._get_endpoint(session, endpoint_id).to_dict()
|
||||
with sql.session_for_read() as session:
|
||||
return self._get_endpoint(session, endpoint_id).to_dict()
|
||||
|
||||
@driver_hints.truncated
|
||||
def list_endpoints(self, hints):
|
||||
session = sql.get_session()
|
||||
endpoints = session.query(Endpoint)
|
||||
endpoints = sql.filter_limit_query(Endpoint, endpoints, hints)
|
||||
return [e.to_dict() for e in list(endpoints)]
|
||||
with sql.session_for_read() as session:
|
||||
endpoints = session.query(Endpoint)
|
||||
endpoints = sql.filter_limit_query(Endpoint, endpoints, hints)
|
||||
return [e.to_dict() for e in list(endpoints)]
|
||||
|
||||
def update_endpoint(self, endpoint_id, endpoint_ref):
|
||||
session = sql.get_session()
|
||||
|
||||
with session.begin():
|
||||
with sql.session_for_write() as session:
|
||||
ref = self._get_endpoint(session, endpoint_id)
|
||||
old_dict = ref.to_dict()
|
||||
old_dict.update(endpoint_ref)
|
||||
@ -260,7 +249,7 @@ class Catalog(catalog.CatalogDriverV8):
|
||||
if attr != 'id':
|
||||
setattr(ref, attr, getattr(new_endpoint, attr))
|
||||
ref.extra = new_endpoint.extra
|
||||
return ref.to_dict()
|
||||
return ref.to_dict()
|
||||
|
||||
def get_catalog(self, user_id, tenant_id):
|
||||
"""Retrieve and format the V2 service catalog.
|
||||
@ -289,40 +278,40 @@ class Catalog(catalog.CatalogDriverV8):
|
||||
else:
|
||||
silent_keyerror_failures = ['tenant_id', 'project_id', ]
|
||||
|
||||
session = sql.get_session()
|
||||
endpoints = (session.query(Endpoint).
|
||||
options(sql.joinedload(Endpoint.service)).
|
||||
filter(Endpoint.enabled == true()).all())
|
||||
with sql.session_for_read() as session:
|
||||
endpoints = (session.query(Endpoint).
|
||||
options(sql.joinedload(Endpoint.service)).
|
||||
filter(Endpoint.enabled == true()).all())
|
||||
|
||||
catalog = {}
|
||||
catalog = {}
|
||||
|
||||
for endpoint in endpoints:
|
||||
if not endpoint.service['enabled']:
|
||||
continue
|
||||
try:
|
||||
formatted_url = core.format_url(
|
||||
endpoint['url'], substitutions,
|
||||
silent_keyerror_failures=silent_keyerror_failures)
|
||||
if formatted_url is not None:
|
||||
url = formatted_url
|
||||
else:
|
||||
for endpoint in endpoints:
|
||||
if not endpoint.service['enabled']:
|
||||
continue
|
||||
except exception.MalformedEndpoint:
|
||||
continue # this failure is already logged in format_url()
|
||||
try:
|
||||
formatted_url = core.format_url(
|
||||
endpoint['url'], substitutions,
|
||||
silent_keyerror_failures=silent_keyerror_failures)
|
||||
if formatted_url is not None:
|
||||
url = formatted_url
|
||||
else:
|
||||
continue
|
||||
except exception.MalformedEndpoint:
|
||||
continue # this failure is already logged in format_url()
|
||||
|
||||
region = endpoint['region_id']
|
||||
service_type = endpoint.service['type']
|
||||
default_service = {
|
||||
'id': endpoint['id'],
|
||||
'name': endpoint.service.extra.get('name', ''),
|
||||
'publicURL': ''
|
||||
}
|
||||
catalog.setdefault(region, {})
|
||||
catalog[region].setdefault(service_type, default_service)
|
||||
interface_url = '%sURL' % endpoint['interface']
|
||||
catalog[region][service_type][interface_url] = url
|
||||
region = endpoint['region_id']
|
||||
service_type = endpoint.service['type']
|
||||
default_service = {
|
||||
'id': endpoint['id'],
|
||||
'name': endpoint.service.extra.get('name', ''),
|
||||
'publicURL': ''
|
||||
}
|
||||
catalog.setdefault(region, {})
|
||||
catalog[region].setdefault(service_type, default_service)
|
||||
interface_url = '%sURL' % endpoint['interface']
|
||||
catalog[region][service_type][interface_url] = url
|
||||
|
||||
return catalog
|
||||
return catalog
|
||||
|
||||
def get_v3_catalog(self, user_id, tenant_id):
|
||||
"""Retrieve and format the current V3 service catalog.
|
||||
@ -349,44 +338,46 @@ class Catalog(catalog.CatalogDriverV8):
|
||||
else:
|
||||
silent_keyerror_failures = ['tenant_id', 'project_id', ]
|
||||
|
||||
session = sql.get_session()
|
||||
services = (session.query(Service).filter(Service.enabled == true()).
|
||||
options(sql.joinedload(Service.endpoints)).
|
||||
all())
|
||||
with sql.session_for_read() as session:
|
||||
services = (session.query(Service).filter(
|
||||
Service.enabled == true()).options(
|
||||
sql.joinedload(Service.endpoints)).all())
|
||||
|
||||
def make_v3_endpoints(endpoints):
|
||||
for endpoint in (ep.to_dict() for ep in endpoints if ep.enabled):
|
||||
del endpoint['service_id']
|
||||
del endpoint['legacy_endpoint_id']
|
||||
del endpoint['enabled']
|
||||
endpoint['region'] = endpoint['region_id']
|
||||
try:
|
||||
formatted_url = core.format_url(
|
||||
endpoint['url'], d,
|
||||
silent_keyerror_failures=silent_keyerror_failures)
|
||||
if formatted_url:
|
||||
endpoint['url'] = formatted_url
|
||||
else:
|
||||
def make_v3_endpoints(endpoints):
|
||||
for endpoint in (ep.to_dict()
|
||||
for ep in endpoints if ep.enabled):
|
||||
del endpoint['service_id']
|
||||
del endpoint['legacy_endpoint_id']
|
||||
del endpoint['enabled']
|
||||
endpoint['region'] = endpoint['region_id']
|
||||
try:
|
||||
formatted_url = core.format_url(
|
||||
endpoint['url'], d,
|
||||
silent_keyerror_failures=silent_keyerror_failures)
|
||||
if formatted_url:
|
||||
endpoint['url'] = formatted_url
|
||||
else:
|
||||
continue
|
||||
except exception.MalformedEndpoint:
|
||||
# this failure is already logged in format_url()
|
||||
continue
|
||||
except exception.MalformedEndpoint:
|
||||
continue # this failure is already logged in format_url()
|
||||
|
||||
yield endpoint
|
||||
yield endpoint
|
||||
|
||||
# TODO(davechen): If there is service with no endpoints, we should skip
|
||||
# the service instead of keeping it in the catalog, see bug #1436704.
|
||||
def make_v3_service(svc):
|
||||
eps = list(make_v3_endpoints(svc.endpoints))
|
||||
service = {'endpoints': eps, 'id': svc.id, 'type': svc.type}
|
||||
service['name'] = svc.extra.get('name', '')
|
||||
return service
|
||||
# TODO(davechen): If there is service with no endpoints, we should
|
||||
# skip the service instead of keeping it in the catalog,
|
||||
# see bug #1436704.
|
||||
def make_v3_service(svc):
|
||||
eps = list(make_v3_endpoints(svc.endpoints))
|
||||
service = {'endpoints': eps, 'id': svc.id, 'type': svc.type}
|
||||
service['name'] = svc.extra.get('name', '')
|
||||
return service
|
||||
|
||||
return [make_v3_service(svc) for svc in services]
|
||||
return [make_v3_service(svc) for svc in services]
|
||||
|
||||
@sql.handle_conflicts(conflict_type='project_endpoint')
|
||||
def add_endpoint_to_project(self, endpoint_id, project_id):
|
||||
session = sql.get_session()
|
||||
with session.begin():
|
||||
with sql.session_for_write() as session:
|
||||
endpoint_filter_ref = ProjectEndpoint(endpoint_id=endpoint_id,
|
||||
project_id=project_id)
|
||||
session.add(endpoint_filter_ref)
|
||||
@ -402,50 +393,46 @@ class Catalog(catalog.CatalogDriverV8):
|
||||
return endpoint_filter_ref
|
||||
|
||||
def check_endpoint_in_project(self, endpoint_id, project_id):
|
||||
session = sql.get_session()
|
||||
self._get_project_endpoint_ref(session, endpoint_id, project_id)
|
||||
with sql.session_for_read() as session:
|
||||
self._get_project_endpoint_ref(session, endpoint_id, project_id)
|
||||
|
||||
def remove_endpoint_from_project(self, endpoint_id, project_id):
|
||||
session = sql.get_session()
|
||||
endpoint_filter_ref = self._get_project_endpoint_ref(
|
||||
session, endpoint_id, project_id)
|
||||
with session.begin():
|
||||
with sql.session_for_write() as session:
|
||||
endpoint_filter_ref = self._get_project_endpoint_ref(
|
||||
session, endpoint_id, project_id)
|
||||
session.delete(endpoint_filter_ref)
|
||||
|
||||
def list_endpoints_for_project(self, project_id):
|
||||
session = sql.get_session()
|
||||
query = session.query(ProjectEndpoint)
|
||||
query = query.filter_by(project_id=project_id)
|
||||
endpoint_filter_refs = query.all()
|
||||
return [ref.to_dict() for ref in endpoint_filter_refs]
|
||||
with sql.session_for_read() as session:
|
||||
query = session.query(ProjectEndpoint)
|
||||
query = query.filter_by(project_id=project_id)
|
||||
endpoint_filter_refs = query.all()
|
||||
return [ref.to_dict() for ref in endpoint_filter_refs]
|
||||
|
||||
def list_projects_for_endpoint(self, endpoint_id):
|
||||
session = sql.get_session()
|
||||
query = session.query(ProjectEndpoint)
|
||||
query = query.filter_by(endpoint_id=endpoint_id)
|
||||
endpoint_filter_refs = query.all()
|
||||
return [ref.to_dict() for ref in endpoint_filter_refs]
|
||||
with sql.session_for_read() as session:
|
||||
query = session.query(ProjectEndpoint)
|
||||
query = query.filter_by(endpoint_id=endpoint_id)
|
||||
endpoint_filter_refs = query.all()
|
||||
return [ref.to_dict() for ref in endpoint_filter_refs]
|
||||
|
||||
def delete_association_by_endpoint(self, endpoint_id):
|
||||
session = sql.get_session()
|
||||
with session.begin():
|
||||
with sql.session_for_write() as session:
|
||||
query = session.query(ProjectEndpoint)
|
||||
query = query.filter_by(endpoint_id=endpoint_id)
|
||||
query.delete(synchronize_session=False)
|
||||
|
||||
def delete_association_by_project(self, project_id):
|
||||
session = sql.get_session()
|
||||
with session.begin():
|
||||
with sql.session_for_write() as session:
|
||||
query = session.query(ProjectEndpoint)
|
||||
query = query.filter_by(project_id=project_id)
|
||||
query.delete(synchronize_session=False)
|
||||
|
||||
def create_endpoint_group(self, endpoint_group_id, endpoint_group):
|
||||
session = sql.get_session()
|
||||
with session.begin():
|
||||
with sql.session_for_write() as session:
|
||||
endpoint_group_ref = EndpointGroup.from_dict(endpoint_group)
|
||||
session.add(endpoint_group_ref)
|
||||
return endpoint_group_ref.to_dict()
|
||||
return endpoint_group_ref.to_dict()
|
||||
|
||||
def _get_endpoint_group(self, session, endpoint_group_id):
|
||||
endpoint_group_ref = session.query(EndpointGroup).get(
|
||||
@ -456,14 +443,13 @@ class Catalog(catalog.CatalogDriverV8):
|
||||
return endpoint_group_ref
|
||||
|
||||
def get_endpoint_group(self, endpoint_group_id):
|
||||
session = sql.get_session()
|
||||
endpoint_group_ref = self._get_endpoint_group(session,
|
||||
endpoint_group_id)
|
||||
return endpoint_group_ref.to_dict()
|
||||
with sql.session_for_read() as session:
|
||||
endpoint_group_ref = self._get_endpoint_group(session,
|
||||
endpoint_group_id)
|
||||
return endpoint_group_ref.to_dict()
|
||||
|
||||
def update_endpoint_group(self, endpoint_group_id, endpoint_group):
|
||||
session = sql.get_session()
|
||||
with session.begin():
|
||||
with sql.session_for_write() as session:
|
||||
endpoint_group_ref = self._get_endpoint_group(session,
|
||||
endpoint_group_id)
|
||||
old_endpoint_group = endpoint_group_ref.to_dict()
|
||||
@ -472,29 +458,26 @@ class Catalog(catalog.CatalogDriverV8):
|
||||
for attr in EndpointGroup.mutable_attributes:
|
||||
setattr(endpoint_group_ref, attr,
|
||||
getattr(new_endpoint_group, attr))
|
||||
return endpoint_group_ref.to_dict()
|
||||
return endpoint_group_ref.to_dict()
|
||||
|
||||
def delete_endpoint_group(self, endpoint_group_id):
|
||||
session = sql.get_session()
|
||||
endpoint_group_ref = self._get_endpoint_group(session,
|
||||
endpoint_group_id)
|
||||
with session.begin():
|
||||
with sql.session_for_write() as session:
|
||||
endpoint_group_ref = self._get_endpoint_group(session,
|
||||
endpoint_group_id)
|
||||
self._delete_endpoint_group_association_by_endpoint_group(
|
||||
session, endpoint_group_id)
|
||||
session.delete(endpoint_group_ref)
|
||||
|
||||
def get_endpoint_group_in_project(self, endpoint_group_id, project_id):
|
||||
session = sql.get_session()
|
||||
ref = self._get_endpoint_group_in_project(session,
|
||||
endpoint_group_id,
|
||||
project_id)
|
||||
return ref.to_dict()
|
||||
with sql.session_for_read() as session:
|
||||
ref = self._get_endpoint_group_in_project(session,
|
||||
endpoint_group_id,
|
||||
project_id)
|
||||
return ref.to_dict()
|
||||
|
||||
@sql.handle_conflicts(conflict_type='project_endpoint_group')
|
||||
def add_endpoint_group_to_project(self, endpoint_group_id, project_id):
|
||||
session = sql.get_session()
|
||||
|
||||
with session.begin():
|
||||
with sql.session_for_write() as session:
|
||||
# Create a new Project Endpoint group entity
|
||||
endpoint_group_project_ref = ProjectEndpointGroupMembership(
|
||||
endpoint_group_id=endpoint_group_id, project_id=project_id)
|
||||
@ -512,32 +495,31 @@ class Catalog(catalog.CatalogDriverV8):
|
||||
return endpoint_group_project_ref
|
||||
|
||||
def list_endpoint_groups(self):
|
||||
session = sql.get_session()
|
||||
query = session.query(EndpointGroup)
|
||||
endpoint_group_refs = query.all()
|
||||
return [e.to_dict() for e in endpoint_group_refs]
|
||||
with sql.session_for_read() as session:
|
||||
query = session.query(EndpointGroup)
|
||||
endpoint_group_refs = query.all()
|
||||
return [e.to_dict() for e in endpoint_group_refs]
|
||||
|
||||
def list_endpoint_groups_for_project(self, project_id):
|
||||
session = sql.get_session()
|
||||
query = session.query(ProjectEndpointGroupMembership)
|
||||
query = query.filter_by(project_id=project_id)
|
||||
endpoint_group_refs = query.all()
|
||||
return [ref.to_dict() for ref in endpoint_group_refs]
|
||||
with sql.session_for_read() as session:
|
||||
query = session.query(ProjectEndpointGroupMembership)
|
||||
query = query.filter_by(project_id=project_id)
|
||||
endpoint_group_refs = query.all()
|
||||
return [ref.to_dict() for ref in endpoint_group_refs]
|
||||
|
||||
def remove_endpoint_group_from_project(self, endpoint_group_id,
|
||||
project_id):
|
||||
session = sql.get_session()
|
||||
endpoint_group_project_ref = self._get_endpoint_group_in_project(
|
||||
session, endpoint_group_id, project_id)
|
||||
with session.begin():
|
||||
with sql.session_for_write() as session:
|
||||
endpoint_group_project_ref = self._get_endpoint_group_in_project(
|
||||
session, endpoint_group_id, project_id)
|
||||
session.delete(endpoint_group_project_ref)
|
||||
|
||||
def list_projects_associated_with_endpoint_group(self, endpoint_group_id):
|
||||
session = sql.get_session()
|
||||
query = session.query(ProjectEndpointGroupMembership)
|
||||
query = query.filter_by(endpoint_group_id=endpoint_group_id)
|
||||
endpoint_group_refs = query.all()
|
||||
return [ref.to_dict() for ref in endpoint_group_refs]
|
||||
with sql.session_for_read() as session:
|
||||
query = session.query(ProjectEndpointGroupMembership)
|
||||
query = query.filter_by(endpoint_group_id=endpoint_group_id)
|
||||
endpoint_group_refs = query.all()
|
||||
return [ref.to_dict() for ref in endpoint_group_refs]
|
||||
|
||||
def _delete_endpoint_group_association_by_endpoint_group(
|
||||
self, session, endpoint_group_id):
|
||||
@ -546,8 +528,7 @@ class Catalog(catalog.CatalogDriverV8):
|
||||
query.delete()
|
||||
|
||||
def delete_endpoint_group_association_by_project(self, project_id):
|
||||
session = sql.get_session()
|
||||
with session.begin():
|
||||
with sql.session_for_write() as session:
|
||||
query = session.query(ProjectEndpointGroupMembership)
|
||||
query = query.filter_by(project_id=project_id)
|
||||
query.delete()
|
||||
|
@ -18,14 +18,14 @@ Before using this module, call initialize(). This has to be done before
|
||||
CONF() because it sets up configuration options.
|
||||
|
||||
"""
|
||||
import contextlib
|
||||
import functools
|
||||
import threading
|
||||
|
||||
from oslo_config import cfg
|
||||
from oslo_db import exception as db_exception
|
||||
from oslo_db import options as db_options
|
||||
from oslo_db.sqlalchemy import enginefacade
|
||||
from oslo_db.sqlalchemy import models
|
||||
from oslo_db.sqlalchemy import session as db_session
|
||||
from oslo_log import log
|
||||
from oslo_serialization import jsonutils
|
||||
import six
|
||||
@ -166,38 +166,41 @@ class ModelDictMixin(object):
|
||||
return {name: getattr(self, name) for name in names}
|
||||
|
||||
|
||||
_engine_facade = None
|
||||
_main_context_manager = None
|
||||
|
||||
|
||||
def _get_engine_facade():
|
||||
global _engine_facade
|
||||
def _get_main_context_manager():
|
||||
global _main_context_manager
|
||||
|
||||
if not _engine_facade:
|
||||
_engine_facade = db_session.EngineFacade.from_config(CONF)
|
||||
if not _main_context_manager:
|
||||
_main_context_manager = enginefacade.transaction_context()
|
||||
|
||||
return _engine_facade
|
||||
return _main_context_manager
|
||||
|
||||
|
||||
def cleanup():
|
||||
global _engine_facade
|
||||
global _main_context_manager
|
||||
|
||||
_engine_facade = None
|
||||
_main_context_manager = None
|
||||
|
||||
|
||||
def get_engine():
|
||||
return _get_engine_facade().get_engine()
|
||||
return _get_main_context_manager().get_legacy_facade().get_engine()
|
||||
|
||||
|
||||
def get_session(expire_on_commit=False):
|
||||
return _get_engine_facade().get_session(expire_on_commit=expire_on_commit)
|
||||
def get_session():
|
||||
return _get_main_context_manager().get_legacy_facade().get_session()
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
def transaction(expire_on_commit=False):
|
||||
"""Return a SQLAlchemy session in a scoped transaction."""
|
||||
session = get_session(expire_on_commit=expire_on_commit)
|
||||
with session.begin():
|
||||
yield session
|
||||
_CONTEXT = threading.local()
|
||||
|
||||
|
||||
def session_for_read():
|
||||
return _get_main_context_manager().reader.using(_CONTEXT)
|
||||
|
||||
|
||||
def session_for_write():
|
||||
return _get_main_context_manager().writer.using(_CONTEXT)
|
||||
|
||||
|
||||
def truncated(f):
|
||||
|
@ -178,7 +178,7 @@ def _sync_extension_repo(extension, version):
|
||||
try:
|
||||
abs_path = find_migrate_repo(package)
|
||||
try:
|
||||
migration.db_version_control(sql.get_engine(), abs_path)
|
||||
migration.db_version_control(engine, abs_path)
|
||||
# Register the repo with the version control API
|
||||
# If it already knows about the repo, it will throw
|
||||
# an exception that we can safely ignore
|
||||
|
@ -36,28 +36,27 @@ class Credential(credential.CredentialDriverV8):
|
||||
|
||||
@sql.handle_conflicts(conflict_type='credential')
|
||||
def create_credential(self, credential_id, credential):
|
||||
session = sql.get_session()
|
||||
with session.begin():
|
||||
with sql.session_for_write() as session:
|
||||
ref = CredentialModel.from_dict(credential)
|
||||
session.add(ref)
|
||||
return ref.to_dict()
|
||||
return ref.to_dict()
|
||||
|
||||
@driver_hints.truncated
|
||||
def list_credentials(self, hints):
|
||||
session = sql.get_session()
|
||||
credentials = session.query(CredentialModel)
|
||||
credentials = sql.filter_limit_query(CredentialModel,
|
||||
credentials, hints)
|
||||
return [s.to_dict() for s in credentials]
|
||||
with sql.session_for_read() as session:
|
||||
credentials = session.query(CredentialModel)
|
||||
credentials = sql.filter_limit_query(CredentialModel,
|
||||
credentials, hints)
|
||||
return [s.to_dict() for s in credentials]
|
||||
|
||||
def list_credentials_for_user(self, user_id, type=None):
|
||||
session = sql.get_session()
|
||||
query = session.query(CredentialModel)
|
||||
query = query.filter_by(user_id=user_id)
|
||||
if type:
|
||||
query = query.filter_by(type=type)
|
||||
refs = query.all()
|
||||
return [ref.to_dict() for ref in refs]
|
||||
with sql.session_for_read() as session:
|
||||
query = session.query(CredentialModel)
|
||||
query = query.filter_by(user_id=user_id)
|
||||
if type:
|
||||
query = query.filter_by(type=type)
|
||||
refs = query.all()
|
||||
return [ref.to_dict() for ref in refs]
|
||||
|
||||
def _get_credential(self, session, credential_id):
|
||||
ref = session.query(CredentialModel).get(credential_id)
|
||||
@ -66,13 +65,12 @@ class Credential(credential.CredentialDriverV8):
|
||||
return ref
|
||||
|
||||
def get_credential(self, credential_id):
|
||||
session = sql.get_session()
|
||||
return self._get_credential(session, credential_id).to_dict()
|
||||
with sql.session_for_read() as session:
|
||||
return self._get_credential(session, credential_id).to_dict()
|
||||
|
||||
@sql.handle_conflicts(conflict_type='credential')
|
||||
def update_credential(self, credential_id, credential):
|
||||
session = sql.get_session()
|
||||
with session.begin():
|
||||
with sql.session_for_write() as session:
|
||||
ref = self._get_credential(session, credential_id)
|
||||
old_dict = ref.to_dict()
|
||||
for k in credential:
|
||||
@ -82,27 +80,21 @@ class Credential(credential.CredentialDriverV8):
|
||||
if attr != 'id':
|
||||
setattr(ref, attr, getattr(new_credential, attr))
|
||||
ref.extra = new_credential.extra
|
||||
return ref.to_dict()
|
||||
return ref.to_dict()
|
||||
|
||||
def delete_credential(self, credential_id):
|
||||
session = sql.get_session()
|
||||
|
||||
with session.begin():
|
||||
with sql.session_for_write() as session:
|
||||
ref = self._get_credential(session, credential_id)
|
||||
session.delete(ref)
|
||||
|
||||
def delete_credentials_for_project(self, project_id):
|
||||
session = sql.get_session()
|
||||
|
||||
with session.begin():
|
||||
with sql.session_for_write() as session:
|
||||
query = session.query(CredentialModel)
|
||||
query = query.filter_by(project_id=project_id)
|
||||
query.delete()
|
||||
|
||||
def delete_credentials_for_user(self, user_id):
|
||||
session = sql.get_session()
|
||||
|
||||
with session.begin():
|
||||
with sql.session_for_write() as session:
|
||||
query = session.query(CredentialModel)
|
||||
query = query.filter_by(user_id=user_id)
|
||||
query.delete()
|
||||
|
@ -51,7 +51,7 @@ class EndpointPolicy(object):
|
||||
|
||||
def create_policy_association(self, policy_id, endpoint_id=None,
|
||||
service_id=None, region_id=None):
|
||||
with sql.transaction() as session:
|
||||
with sql.session_for_write() as session:
|
||||
try:
|
||||
# See if there is already a row for this association, and if
|
||||
# so, update it with the new policy_id
|
||||
@ -79,14 +79,14 @@ class EndpointPolicy(object):
|
||||
|
||||
# NOTE(henry-nash): Getting a single value to save object
|
||||
# management overhead.
|
||||
with sql.transaction() as session:
|
||||
with sql.session_for_read() as session:
|
||||
if session.query(PolicyAssociation.id).filter(
|
||||
sql_constraints).distinct().count() == 0:
|
||||
raise exception.PolicyAssociationNotFound()
|
||||
|
||||
def delete_policy_association(self, policy_id, endpoint_id=None,
|
||||
service_id=None, region_id=None):
|
||||
with sql.transaction() as session:
|
||||
with sql.session_for_write() as session:
|
||||
query = session.query(PolicyAssociation)
|
||||
query = query.filter_by(policy_id=policy_id)
|
||||
query = query.filter_by(endpoint_id=endpoint_id)
|
||||
@ -102,7 +102,7 @@ class EndpointPolicy(object):
|
||||
PolicyAssociation.region_id == region_id)
|
||||
|
||||
try:
|
||||
with sql.transaction() as session:
|
||||
with sql.session_for_read() as session:
|
||||
policy_id = session.query(PolicyAssociation.policy_id).filter(
|
||||
sql_constraints).distinct().one()
|
||||
return {'policy_id': policy_id}
|
||||
@ -110,31 +110,31 @@ class EndpointPolicy(object):
|
||||
raise exception.PolicyAssociationNotFound()
|
||||
|
||||
def list_associations_for_policy(self, policy_id):
|
||||
with sql.transaction() as session:
|
||||
with sql.session_for_read() as session:
|
||||
query = session.query(PolicyAssociation)
|
||||
query = query.filter_by(policy_id=policy_id)
|
||||
return [ref.to_dict() for ref in query.all()]
|
||||
|
||||
def delete_association_by_endpoint(self, endpoint_id):
|
||||
with sql.transaction() as session:
|
||||
with sql.session_for_write() as session:
|
||||
query = session.query(PolicyAssociation)
|
||||
query = query.filter_by(endpoint_id=endpoint_id)
|
||||
query.delete()
|
||||
|
||||
def delete_association_by_service(self, service_id):
|
||||
with sql.transaction() as session:
|
||||
with sql.session_for_write() as session:
|
||||
query = session.query(PolicyAssociation)
|
||||
query = query.filter_by(service_id=service_id)
|
||||
query.delete()
|
||||
|
||||
def delete_association_by_region(self, region_id):
|
||||
with sql.transaction() as session:
|
||||
with sql.session_for_write() as session:
|
||||
query = session.query(PolicyAssociation)
|
||||
query = query.filter_by(region_id=region_id)
|
||||
query.delete()
|
||||
|
||||
def delete_association_by_policy(self, policy_id):
|
||||
with sql.transaction() as session:
|
||||
with sql.session_for_write() as session:
|
||||
query = session.query(PolicyAssociation)
|
||||
query = query.filter_by(policy_id=policy_id)
|
||||
query.delete()
|
||||
|
@ -161,13 +161,13 @@ class Federation(core.FederationDriverV8):
|
||||
@sql.handle_conflicts(conflict_type='identity_provider')
|
||||
def create_idp(self, idp_id, idp):
|
||||
idp['id'] = idp_id
|
||||
with sql.transaction() as session:
|
||||
with sql.session_for_write() as session:
|
||||
idp_ref = IdentityProviderModel.from_dict(idp)
|
||||
session.add(idp_ref)
|
||||
return idp_ref.to_dict()
|
||||
return idp_ref.to_dict()
|
||||
|
||||
def delete_idp(self, idp_id):
|
||||
with sql.transaction() as session:
|
||||
with sql.session_for_write() as session:
|
||||
self._delete_assigned_protocols(session, idp_id)
|
||||
idp_ref = self._get_idp(session, idp_id)
|
||||
session.delete(idp_ref)
|
||||
@ -187,30 +187,30 @@ class Federation(core.FederationDriverV8):
|
||||
raise exception.IdentityProviderNotFound(idp_id=remote_id)
|
||||
|
||||
def list_idps(self):
|
||||
with sql.transaction() as session:
|
||||
with sql.session_for_read() as session:
|
||||
idps = session.query(IdentityProviderModel)
|
||||
idps_list = [idp.to_dict() for idp in idps]
|
||||
return idps_list
|
||||
idps_list = [idp.to_dict() for idp in idps]
|
||||
return idps_list
|
||||
|
||||
def get_idp(self, idp_id):
|
||||
with sql.transaction() as session:
|
||||
with sql.session_for_read() as session:
|
||||
idp_ref = self._get_idp(session, idp_id)
|
||||
return idp_ref.to_dict()
|
||||
return idp_ref.to_dict()
|
||||
|
||||
def get_idp_from_remote_id(self, remote_id):
|
||||
with sql.transaction() as session:
|
||||
with sql.session_for_read() as session:
|
||||
ref = self._get_idp_from_remote_id(session, remote_id)
|
||||
return ref.to_dict()
|
||||
return ref.to_dict()
|
||||
|
||||
def update_idp(self, idp_id, idp):
|
||||
with sql.transaction() as session:
|
||||
with sql.session_for_write() as session:
|
||||
idp_ref = self._get_idp(session, idp_id)
|
||||
old_idp = idp_ref.to_dict()
|
||||
old_idp.update(idp)
|
||||
new_idp = IdentityProviderModel.from_dict(old_idp)
|
||||
for attr in IdentityProviderModel.mutable_attributes:
|
||||
setattr(idp_ref, attr, getattr(new_idp, attr))
|
||||
return idp_ref.to_dict()
|
||||
return idp_ref.to_dict()
|
||||
|
||||
# Protocol CRUD
|
||||
def _get_protocol(self, session, idp_id, protocol_id):
|
||||
@ -227,36 +227,36 @@ class Federation(core.FederationDriverV8):
|
||||
def create_protocol(self, idp_id, protocol_id, protocol):
|
||||
protocol['id'] = protocol_id
|
||||
protocol['idp_id'] = idp_id
|
||||
with sql.transaction() as session:
|
||||
with sql.session_for_write() as session:
|
||||
self._get_idp(session, idp_id)
|
||||
protocol_ref = FederationProtocolModel.from_dict(protocol)
|
||||
session.add(protocol_ref)
|
||||
return protocol_ref.to_dict()
|
||||
return protocol_ref.to_dict()
|
||||
|
||||
def update_protocol(self, idp_id, protocol_id, protocol):
|
||||
with sql.transaction() as session:
|
||||
with sql.session_for_write() as session:
|
||||
proto_ref = self._get_protocol(session, idp_id, protocol_id)
|
||||
old_proto = proto_ref.to_dict()
|
||||
old_proto.update(protocol)
|
||||
new_proto = FederationProtocolModel.from_dict(old_proto)
|
||||
for attr in FederationProtocolModel.mutable_attributes:
|
||||
setattr(proto_ref, attr, getattr(new_proto, attr))
|
||||
return proto_ref.to_dict()
|
||||
return proto_ref.to_dict()
|
||||
|
||||
def get_protocol(self, idp_id, protocol_id):
|
||||
with sql.transaction() as session:
|
||||
with sql.session_for_read() as session:
|
||||
protocol_ref = self._get_protocol(session, idp_id, protocol_id)
|
||||
return protocol_ref.to_dict()
|
||||
return protocol_ref.to_dict()
|
||||
|
||||
def list_protocols(self, idp_id):
|
||||
with sql.transaction() as session:
|
||||
with sql.session_for_read() as session:
|
||||
q = session.query(FederationProtocolModel)
|
||||
q = q.filter_by(idp_id=idp_id)
|
||||
protocols = [protocol.to_dict() for protocol in q]
|
||||
return protocols
|
||||
protocols = [protocol.to_dict() for protocol in q]
|
||||
return protocols
|
||||
|
||||
def delete_protocol(self, idp_id, protocol_id):
|
||||
with sql.transaction() as session:
|
||||
with sql.session_for_write() as session:
|
||||
key_ref = self._get_protocol(session, idp_id, protocol_id)
|
||||
session.delete(key_ref)
|
||||
|
||||
@ -277,58 +277,58 @@ class Federation(core.FederationDriverV8):
|
||||
ref = {}
|
||||
ref['id'] = mapping_id
|
||||
ref['rules'] = mapping.get('rules')
|
||||
with sql.transaction() as session:
|
||||
with sql.session_for_write() as session:
|
||||
mapping_ref = MappingModel.from_dict(ref)
|
||||
session.add(mapping_ref)
|
||||
return mapping_ref.to_dict()
|
||||
return mapping_ref.to_dict()
|
||||
|
||||
def delete_mapping(self, mapping_id):
|
||||
with sql.transaction() as session:
|
||||
with sql.session_for_write() as session:
|
||||
mapping_ref = self._get_mapping(session, mapping_id)
|
||||
session.delete(mapping_ref)
|
||||
|
||||
def list_mappings(self):
|
||||
with sql.transaction() as session:
|
||||
with sql.session_for_read() as session:
|
||||
mappings = session.query(MappingModel)
|
||||
return [x.to_dict() for x in mappings]
|
||||
return [x.to_dict() for x in mappings]
|
||||
|
||||
def get_mapping(self, mapping_id):
|
||||
with sql.transaction() as session:
|
||||
with sql.session_for_read() as session:
|
||||
mapping_ref = self._get_mapping(session, mapping_id)
|
||||
return mapping_ref.to_dict()
|
||||
return mapping_ref.to_dict()
|
||||
|
||||
@sql.handle_conflicts(conflict_type='mapping')
|
||||
def update_mapping(self, mapping_id, mapping):
|
||||
ref = {}
|
||||
ref['id'] = mapping_id
|
||||
ref['rules'] = mapping.get('rules')
|
||||
with sql.transaction() as session:
|
||||
with sql.session_for_write() as session:
|
||||
mapping_ref = self._get_mapping(session, mapping_id)
|
||||
old_mapping = mapping_ref.to_dict()
|
||||
old_mapping.update(ref)
|
||||
new_mapping = MappingModel.from_dict(old_mapping)
|
||||
for attr in MappingModel.attributes:
|
||||
setattr(mapping_ref, attr, getattr(new_mapping, attr))
|
||||
return mapping_ref.to_dict()
|
||||
return mapping_ref.to_dict()
|
||||
|
||||
def get_mapping_from_idp_and_protocol(self, idp_id, protocol_id):
|
||||
with sql.transaction() as session:
|
||||
with sql.session_for_read() as session:
|
||||
protocol_ref = self._get_protocol(session, idp_id, protocol_id)
|
||||
mapping_id = protocol_ref.mapping_id
|
||||
mapping_ref = self._get_mapping(session, mapping_id)
|
||||
return mapping_ref.to_dict()
|
||||
return mapping_ref.to_dict()
|
||||
|
||||
# Service Provider CRUD
|
||||
@sql.handle_conflicts(conflict_type='service_provider')
|
||||
def create_sp(self, sp_id, sp):
|
||||
sp['id'] = sp_id
|
||||
with sql.transaction() as session:
|
||||
with sql.session_for_write() as session:
|
||||
sp_ref = ServiceProviderModel.from_dict(sp)
|
||||
session.add(sp_ref)
|
||||
return sp_ref.to_dict()
|
||||
return sp_ref.to_dict()
|
||||
|
||||
def delete_sp(self, sp_id):
|
||||
with sql.transaction() as session:
|
||||
with sql.session_for_write() as session:
|
||||
sp_ref = self._get_sp(session, sp_id)
|
||||
session.delete(sp_ref)
|
||||
|
||||
@ -339,28 +339,28 @@ class Federation(core.FederationDriverV8):
|
||||
return sp_ref
|
||||
|
||||
def list_sps(self):
|
||||
with sql.transaction() as session:
|
||||
with sql.session_for_read() as session:
|
||||
sps = session.query(ServiceProviderModel)
|
||||
sps_list = [sp.to_dict() for sp in sps]
|
||||
return sps_list
|
||||
sps_list = [sp.to_dict() for sp in sps]
|
||||
return sps_list
|
||||
|
||||
def get_sp(self, sp_id):
|
||||
with sql.transaction() as session:
|
||||
with sql.session_for_read() as session:
|
||||
sp_ref = self._get_sp(session, sp_id)
|
||||
return sp_ref.to_dict()
|
||||
return sp_ref.to_dict()
|
||||
|
||||
def update_sp(self, sp_id, sp):
|
||||
with sql.transaction() as session:
|
||||
with sql.session_for_write() as session:
|
||||
sp_ref = self._get_sp(session, sp_id)
|
||||
old_sp = sp_ref.to_dict()
|
||||
old_sp.update(sp)
|
||||
new_sp = ServiceProviderModel.from_dict(old_sp)
|
||||
for attr in ServiceProviderModel.mutable_attributes:
|
||||
setattr(sp_ref, attr, getattr(new_sp, attr))
|
||||
return sp_ref.to_dict()
|
||||
return sp_ref.to_dict()
|
||||
|
||||
def get_enabled_service_providers(self):
|
||||
with sql.transaction() as session:
|
||||
with sql.session_for_read() as session:
|
||||
service_providers = session.query(ServiceProviderModel)
|
||||
service_providers = service_providers.filter_by(enabled=True)
|
||||
return service_providers
|
||||
return service_providers
|
||||
|
@ -169,10 +169,10 @@ class Federation(core.FederationDriverV9):
|
||||
def create_idp(self, idp_id, idp):
|
||||
idp['id'] = idp_id
|
||||
try:
|
||||
with sql.transaction() as session:
|
||||
with sql.session_for_write() as session:
|
||||
idp_ref = IdentityProviderModel.from_dict(idp)
|
||||
session.add(idp_ref)
|
||||
return idp_ref.to_dict()
|
||||
return idp_ref.to_dict()
|
||||
except sql.DBDuplicateEntry as e:
|
||||
conflict_type = 'identity_provider'
|
||||
details = six.text_type(e)
|
||||
@ -186,7 +186,7 @@ class Federation(core.FederationDriverV9):
|
||||
raise exception.Conflict(type=conflict_type, details=msg)
|
||||
|
||||
def delete_idp(self, idp_id):
|
||||
with sql.transaction() as session:
|
||||
with sql.session_for_write() as session:
|
||||
self._delete_assigned_protocols(session, idp_id)
|
||||
idp_ref = self._get_idp(session, idp_id)
|
||||
session.delete(idp_ref)
|
||||
@ -206,31 +206,31 @@ class Federation(core.FederationDriverV9):
|
||||
raise exception.IdentityProviderNotFound(idp_id=remote_id)
|
||||
|
||||
def list_idps(self, hints=None):
|
||||
with sql.transaction() as session:
|
||||
with sql.session_for_read() as session:
|
||||
query = session.query(IdentityProviderModel)
|
||||
idps = sql.filter_limit_query(IdentityProviderModel, query, hints)
|
||||
idps_list = [idp.to_dict() for idp in idps]
|
||||
return idps_list
|
||||
idps_list = [idp.to_dict() for idp in idps]
|
||||
return idps_list
|
||||
|
||||
def get_idp(self, idp_id):
|
||||
with sql.transaction() as session:
|
||||
with sql.session_for_read() as session:
|
||||
idp_ref = self._get_idp(session, idp_id)
|
||||
return idp_ref.to_dict()
|
||||
return idp_ref.to_dict()
|
||||
|
||||
def get_idp_from_remote_id(self, remote_id):
|
||||
with sql.transaction() as session:
|
||||
with sql.session_for_read() as session:
|
||||
ref = self._get_idp_from_remote_id(session, remote_id)
|
||||
return ref.to_dict()
|
||||
return ref.to_dict()
|
||||
|
||||
def update_idp(self, idp_id, idp):
|
||||
with sql.transaction() as session:
|
||||
with sql.session_for_write() as session:
|
||||
idp_ref = self._get_idp(session, idp_id)
|
||||
old_idp = idp_ref.to_dict()
|
||||
old_idp.update(idp)
|
||||
new_idp = IdentityProviderModel.from_dict(old_idp)
|
||||
for attr in IdentityProviderModel.mutable_attributes:
|
||||
setattr(idp_ref, attr, getattr(new_idp, attr))
|
||||
return idp_ref.to_dict()
|
||||
return idp_ref.to_dict()
|
||||
|
||||
# Protocol CRUD
|
||||
def _get_protocol(self, session, idp_id, protocol_id):
|
||||
@ -247,36 +247,36 @@ class Federation(core.FederationDriverV9):
|
||||
def create_protocol(self, idp_id, protocol_id, protocol):
|
||||
protocol['id'] = protocol_id
|
||||
protocol['idp_id'] = idp_id
|
||||
with sql.transaction() as session:
|
||||
with sql.session_for_write() as session:
|
||||
self._get_idp(session, idp_id)
|
||||
protocol_ref = FederationProtocolModel.from_dict(protocol)
|
||||
session.add(protocol_ref)
|
||||
return protocol_ref.to_dict()
|
||||
return protocol_ref.to_dict()
|
||||
|
||||
def update_protocol(self, idp_id, protocol_id, protocol):
|
||||
with sql.transaction() as session:
|
||||
with sql.session_for_write() as session:
|
||||
proto_ref = self._get_protocol(session, idp_id, protocol_id)
|
||||
old_proto = proto_ref.to_dict()
|
||||
old_proto.update(protocol)
|
||||
new_proto = FederationProtocolModel.from_dict(old_proto)
|
||||
for attr in FederationProtocolModel.mutable_attributes:
|
||||
setattr(proto_ref, attr, getattr(new_proto, attr))
|
||||
return proto_ref.to_dict()
|
||||
return proto_ref.to_dict()
|
||||
|
||||
def get_protocol(self, idp_id, protocol_id):
|
||||
with sql.transaction() as session:
|
||||
with sql.session_for_read() as session:
|
||||
protocol_ref = self._get_protocol(session, idp_id, protocol_id)
|
||||
return protocol_ref.to_dict()
|
||||
return protocol_ref.to_dict()
|
||||
|
||||
def list_protocols(self, idp_id):
|
||||
with sql.transaction() as session:
|
||||
with sql.session_for_read() as session:
|
||||
q = session.query(FederationProtocolModel)
|
||||
q = q.filter_by(idp_id=idp_id)
|
||||
protocols = [protocol.to_dict() for protocol in q]
|
||||
return protocols
|
||||
protocols = [protocol.to_dict() for protocol in q]
|
||||
return protocols
|
||||
|
||||
def delete_protocol(self, idp_id, protocol_id):
|
||||
with sql.transaction() as session:
|
||||
with sql.session_for_write() as session:
|
||||
key_ref = self._get_protocol(session, idp_id, protocol_id)
|
||||
session.delete(key_ref)
|
||||
|
||||
@ -297,58 +297,58 @@ class Federation(core.FederationDriverV9):
|
||||
ref = {}
|
||||
ref['id'] = mapping_id
|
||||
ref['rules'] = mapping.get('rules')
|
||||
with sql.transaction() as session:
|
||||
with sql.session_for_write() as session:
|
||||
mapping_ref = MappingModel.from_dict(ref)
|
||||
session.add(mapping_ref)
|
||||
return mapping_ref.to_dict()
|
||||
return mapping_ref.to_dict()
|
||||
|
||||
def delete_mapping(self, mapping_id):
|
||||
with sql.transaction() as session:
|
||||
with sql.session_for_write() as session:
|
||||
mapping_ref = self._get_mapping(session, mapping_id)
|
||||
session.delete(mapping_ref)
|
||||
|
||||
def list_mappings(self):
|
||||
with sql.transaction() as session:
|
||||
with sql.session_for_read() as session:
|
||||
mappings = session.query(MappingModel)
|
||||
return [x.to_dict() for x in mappings]
|
||||
return [x.to_dict() for x in mappings]
|
||||
|
||||
def get_mapping(self, mapping_id):
|
||||
with sql.transaction() as session:
|
||||
with sql.session_for_read() as session:
|
||||
mapping_ref = self._get_mapping(session, mapping_id)
|
||||
return mapping_ref.to_dict()
|
||||
return mapping_ref.to_dict()
|
||||
|
||||
@sql.handle_conflicts(conflict_type='mapping')
|
||||
def update_mapping(self, mapping_id, mapping):
|
||||
ref = {}
|
||||
ref['id'] = mapping_id
|
||||
ref['rules'] = mapping.get('rules')
|
||||
with sql.transaction() as session:
|
||||
with sql.session_for_write() as session:
|
||||
mapping_ref = self._get_mapping(session, mapping_id)
|
||||
old_mapping = mapping_ref.to_dict()
|
||||
old_mapping.update(ref)
|
||||
new_mapping = MappingModel.from_dict(old_mapping)
|
||||
for attr in MappingModel.attributes:
|
||||
setattr(mapping_ref, attr, getattr(new_mapping, attr))
|
||||
return mapping_ref.to_dict()
|
||||
return mapping_ref.to_dict()
|
||||
|
||||
def get_mapping_from_idp_and_protocol(self, idp_id, protocol_id):
|
||||
with sql.transaction() as session:
|
||||
with sql.session_for_read() as session:
|
||||
protocol_ref = self._get_protocol(session, idp_id, protocol_id)
|
||||
mapping_id = protocol_ref.mapping_id
|
||||
mapping_ref = self._get_mapping(session, mapping_id)
|
||||
return mapping_ref.to_dict()
|
||||
return mapping_ref.to_dict()
|
||||
|
||||
# Service Provider CRUD
|
||||
@sql.handle_conflicts(conflict_type='service_provider')
|
||||
def create_sp(self, sp_id, sp):
|
||||
sp['id'] = sp_id
|
||||
with sql.transaction() as session:
|
||||
with sql.session_for_write() as session:
|
||||
sp_ref = ServiceProviderModel.from_dict(sp)
|
||||
session.add(sp_ref)
|
||||
return sp_ref.to_dict()
|
||||
return sp_ref.to_dict()
|
||||
|
||||
def delete_sp(self, sp_id):
|
||||
with sql.transaction() as session:
|
||||
with sql.session_for_write() as session:
|
||||
sp_ref = self._get_sp(session, sp_id)
|
||||
session.delete(sp_ref)
|
||||
|
||||
@ -359,28 +359,28 @@ class Federation(core.FederationDriverV9):
|
||||
return sp_ref
|
||||
|
||||
def list_sps(self):
|
||||
with sql.transaction() as session:
|
||||
with sql.session_for_read() as session:
|
||||
sps = session.query(ServiceProviderModel)
|
||||
sps_list = [sp.to_dict() for sp in sps]
|
||||
return sps_list
|
||||
sps_list = [sp.to_dict() for sp in sps]
|
||||
return sps_list
|
||||
|
||||
def get_sp(self, sp_id):
|
||||
with sql.transaction() as session:
|
||||
with sql.session_for_read() as session:
|
||||
sp_ref = self._get_sp(session, sp_id)
|
||||
return sp_ref.to_dict()
|
||||
return sp_ref.to_dict()
|
||||
|
||||
def update_sp(self, sp_id, sp):
|
||||
with sql.transaction() as session:
|
||||
with sql.session_for_write() as session:
|
||||
sp_ref = self._get_sp(session, sp_id)
|
||||
old_sp = sp_ref.to_dict()
|
||||
old_sp.update(sp)
|
||||
new_sp = ServiceProviderModel.from_dict(old_sp)
|
||||
for attr in ServiceProviderModel.mutable_attributes:
|
||||
setattr(sp_ref, attr, getattr(new_sp, attr))
|
||||
return sp_ref.to_dict()
|
||||
return sp_ref.to_dict()
|
||||
|
||||
def get_enabled_service_providers(self):
|
||||
with sql.transaction() as session:
|
||||
with sql.session_for_read() as session:
|
||||
service_providers = session.query(ServiceProviderModel)
|
||||
service_providers = service_providers.filter_by(enabled=True)
|
||||
return service_providers
|
||||
return service_providers
|
||||
|
@ -178,33 +178,32 @@ class Identity(identity.IdentityDriverV8):
|
||||
|
||||
# Identity interface
|
||||
def authenticate(self, user_id, password):
|
||||
session = sql.get_session()
|
||||
user_ref = None
|
||||
try:
|
||||
user_ref = self._get_user(session, user_id)
|
||||
except exception.UserNotFound:
|
||||
raise AssertionError(_('Invalid user / password'))
|
||||
if not self._check_password(password, user_ref):
|
||||
raise AssertionError(_('Invalid user / password'))
|
||||
return identity.filter_user(user_ref.to_dict())
|
||||
with sql.session_for_read() as session:
|
||||
user_ref = None
|
||||
try:
|
||||
user_ref = self._get_user(session, user_id)
|
||||
except exception.UserNotFound:
|
||||
raise AssertionError(_('Invalid user / password'))
|
||||
if not self._check_password(password, user_ref):
|
||||
raise AssertionError(_('Invalid user / password'))
|
||||
return identity.filter_user(user_ref.to_dict())
|
||||
|
||||
# user crud
|
||||
|
||||
@sql.handle_conflicts(conflict_type='user')
|
||||
def create_user(self, user_id, user):
|
||||
user = utils.hash_user_password(user)
|
||||
session = sql.get_session()
|
||||
with session.begin():
|
||||
with sql.session_for_write() as session:
|
||||
user_ref = User.from_dict(user)
|
||||
session.add(user_ref)
|
||||
return identity.filter_user(user_ref.to_dict())
|
||||
return identity.filter_user(user_ref.to_dict())
|
||||
|
||||
@driver_hints.truncated
|
||||
def list_users(self, hints):
|
||||
session = sql.get_session()
|
||||
query = session.query(User).outerjoin(LocalUser)
|
||||
user_refs = sql.filter_limit_query(User, query, hints)
|
||||
return [identity.filter_user(x.to_dict()) for x in user_refs]
|
||||
with sql.session_for_read() as session:
|
||||
query = session.query(User).outerjoin(LocalUser)
|
||||
user_refs = sql.filter_limit_query(User, query, hints)
|
||||
return [identity.filter_user(x.to_dict()) for x in user_refs]
|
||||
|
||||
def _get_user(self, session, user_id):
|
||||
user_ref = session.query(User).get(user_id)
|
||||
@ -213,25 +212,24 @@ class Identity(identity.IdentityDriverV8):
|
||||
return user_ref
|
||||
|
||||
def get_user(self, user_id):
|
||||
session = sql.get_session()
|
||||
return identity.filter_user(self._get_user(session, user_id).to_dict())
|
||||
with sql.session_for_read() as session:
|
||||
return identity.filter_user(
|
||||
self._get_user(session, user_id).to_dict())
|
||||
|
||||
def get_user_by_name(self, user_name, domain_id):
|
||||
session = sql.get_session()
|
||||
query = session.query(User).join(LocalUser)
|
||||
query = query.filter(and_(LocalUser.name == user_name,
|
||||
LocalUser.domain_id == domain_id))
|
||||
try:
|
||||
user_ref = query.one()
|
||||
except sql.NotFound:
|
||||
raise exception.UserNotFound(user_id=user_name)
|
||||
return identity.filter_user(user_ref.to_dict())
|
||||
with sql.session_for_read() as session:
|
||||
query = session.query(User).join(LocalUser)
|
||||
query = query.filter(and_(LocalUser.name == user_name,
|
||||
LocalUser.domain_id == domain_id))
|
||||
try:
|
||||
user_ref = query.one()
|
||||
except sql.NotFound:
|
||||
raise exception.UserNotFound(user_id=user_name)
|
||||
return identity.filter_user(user_ref.to_dict())
|
||||
|
||||
@sql.handle_conflicts(conflict_type='user')
|
||||
def update_user(self, user_id, user):
|
||||
session = sql.get_session()
|
||||
|
||||
with session.begin():
|
||||
with sql.session_for_write() as session:
|
||||
user_ref = self._get_user(session, user_id)
|
||||
old_user_dict = user_ref.to_dict()
|
||||
user = utils.hash_user_password(user)
|
||||
@ -242,77 +240,74 @@ class Identity(identity.IdentityDriverV8):
|
||||
if attr != 'id':
|
||||
setattr(user_ref, attr, getattr(new_user, attr))
|
||||
user_ref.extra = new_user.extra
|
||||
return identity.filter_user(user_ref.to_dict(include_extra_dict=True))
|
||||
return identity.filter_user(
|
||||
user_ref.to_dict(include_extra_dict=True))
|
||||
|
||||
def add_user_to_group(self, user_id, group_id):
|
||||
session = sql.get_session()
|
||||
self.get_group(group_id)
|
||||
self.get_user(user_id)
|
||||
query = session.query(UserGroupMembership)
|
||||
query = query.filter_by(user_id=user_id)
|
||||
query = query.filter_by(group_id=group_id)
|
||||
rv = query.first()
|
||||
if rv:
|
||||
return
|
||||
with sql.session_for_write() as session:
|
||||
self.get_group(group_id)
|
||||
self.get_user(user_id)
|
||||
query = session.query(UserGroupMembership)
|
||||
query = query.filter_by(user_id=user_id)
|
||||
query = query.filter_by(group_id=group_id)
|
||||
rv = query.first()
|
||||
if rv:
|
||||
return
|
||||
|
||||
with session.begin():
|
||||
session.add(UserGroupMembership(user_id=user_id,
|
||||
group_id=group_id))
|
||||
|
||||
def check_user_in_group(self, user_id, group_id):
|
||||
session = sql.get_session()
|
||||
self.get_group(group_id)
|
||||
self.get_user(user_id)
|
||||
query = session.query(UserGroupMembership)
|
||||
query = query.filter_by(user_id=user_id)
|
||||
query = query.filter_by(group_id=group_id)
|
||||
if not query.first():
|
||||
raise exception.NotFound(_("User '%(user_id)s' not found in"
|
||||
" group '%(group_id)s'") %
|
||||
{'user_id': user_id,
|
||||
'group_id': group_id})
|
||||
|
||||
def remove_user_from_group(self, user_id, group_id):
|
||||
session = sql.get_session()
|
||||
# We don't check if user or group are still valid and let the remove
|
||||
# be tried anyway - in case this is some kind of clean-up operation
|
||||
query = session.query(UserGroupMembership)
|
||||
query = query.filter_by(user_id=user_id)
|
||||
query = query.filter_by(group_id=group_id)
|
||||
membership_ref = query.first()
|
||||
if membership_ref is None:
|
||||
# Check if the group and user exist to return descriptive
|
||||
# exceptions.
|
||||
with sql.session_for_read() as session:
|
||||
self.get_group(group_id)
|
||||
self.get_user(user_id)
|
||||
raise exception.NotFound(_("User '%(user_id)s' not found in"
|
||||
" group '%(group_id)s'") %
|
||||
{'user_id': user_id,
|
||||
'group_id': group_id})
|
||||
with session.begin():
|
||||
query = session.query(UserGroupMembership)
|
||||
query = query.filter_by(user_id=user_id)
|
||||
query = query.filter_by(group_id=group_id)
|
||||
if not query.first():
|
||||
raise exception.NotFound(_("User '%(user_id)s' not found in"
|
||||
" group '%(group_id)s'") %
|
||||
{'user_id': user_id,
|
||||
'group_id': group_id})
|
||||
|
||||
def remove_user_from_group(self, user_id, group_id):
|
||||
# We don't check if user or group are still valid and let the remove
|
||||
# be tried anyway - in case this is some kind of clean-up operation
|
||||
with sql.session_for_write() as session:
|
||||
query = session.query(UserGroupMembership)
|
||||
query = query.filter_by(user_id=user_id)
|
||||
query = query.filter_by(group_id=group_id)
|
||||
membership_ref = query.first()
|
||||
if membership_ref is None:
|
||||
# Check if the group and user exist to return descriptive
|
||||
# exceptions.
|
||||
self.get_group(group_id)
|
||||
self.get_user(user_id)
|
||||
raise exception.NotFound(_("User '%(user_id)s' not found in"
|
||||
" group '%(group_id)s'") %
|
||||
{'user_id': user_id,
|
||||
'group_id': group_id})
|
||||
session.delete(membership_ref)
|
||||
|
||||
def list_groups_for_user(self, user_id, hints):
|
||||
session = sql.get_session()
|
||||
self.get_user(user_id)
|
||||
query = session.query(Group).join(UserGroupMembership)
|
||||
query = query.filter(UserGroupMembership.user_id == user_id)
|
||||
query = sql.filter_limit_query(Group, query, hints)
|
||||
return [g.to_dict() for g in query]
|
||||
with sql.session_for_read() as session:
|
||||
self.get_user(user_id)
|
||||
query = session.query(Group).join(UserGroupMembership)
|
||||
query = query.filter(UserGroupMembership.user_id == user_id)
|
||||
query = sql.filter_limit_query(Group, query, hints)
|
||||
return [g.to_dict() for g in query]
|
||||
|
||||
def list_users_in_group(self, group_id, hints):
|
||||
session = sql.get_session()
|
||||
self.get_group(group_id)
|
||||
query = session.query(User).outerjoin(LocalUser)
|
||||
query = query.join(UserGroupMembership)
|
||||
query = query.filter(UserGroupMembership.group_id == group_id)
|
||||
query = sql.filter_limit_query(User, query, hints)
|
||||
return [identity.filter_user(u.to_dict()) for u in query]
|
||||
with sql.session_for_read() as session:
|
||||
self.get_group(group_id)
|
||||
query = session.query(User).outerjoin(LocalUser)
|
||||
query = query.join(UserGroupMembership)
|
||||
query = query.filter(UserGroupMembership.group_id == group_id)
|
||||
query = sql.filter_limit_query(User, query, hints)
|
||||
return [identity.filter_user(u.to_dict()) for u in query]
|
||||
|
||||
def delete_user(self, user_id):
|
||||
session = sql.get_session()
|
||||
|
||||
with session.begin():
|
||||
with sql.session_for_write() as session:
|
||||
ref = self._get_user(session, user_id)
|
||||
|
||||
q = session.query(UserGroupMembership)
|
||||
@ -325,18 +320,17 @@ class Identity(identity.IdentityDriverV8):
|
||||
|
||||
@sql.handle_conflicts(conflict_type='group')
|
||||
def create_group(self, group_id, group):
|
||||
session = sql.get_session()
|
||||
with session.begin():
|
||||
with sql.session_for_write() as session:
|
||||
ref = Group.from_dict(group)
|
||||
session.add(ref)
|
||||
return ref.to_dict()
|
||||
return ref.to_dict()
|
||||
|
||||
@driver_hints.truncated
|
||||
def list_groups(self, hints):
|
||||
session = sql.get_session()
|
||||
query = session.query(Group)
|
||||
refs = sql.filter_limit_query(Group, query, hints)
|
||||
return [ref.to_dict() for ref in refs]
|
||||
with sql.session_for_read() as session:
|
||||
query = session.query(Group)
|
||||
refs = sql.filter_limit_query(Group, query, hints)
|
||||
return [ref.to_dict() for ref in refs]
|
||||
|
||||
def _get_group(self, session, group_id):
|
||||
ref = session.query(Group).get(group_id)
|
||||
@ -345,25 +339,23 @@ class Identity(identity.IdentityDriverV8):
|
||||
return ref
|
||||
|
||||
def get_group(self, group_id):
|
||||
session = sql.get_session()
|
||||
return self._get_group(session, group_id).to_dict()
|
||||
with sql.session_for_read() as session:
|
||||
return self._get_group(session, group_id).to_dict()
|
||||
|
||||
def get_group_by_name(self, group_name, domain_id):
|
||||
session = sql.get_session()
|
||||
query = session.query(Group)
|
||||
query = query.filter_by(name=group_name)
|
||||
query = query.filter_by(domain_id=domain_id)
|
||||
try:
|
||||
group_ref = query.one()
|
||||
except sql.NotFound:
|
||||
raise exception.GroupNotFound(group_id=group_name)
|
||||
return group_ref.to_dict()
|
||||
with sql.session_for_read() as session:
|
||||
query = session.query(Group)
|
||||
query = query.filter_by(name=group_name)
|
||||
query = query.filter_by(domain_id=domain_id)
|
||||
try:
|
||||
group_ref = query.one()
|
||||
except sql.NotFound:
|
||||
raise exception.GroupNotFound(group_id=group_name)
|
||||
return group_ref.to_dict()
|
||||
|
||||
@sql.handle_conflicts(conflict_type='group')
|
||||
def update_group(self, group_id, group):
|
||||
session = sql.get_session()
|
||||
|
||||
with session.begin():
|
||||
with sql.session_for_write() as session:
|
||||
ref = self._get_group(session, group_id)
|
||||
old_dict = ref.to_dict()
|
||||
for k in group:
|
||||
@ -373,12 +365,10 @@ class Identity(identity.IdentityDriverV8):
|
||||
if attr != 'id':
|
||||
setattr(ref, attr, getattr(new_group, attr))
|
||||
ref.extra = new_group.extra
|
||||
return ref.to_dict()
|
||||
return ref.to_dict()
|
||||
|
||||
def delete_group(self, group_id):
|
||||
session = sql.get_session()
|
||||
|
||||
with session.begin():
|
||||
with sql.session_for_write() as session:
|
||||
ref = self._get_group(session, group_id)
|
||||
|
||||
q = session.query(UserGroupMembership)
|
||||
|
@ -45,27 +45,27 @@ class Mapping(identity.MappingDriverV8):
|
||||
# work if we hashed all the entries, even those that already generate
|
||||
# UUIDs, like SQL. Further, this would only work if the generation
|
||||
# algorithm was immutable (e.g. it had always been sha256).
|
||||
session = sql.get_session()
|
||||
query = session.query(IDMapping.public_id)
|
||||
query = query.filter_by(domain_id=local_entity['domain_id'])
|
||||
query = query.filter_by(local_id=local_entity['local_id'])
|
||||
query = query.filter_by(entity_type=local_entity['entity_type'])
|
||||
try:
|
||||
public_ref = query.one()
|
||||
public_id = public_ref.public_id
|
||||
return public_id
|
||||
except sql.NotFound:
|
||||
return None
|
||||
with sql.session_for_read() as session:
|
||||
query = session.query(IDMapping.public_id)
|
||||
query = query.filter_by(domain_id=local_entity['domain_id'])
|
||||
query = query.filter_by(local_id=local_entity['local_id'])
|
||||
query = query.filter_by(entity_type=local_entity['entity_type'])
|
||||
try:
|
||||
public_ref = query.one()
|
||||
public_id = public_ref.public_id
|
||||
return public_id
|
||||
except sql.NotFound:
|
||||
return None
|
||||
|
||||
def get_id_mapping(self, public_id):
|
||||
session = sql.get_session()
|
||||
mapping_ref = session.query(IDMapping).get(public_id)
|
||||
if mapping_ref:
|
||||
return mapping_ref.to_dict()
|
||||
with sql.session_for_read() as session:
|
||||
mapping_ref = session.query(IDMapping).get(public_id)
|
||||
if mapping_ref:
|
||||
return mapping_ref.to_dict()
|
||||
|
||||
def create_id_mapping(self, local_entity, public_id=None):
|
||||
entity = local_entity.copy()
|
||||
with sql.transaction() as session:
|
||||
with sql.session_for_write() as session:
|
||||
if public_id is None:
|
||||
public_id = self.id_generator_api.generate_public_ID(entity)
|
||||
entity['public_id'] = public_id
|
||||
@ -74,7 +74,7 @@ class Mapping(identity.MappingDriverV8):
|
||||
return public_id
|
||||
|
||||
def delete_id_mapping(self, public_id):
|
||||
with sql.transaction() as session:
|
||||
with sql.session_for_write() as session:
|
||||
try:
|
||||
session.query(IDMapping).filter(
|
||||
IDMapping.public_id == public_id).delete()
|
||||
@ -84,14 +84,15 @@ class Mapping(identity.MappingDriverV8):
|
||||
pass
|
||||
|
||||
def purge_mappings(self, purge_filter):
|
||||
session = sql.get_session()
|
||||
query = session.query(IDMapping)
|
||||
if 'domain_id' in purge_filter:
|
||||
query = query.filter_by(domain_id=purge_filter['domain_id'])
|
||||
if 'public_id' in purge_filter:
|
||||
query = query.filter_by(public_id=purge_filter['public_id'])
|
||||
if 'local_id' in purge_filter:
|
||||
query = query.filter_by(local_id=purge_filter['local_id'])
|
||||
if 'entity_type' in purge_filter:
|
||||
query = query.filter_by(entity_type=purge_filter['entity_type'])
|
||||
query.delete()
|
||||
with sql.session_for_write() as session:
|
||||
query = session.query(IDMapping)
|
||||
if 'domain_id' in purge_filter:
|
||||
query = query.filter_by(domain_id=purge_filter['domain_id'])
|
||||
if 'public_id' in purge_filter:
|
||||
query = query.filter_by(public_id=purge_filter['public_id'])
|
||||
if 'local_id' in purge_filter:
|
||||
query = query.filter_by(local_id=purge_filter['local_id'])
|
||||
if 'entity_type' in purge_filter:
|
||||
query = query.filter_by(
|
||||
entity_type=purge_filter['entity_type'])
|
||||
query.delete()
|
||||
|
@ -92,17 +92,16 @@ class OAuth1(core.Oauth1DriverV8):
|
||||
return consumer_ref
|
||||
|
||||
def get_consumer_with_secret(self, consumer_id):
|
||||
session = sql.get_session()
|
||||
consumer_ref = self._get_consumer(session, consumer_id)
|
||||
return consumer_ref.to_dict()
|
||||
with sql.session_for_read() as session:
|
||||
consumer_ref = self._get_consumer(session, consumer_id)
|
||||
return consumer_ref.to_dict()
|
||||
|
||||
def get_consumer(self, consumer_id):
|
||||
return core.filter_consumer(
|
||||
self.get_consumer_with_secret(consumer_id))
|
||||
|
||||
def create_consumer(self, consumer_ref):
|
||||
session = sql.get_session()
|
||||
with session.begin():
|
||||
with sql.session_for_write() as session:
|
||||
consumer = Consumer.from_dict(consumer_ref)
|
||||
session.add(consumer)
|
||||
return consumer.to_dict()
|
||||
@ -128,20 +127,18 @@ class OAuth1(core.Oauth1DriverV8):
|
||||
session.delete(token_ref)
|
||||
|
||||
def delete_consumer(self, consumer_id):
|
||||
session = sql.get_session()
|
||||
with session.begin():
|
||||
with sql.session_for_write() as session:
|
||||
self._delete_request_tokens(session, consumer_id)
|
||||
self._delete_access_tokens(session, consumer_id)
|
||||
self._delete_consumer(session, consumer_id)
|
||||
|
||||
def list_consumers(self):
|
||||
session = sql.get_session()
|
||||
cons = session.query(Consumer)
|
||||
return [core.filter_consumer(x.to_dict()) for x in cons]
|
||||
with sql.session_for_read() as session:
|
||||
cons = session.query(Consumer)
|
||||
return [core.filter_consumer(x.to_dict()) for x in cons]
|
||||
|
||||
def update_consumer(self, consumer_id, consumer_ref):
|
||||
session = sql.get_session()
|
||||
with session.begin():
|
||||
with sql.session_for_write() as session:
|
||||
consumer = self._get_consumer(session, consumer_id)
|
||||
old_consumer_dict = consumer.to_dict()
|
||||
old_consumer_dict.update(consumer_ref)
|
||||
@ -169,11 +166,10 @@ class OAuth1(core.Oauth1DriverV8):
|
||||
ref['role_ids'] = None
|
||||
ref['consumer_id'] = consumer_id
|
||||
ref['expires_at'] = expiry_date
|
||||
session = sql.get_session()
|
||||
with session.begin():
|
||||
with sql.session_for_write() as session:
|
||||
token_ref = RequestToken.from_dict(ref)
|
||||
session.add(token_ref)
|
||||
return token_ref.to_dict()
|
||||
return token_ref.to_dict()
|
||||
|
||||
def _get_request_token(self, session, request_token_id):
|
||||
token_ref = session.query(RequestToken).get(request_token_id)
|
||||
@ -182,14 +178,13 @@ class OAuth1(core.Oauth1DriverV8):
|
||||
return token_ref
|
||||
|
||||
def get_request_token(self, request_token_id):
|
||||
session = sql.get_session()
|
||||
token_ref = self._get_request_token(session, request_token_id)
|
||||
return token_ref.to_dict()
|
||||
with sql.session_for_read() as session:
|
||||
token_ref = self._get_request_token(session, request_token_id)
|
||||
return token_ref.to_dict()
|
||||
|
||||
def authorize_request_token(self, request_token_id, user_id,
|
||||
role_ids):
|
||||
session = sql.get_session()
|
||||
with session.begin():
|
||||
with sql.session_for_write() as session:
|
||||
token_ref = self._get_request_token(session, request_token_id)
|
||||
token_dict = token_ref.to_dict()
|
||||
token_dict['authorizing_user_id'] = user_id
|
||||
@ -203,13 +198,12 @@ class OAuth1(core.Oauth1DriverV8):
|
||||
or attr == 'role_ids'):
|
||||
setattr(token_ref, attr, getattr(new_token, attr))
|
||||
|
||||
return token_ref.to_dict()
|
||||
return token_ref.to_dict()
|
||||
|
||||
def create_access_token(self, request_id, access_token_duration):
|
||||
access_token_id = uuid.uuid4().hex
|
||||
access_token_secret = uuid.uuid4().hex
|
||||
session = sql.get_session()
|
||||
with session.begin():
|
||||
with sql.session_for_write() as session:
|
||||
req_token_ref = self._get_request_token(session, request_id)
|
||||
token_dict = req_token_ref.to_dict()
|
||||
|
||||
@ -235,7 +229,7 @@ class OAuth1(core.Oauth1DriverV8):
|
||||
# remove request token, it's been used
|
||||
session.delete(req_token_ref)
|
||||
|
||||
return token_ref.to_dict()
|
||||
return token_ref.to_dict()
|
||||
|
||||
def _get_access_token(self, session, access_token_id):
|
||||
token_ref = session.query(AccessToken).get(access_token_id)
|
||||
@ -244,19 +238,18 @@ class OAuth1(core.Oauth1DriverV8):
|
||||
return token_ref
|
||||
|
||||
def get_access_token(self, access_token_id):
|
||||
session = sql.get_session()
|
||||
token_ref = self._get_access_token(session, access_token_id)
|
||||
return token_ref.to_dict()
|
||||
with sql.session_for_read() as session:
|
||||
token_ref = self._get_access_token(session, access_token_id)
|
||||
return token_ref.to_dict()
|
||||
|
||||
def list_access_tokens(self, user_id):
|
||||
session = sql.get_session()
|
||||
q = session.query(AccessToken)
|
||||
user_auths = q.filter_by(authorizing_user_id=user_id)
|
||||
return [core.filter_token(x.to_dict()) for x in user_auths]
|
||||
with sql.session_for_read() as session:
|
||||
q = session.query(AccessToken)
|
||||
user_auths = q.filter_by(authorizing_user_id=user_id)
|
||||
return [core.filter_token(x.to_dict()) for x in user_auths]
|
||||
|
||||
def delete_access_token(self, user_id, access_token_id):
|
||||
session = sql.get_session()
|
||||
with session.begin():
|
||||
with sql.session_for_write() as session:
|
||||
token_ref = self._get_access_token(session, access_token_id)
|
||||
token_dict = token_ref.to_dict()
|
||||
if token_dict['authorizing_user_id'] != user_id:
|
||||
|
@ -30,19 +30,16 @@ class Policy(rules.Policy):
|
||||
|
||||
@sql.handle_conflicts(conflict_type='policy')
|
||||
def create_policy(self, policy_id, policy):
|
||||
session = sql.get_session()
|
||||
|
||||
with session.begin():
|
||||
with sql.session_for_write() as session:
|
||||
ref = PolicyModel.from_dict(policy)
|
||||
session.add(ref)
|
||||
|
||||
return ref.to_dict()
|
||||
return ref.to_dict()
|
||||
|
||||
def list_policies(self):
|
||||
session = sql.get_session()
|
||||
|
||||
refs = session.query(PolicyModel).all()
|
||||
return [ref.to_dict() for ref in refs]
|
||||
with sql.session_for_read() as session:
|
||||
refs = session.query(PolicyModel).all()
|
||||
return [ref.to_dict() for ref in refs]
|
||||
|
||||
def _get_policy(self, session, policy_id):
|
||||
"""Private method to get a policy model object (NOT a dictionary)."""
|
||||
@ -52,15 +49,12 @@ class Policy(rules.Policy):
|
||||
return ref
|
||||
|
||||
def get_policy(self, policy_id):
|
||||
session = sql.get_session()
|
||||
|
||||
return self._get_policy(session, policy_id).to_dict()
|
||||
with sql.session_for_read() as session:
|
||||
return self._get_policy(session, policy_id).to_dict()
|
||||
|
||||
@sql.handle_conflicts(conflict_type='policy')
|
||||
def update_policy(self, policy_id, policy):
|
||||
session = sql.get_session()
|
||||
|
||||
with session.begin():
|
||||
with sql.session_for_write() as session:
|
||||
ref = self._get_policy(session, policy_id)
|
||||
old_dict = ref.to_dict()
|
||||
old_dict.update(policy)
|
||||
@ -72,8 +66,6 @@ class Policy(rules.Policy):
|
||||
return ref.to_dict()
|
||||
|
||||
def delete_policy(self, policy_id):
|
||||
session = sql.get_session()
|
||||
|
||||
with session.begin():
|
||||
with sql.session_for_write() as session:
|
||||
ref = self._get_policy(session, policy_id)
|
||||
session.delete(ref)
|
||||
|
@ -35,11 +35,11 @@ class Resource(keystone_resource.ResourceDriverV8):
|
||||
return project_ref
|
||||
|
||||
def get_project(self, tenant_id):
|
||||
with sql.transaction() as session:
|
||||
with sql.session_for_read() as session:
|
||||
return self._get_project(session, tenant_id).to_dict()
|
||||
|
||||
def get_project_by_name(self, tenant_name, domain_id):
|
||||
with sql.transaction() as session:
|
||||
with sql.session_for_read() as session:
|
||||
query = session.query(Project)
|
||||
query = query.filter_by(name=tenant_name)
|
||||
query = query.filter_by(domain_id=domain_id)
|
||||
@ -51,7 +51,7 @@ class Resource(keystone_resource.ResourceDriverV8):
|
||||
|
||||
@driver_hints.truncated
|
||||
def list_projects(self, hints):
|
||||
with sql.transaction() as session:
|
||||
with sql.session_for_read() as session:
|
||||
query = session.query(Project)
|
||||
project_refs = sql.filter_limit_query(Project, query, hints)
|
||||
return [project_ref.to_dict() for project_ref in project_refs]
|
||||
@ -60,7 +60,7 @@ class Resource(keystone_resource.ResourceDriverV8):
|
||||
if not ids:
|
||||
return []
|
||||
else:
|
||||
with sql.transaction() as session:
|
||||
with sql.session_for_read() as session:
|
||||
query = session.query(Project)
|
||||
query = query.filter(Project.id.in_(ids))
|
||||
return [project_ref.to_dict() for project_ref in query.all()]
|
||||
@ -69,14 +69,14 @@ class Resource(keystone_resource.ResourceDriverV8):
|
||||
if not domain_ids:
|
||||
return []
|
||||
else:
|
||||
with sql.transaction() as session:
|
||||
with sql.session_for_read() as session:
|
||||
query = session.query(Project.id)
|
||||
query = (
|
||||
query.filter(Project.domain_id.in_(domain_ids)))
|
||||
return [x.id for x in query.all()]
|
||||
|
||||
def list_projects_in_domain(self, domain_id):
|
||||
with sql.transaction() as session:
|
||||
with sql.session_for_read() as session:
|
||||
self._get_domain(session, domain_id)
|
||||
query = session.query(Project)
|
||||
project_refs = query.filter_by(domain_id=domain_id)
|
||||
@ -89,7 +89,7 @@ class Resource(keystone_resource.ResourceDriverV8):
|
||||
return [project_ref.to_dict() for project_ref in project_refs]
|
||||
|
||||
def list_projects_in_subtree(self, project_id):
|
||||
with sql.transaction() as session:
|
||||
with sql.session_for_read() as session:
|
||||
children = self._get_children(session, [project_id])
|
||||
subtree = []
|
||||
examined = set([project_id])
|
||||
@ -110,7 +110,7 @@ class Resource(keystone_resource.ResourceDriverV8):
|
||||
return subtree
|
||||
|
||||
def list_project_parents(self, project_id):
|
||||
with sql.transaction() as session:
|
||||
with sql.session_for_read() as session:
|
||||
project = self._get_project(session, project_id).to_dict()
|
||||
parents = []
|
||||
examined = set()
|
||||
@ -130,7 +130,7 @@ class Resource(keystone_resource.ResourceDriverV8):
|
||||
return parents
|
||||
|
||||
def is_leaf_project(self, project_id):
|
||||
with sql.transaction() as session:
|
||||
with sql.session_for_read() as session:
|
||||
project_refs = self._get_children(session, [project_id])
|
||||
return not project_refs
|
||||
|
||||
@ -138,7 +138,7 @@ class Resource(keystone_resource.ResourceDriverV8):
|
||||
@sql.handle_conflicts(conflict_type='project')
|
||||
def create_project(self, tenant_id, tenant):
|
||||
tenant['name'] = clean.project_name(tenant['name'])
|
||||
with sql.transaction() as session:
|
||||
with sql.session_for_write() as session:
|
||||
tenant_ref = Project.from_dict(tenant)
|
||||
session.add(tenant_ref)
|
||||
return tenant_ref.to_dict()
|
||||
@ -148,7 +148,7 @@ class Resource(keystone_resource.ResourceDriverV8):
|
||||
if 'name' in tenant:
|
||||
tenant['name'] = clean.project_name(tenant['name'])
|
||||
|
||||
with sql.transaction() as session:
|
||||
with sql.session_for_write() as session:
|
||||
tenant_ref = self._get_project(session, tenant_id)
|
||||
old_project_dict = tenant_ref.to_dict()
|
||||
for k in tenant:
|
||||
@ -162,7 +162,7 @@ class Resource(keystone_resource.ResourceDriverV8):
|
||||
|
||||
@sql.handle_conflicts(conflict_type='project')
|
||||
def delete_project(self, tenant_id):
|
||||
with sql.transaction() as session:
|
||||
with sql.session_for_write() as session:
|
||||
tenant_ref = self._get_project(session, tenant_id)
|
||||
session.delete(tenant_ref)
|
||||
|
||||
@ -170,14 +170,14 @@ class Resource(keystone_resource.ResourceDriverV8):
|
||||
|
||||
@sql.handle_conflicts(conflict_type='domain')
|
||||
def create_domain(self, domain_id, domain):
|
||||
with sql.transaction() as session:
|
||||
with sql.session_for_write() as session:
|
||||
ref = Domain.from_dict(domain)
|
||||
session.add(ref)
|
||||
return ref.to_dict()
|
||||
|
||||
@driver_hints.truncated
|
||||
def list_domains(self, hints):
|
||||
with sql.transaction() as session:
|
||||
with sql.session_for_read() as session:
|
||||
query = session.query(Domain)
|
||||
refs = sql.filter_limit_query(Domain, query, hints)
|
||||
return [ref.to_dict() for ref in refs]
|
||||
@ -186,7 +186,7 @@ class Resource(keystone_resource.ResourceDriverV8):
|
||||
if not ids:
|
||||
return []
|
||||
else:
|
||||
with sql.transaction() as session:
|
||||
with sql.session_for_read() as session:
|
||||
query = session.query(Domain)
|
||||
query = query.filter(Domain.id.in_(ids))
|
||||
domain_refs = query.all()
|
||||
@ -199,11 +199,11 @@ class Resource(keystone_resource.ResourceDriverV8):
|
||||
return ref
|
||||
|
||||
def get_domain(self, domain_id):
|
||||
with sql.transaction() as session:
|
||||
with sql.session_for_read() as session:
|
||||
return self._get_domain(session, domain_id).to_dict()
|
||||
|
||||
def get_domain_by_name(self, domain_name):
|
||||
with sql.transaction() as session:
|
||||
with sql.session_for_read() as session:
|
||||
try:
|
||||
ref = (session.query(Domain).
|
||||
filter_by(name=domain_name).one())
|
||||
@ -213,7 +213,7 @@ class Resource(keystone_resource.ResourceDriverV8):
|
||||
|
||||
@sql.handle_conflicts(conflict_type='domain')
|
||||
def update_domain(self, domain_id, domain):
|
||||
with sql.transaction() as session:
|
||||
with sql.session_for_write() as session:
|
||||
ref = self._get_domain(session, domain_id)
|
||||
old_dict = ref.to_dict()
|
||||
for k in domain:
|
||||
@ -226,7 +226,7 @@ class Resource(keystone_resource.ResourceDriverV8):
|
||||
return ref.to_dict()
|
||||
|
||||
def delete_domain(self, domain_id):
|
||||
with sql.transaction() as session:
|
||||
with sql.session_for_write() as session:
|
||||
ref = self._get_domain(session, domain_id)
|
||||
session.delete(ref)
|
||||
|
||||
|
@ -38,11 +38,11 @@ class Resource(keystone_resource.ResourceDriverV9):
|
||||
return project_ref
|
||||
|
||||
def get_project(self, project_id):
|
||||
with sql.transaction() as session:
|
||||
with sql.session_for_read() as session:
|
||||
return self._get_project(session, project_id).to_dict()
|
||||
|
||||
def get_project_by_name(self, project_name, domain_id):
|
||||
with sql.transaction() as session:
|
||||
with sql.session_for_read() as session:
|
||||
query = session.query(Project)
|
||||
query = query.filter_by(name=project_name)
|
||||
if domain_id is None:
|
||||
@ -70,7 +70,7 @@ class Resource(keystone_resource.ResourceDriverV9):
|
||||
for f in hints.filters:
|
||||
if (f['name'] == 'domain_id' and f['value'] is None):
|
||||
f['value'] = keystone_resource.NULL_DOMAIN_ID
|
||||
with sql.transaction() as session:
|
||||
with sql.session_for_read() as session:
|
||||
query = session.query(Project)
|
||||
project_refs = sql.filter_limit_query(Project, query, hints)
|
||||
return [project_ref.to_dict() for project_ref in project_refs
|
||||
@ -80,7 +80,7 @@ class Resource(keystone_resource.ResourceDriverV9):
|
||||
if not ids:
|
||||
return []
|
||||
else:
|
||||
with sql.transaction() as session:
|
||||
with sql.session_for_read() as session:
|
||||
query = session.query(Project)
|
||||
query = query.filter(Project.id.in_(ids))
|
||||
return [project_ref.to_dict() for project_ref in query.all()
|
||||
@ -90,7 +90,7 @@ class Resource(keystone_resource.ResourceDriverV9):
|
||||
if not domain_ids:
|
||||
return []
|
||||
else:
|
||||
with sql.transaction() as session:
|
||||
with sql.session_for_read() as session:
|
||||
query = session.query(Project.id)
|
||||
query = (
|
||||
query.filter(Project.domain_id.in_(domain_ids)))
|
||||
@ -98,7 +98,7 @@ class Resource(keystone_resource.ResourceDriverV9):
|
||||
if not self._is_hidden_ref(x)]
|
||||
|
||||
def list_projects_in_domain(self, domain_id):
|
||||
with sql.transaction() as session:
|
||||
with sql.session_for_read() as session:
|
||||
self._get_domain(session, domain_id)
|
||||
query = session.query(Project)
|
||||
project_refs = query.filter_by(domain_id=domain_id)
|
||||
@ -111,7 +111,7 @@ class Resource(keystone_resource.ResourceDriverV9):
|
||||
return [project_ref.to_dict() for project_ref in project_refs]
|
||||
|
||||
def list_projects_in_subtree(self, project_id):
|
||||
with sql.transaction() as session:
|
||||
with sql.session_for_read() as session:
|
||||
children = self._get_children(session, [project_id])
|
||||
subtree = []
|
||||
examined = set([project_id])
|
||||
@ -132,7 +132,7 @@ class Resource(keystone_resource.ResourceDriverV9):
|
||||
return subtree
|
||||
|
||||
def list_project_parents(self, project_id):
|
||||
with sql.transaction() as session:
|
||||
with sql.session_for_read() as session:
|
||||
project = self._get_project(session, project_id).to_dict()
|
||||
parents = []
|
||||
examined = set()
|
||||
@ -152,7 +152,7 @@ class Resource(keystone_resource.ResourceDriverV9):
|
||||
return parents
|
||||
|
||||
def is_leaf_project(self, project_id):
|
||||
with sql.transaction() as session:
|
||||
with sql.session_for_read() as session:
|
||||
project_refs = self._get_children(session, [project_id])
|
||||
return not project_refs
|
||||
|
||||
@ -161,7 +161,7 @@ class Resource(keystone_resource.ResourceDriverV9):
|
||||
def create_project(self, project_id, project):
|
||||
project['name'] = clean.project_name(project['name'])
|
||||
new_project = self._encode_domain_id(project)
|
||||
with sql.transaction() as session:
|
||||
with sql.session_for_write() as session:
|
||||
project_ref = Project.from_dict(new_project)
|
||||
session.add(project_ref)
|
||||
return project_ref.to_dict()
|
||||
@ -172,7 +172,7 @@ class Resource(keystone_resource.ResourceDriverV9):
|
||||
project['name'] = clean.project_name(project['name'])
|
||||
|
||||
update_project = self._encode_domain_id(project)
|
||||
with sql.transaction() as session:
|
||||
with sql.session_for_write() as session:
|
||||
project_ref = self._get_project(session, project_id)
|
||||
old_project_dict = project_ref.to_dict()
|
||||
for k in update_project:
|
||||
@ -189,7 +189,7 @@ class Resource(keystone_resource.ResourceDriverV9):
|
||||
|
||||
@sql.handle_conflicts(conflict_type='project')
|
||||
def delete_project(self, project_id):
|
||||
with sql.transaction() as session:
|
||||
with sql.session_for_write() as session:
|
||||
project_ref = self._get_project(session, project_id)
|
||||
session.delete(project_ref)
|
||||
|
||||
@ -197,7 +197,7 @@ class Resource(keystone_resource.ResourceDriverV9):
|
||||
def delete_projects_from_ids(self, project_ids):
|
||||
if not project_ids:
|
||||
return
|
||||
with sql.transaction() as session:
|
||||
with sql.session_for_write() as session:
|
||||
query = session.query(Project).filter(Project.id.in_(
|
||||
project_ids))
|
||||
project_ids_from_bd = [p['id'] for p in query.all()]
|
||||
@ -212,14 +212,14 @@ class Resource(keystone_resource.ResourceDriverV9):
|
||||
|
||||
@sql.handle_conflicts(conflict_type='domain')
|
||||
def create_domain(self, domain_id, domain):
|
||||
with sql.transaction() as session:
|
||||
with sql.session_for_write() as session:
|
||||
ref = Domain.from_dict(domain)
|
||||
session.add(ref)
|
||||
return ref.to_dict()
|
||||
return ref.to_dict()
|
||||
|
||||
@driver_hints.truncated
|
||||
def list_domains(self, hints):
|
||||
with sql.transaction() as session:
|
||||
with sql.session_for_read() as session:
|
||||
query = session.query(Domain)
|
||||
refs = sql.filter_limit_query(Domain, query, hints)
|
||||
return [ref.to_dict() for ref in refs
|
||||
@ -229,7 +229,7 @@ class Resource(keystone_resource.ResourceDriverV9):
|
||||
if not ids:
|
||||
return []
|
||||
else:
|
||||
with sql.transaction() as session:
|
||||
with sql.session_for_read() as session:
|
||||
query = session.query(Domain)
|
||||
query = query.filter(Domain.id.in_(ids))
|
||||
domain_refs = query.all()
|
||||
@ -243,11 +243,11 @@ class Resource(keystone_resource.ResourceDriverV9):
|
||||
return ref
|
||||
|
||||
def get_domain(self, domain_id):
|
||||
with sql.transaction() as session:
|
||||
with sql.session_for_read() as session:
|
||||
return self._get_domain(session, domain_id).to_dict()
|
||||
|
||||
def get_domain_by_name(self, domain_name):
|
||||
with sql.transaction() as session:
|
||||
with sql.session_for_read() as session:
|
||||
try:
|
||||
ref = (session.query(Domain).
|
||||
filter_by(name=domain_name).one())
|
||||
@ -260,7 +260,7 @@ class Resource(keystone_resource.ResourceDriverV9):
|
||||
|
||||
@sql.handle_conflicts(conflict_type='domain')
|
||||
def update_domain(self, domain_id, domain):
|
||||
with sql.transaction() as session:
|
||||
with sql.session_for_write() as session:
|
||||
ref = self._get_domain(session, domain_id)
|
||||
old_dict = ref.to_dict()
|
||||
for k in domain:
|
||||
@ -273,7 +273,7 @@ class Resource(keystone_resource.ResourceDriverV9):
|
||||
return ref.to_dict()
|
||||
|
||||
def delete_domain(self, domain_id):
|
||||
with sql.transaction() as session:
|
||||
with sql.session_for_write() as session:
|
||||
ref = self._get_domain(session, domain_id)
|
||||
session.delete(ref)
|
||||
|
||||
|
@ -59,12 +59,12 @@ class DomainConfig(resource.DomainConfigDriverV8):
|
||||
@sql.handle_conflicts(conflict_type='domain_config')
|
||||
def create_config_option(self, domain_id, group, option, value,
|
||||
sensitive=False):
|
||||
with sql.transaction() as session:
|
||||
with sql.session_for_write() as session:
|
||||
config_table = self.choose_table(sensitive)
|
||||
ref = config_table(domain_id=domain_id, group=group,
|
||||
option=option, value=value)
|
||||
session.add(ref)
|
||||
return ref.to_dict()
|
||||
return ref.to_dict()
|
||||
|
||||
def _get_config_option(self, session, domain_id, group, option, sensitive):
|
||||
try:
|
||||
@ -80,14 +80,14 @@ class DomainConfig(resource.DomainConfigDriverV8):
|
||||
return ref
|
||||
|
||||
def get_config_option(self, domain_id, group, option, sensitive=False):
|
||||
with sql.transaction() as session:
|
||||
with sql.session_for_read() as session:
|
||||
ref = self._get_config_option(session, domain_id, group, option,
|
||||
sensitive)
|
||||
return ref.to_dict()
|
||||
return ref.to_dict()
|
||||
|
||||
def list_config_options(self, domain_id, group=None, option=None,
|
||||
sensitive=False):
|
||||
with sql.transaction() as session:
|
||||
with sql.session_for_read() as session:
|
||||
config_table = self.choose_table(sensitive)
|
||||
query = session.query(config_table)
|
||||
query = query.filter_by(domain_id=domain_id)
|
||||
@ -99,11 +99,11 @@ class DomainConfig(resource.DomainConfigDriverV8):
|
||||
|
||||
def update_config_option(self, domain_id, group, option, value,
|
||||
sensitive=False):
|
||||
with sql.transaction() as session:
|
||||
with sql.session_for_write() as session:
|
||||
ref = self._get_config_option(session, domain_id, group, option,
|
||||
sensitive)
|
||||
ref.value = value
|
||||
return ref.to_dict()
|
||||
return ref.to_dict()
|
||||
|
||||
def delete_config_options(self, domain_id, group=None, option=None,
|
||||
sensitive=False):
|
||||
@ -114,7 +114,7 @@ class DomainConfig(resource.DomainConfigDriverV8):
|
||||
if there was nothing to delete.
|
||||
|
||||
"""
|
||||
with sql.transaction() as session:
|
||||
with sql.session_for_write() as session:
|
||||
config_table = self.choose_table(sensitive)
|
||||
query = session.query(config_table)
|
||||
query = query.filter_by(domain_id=domain_id)
|
||||
@ -126,7 +126,7 @@ class DomainConfig(resource.DomainConfigDriverV8):
|
||||
|
||||
def obtain_registration(self, domain_id, type):
|
||||
try:
|
||||
with sql.transaction() as session:
|
||||
with sql.session_for_write() as session:
|
||||
ref = ConfigRegister(type=type, domain_id=domain_id)
|
||||
session.add(ref)
|
||||
return True
|
||||
@ -136,15 +136,15 @@ class DomainConfig(resource.DomainConfigDriverV8):
|
||||
return False
|
||||
|
||||
def read_registration(self, type):
|
||||
with sql.transaction() as session:
|
||||
with sql.session_for_read() as session:
|
||||
ref = session.query(ConfigRegister).get(type)
|
||||
if not ref:
|
||||
raise exception.ConfigRegistrationNotFound()
|
||||
return ref.domain_id
|
||||
return ref.domain_id
|
||||
|
||||
def release_registration(self, domain_id, type=None):
|
||||
"""Silently delete anything registered for the domain specified."""
|
||||
with sql.transaction() as session:
|
||||
with sql.session_for_write() as session:
|
||||
query = session.query(ConfigRegister)
|
||||
if type:
|
||||
query = query.filter_by(type=type)
|
||||
|
@ -60,37 +60,37 @@ class Revoke(revoke.RevokeDriverV8):
|
||||
def _prune_expired_events(self):
|
||||
oldest = revoke.revoked_before_cutoff_time()
|
||||
|
||||
session = sql.get_session()
|
||||
dialect = session.bind.dialect.name
|
||||
batch_size = self._flush_batch_size(dialect)
|
||||
if batch_size > 0:
|
||||
query = session.query(RevocationEvent.id)
|
||||
query = query.filter(RevocationEvent.revoked_at < oldest)
|
||||
query = query.limit(batch_size).subquery()
|
||||
delete_query = (session.query(RevocationEvent).
|
||||
filter(RevocationEvent.id.in_(query)))
|
||||
while True:
|
||||
rowcount = delete_query.delete(synchronize_session=False)
|
||||
if rowcount == 0:
|
||||
break
|
||||
else:
|
||||
query = session.query(RevocationEvent)
|
||||
query = query.filter(RevocationEvent.revoked_at < oldest)
|
||||
query.delete(synchronize_session=False)
|
||||
with sql.session_for_write() as session:
|
||||
dialect = session.bind.dialect.name
|
||||
batch_size = self._flush_batch_size(dialect)
|
||||
if batch_size > 0:
|
||||
query = session.query(RevocationEvent.id)
|
||||
query = query.filter(RevocationEvent.revoked_at < oldest)
|
||||
query = query.limit(batch_size).subquery()
|
||||
delete_query = (session.query(RevocationEvent).
|
||||
filter(RevocationEvent.id.in_(query)))
|
||||
while True:
|
||||
rowcount = delete_query.delete(synchronize_session=False)
|
||||
if rowcount == 0:
|
||||
break
|
||||
else:
|
||||
query = session.query(RevocationEvent)
|
||||
query = query.filter(RevocationEvent.revoked_at < oldest)
|
||||
query.delete(synchronize_session=False)
|
||||
|
||||
session.flush()
|
||||
session.flush()
|
||||
|
||||
def list_events(self, last_fetch=None):
|
||||
session = sql.get_session()
|
||||
query = session.query(RevocationEvent).order_by(
|
||||
RevocationEvent.revoked_at)
|
||||
with sql.session_for_read() as session:
|
||||
query = session.query(RevocationEvent).order_by(
|
||||
RevocationEvent.revoked_at)
|
||||
|
||||
if last_fetch:
|
||||
query = query.filter(RevocationEvent.revoked_at > last_fetch)
|
||||
if last_fetch:
|
||||
query = query.filter(RevocationEvent.revoked_at > last_fetch)
|
||||
|
||||
events = [model.RevokeEvent(**e.to_dict()) for e in query]
|
||||
events = [model.RevokeEvent(**e.to_dict()) for e in query]
|
||||
|
||||
return events
|
||||
return events
|
||||
|
||||
def revoke(self, event):
|
||||
kwargs = dict()
|
||||
@ -98,7 +98,6 @@ class Revoke(revoke.RevokeDriverV8):
|
||||
kwargs[attr] = getattr(event, attr)
|
||||
kwargs['id'] = uuid.uuid4().hex
|
||||
record = RevocationEvent(**kwargs)
|
||||
session = sql.get_session()
|
||||
with session.begin():
|
||||
with sql.session_for_write() as session:
|
||||
session.add(record)
|
||||
self._prune_expired_events()
|
||||
self._prune_expired_events()
|
||||
|
@ -17,6 +17,6 @@ from keystone.identity.mapping_backends import sql as mapping_sql
|
||||
|
||||
def list_id_mappings():
|
||||
"""List all id_mappings for testing purposes."""
|
||||
a_session = sql.get_session()
|
||||
refs = a_session.query(mapping_sql.IDMapping).all()
|
||||
return [x.to_dict() for x in refs]
|
||||
with sql.session_for_read() as session:
|
||||
refs = session.query(mapping_sql.IDMapping).all()
|
||||
return [x.to_dict() for x in refs]
|
||||
|
@ -611,7 +611,7 @@ class SqlToken(SqlTests, test_backend.TokenTests):
|
||||
tok = token_sql.Token()
|
||||
tok.list_revoked_tokens()
|
||||
|
||||
mock_query = mock_sql.get_session().query
|
||||
mock_query = mock_sql.session_for_read().__enter__().query
|
||||
mock_query.assert_called_with(*expected_query_args)
|
||||
|
||||
def test_flush_expired_tokens_batch(self):
|
||||
@ -636,8 +636,12 @@ class SqlToken(SqlTests, test_backend.TokenTests):
|
||||
# other tests below test the differences between how they use the batch
|
||||
# strategy
|
||||
with mock.patch.object(token_sql, 'sql') as mock_sql:
|
||||
mock_sql.get_session().query().filter().delete.return_value = 0
|
||||
mock_sql.get_session().bind.dialect.name = 'mysql'
|
||||
mock_sql.session_for_write().__enter__(
|
||||
).query().filter().delete.return_value = 0
|
||||
|
||||
mock_sql.session_for_write().__enter__(
|
||||
).bind.dialect.name = 'mysql'
|
||||
|
||||
tok = token_sql.Token()
|
||||
expiry_mock = mock.Mock()
|
||||
ITERS = [1, 2, 3]
|
||||
@ -648,7 +652,10 @@ class SqlToken(SqlTests, test_backend.TokenTests):
|
||||
# The expiry strategy is only invoked once, the other calls are via
|
||||
# the yield return.
|
||||
self.assertEqual(1, expiry_mock.call_count)
|
||||
mock_delete = mock_sql.get_session().query().filter().delete
|
||||
|
||||
mock_delete = mock_sql.session_for_write().__enter__(
|
||||
).query().filter().delete
|
||||
|
||||
self.assertThat(mock_delete.call_args_list,
|
||||
matchers.HasLength(len(ITERS)))
|
||||
|
||||
|
@ -161,7 +161,8 @@ class SqlMigrateBase(unit.SQLDriverOverrides, unit.TestCase):
|
||||
self.repo_package())
|
||||
self.schema = versioning_api.ControlledSchema.create(
|
||||
self.engine,
|
||||
self.repo_path, self.initial_db_version)
|
||||
self.repo_path,
|
||||
self.initial_db_version)
|
||||
|
||||
# auto-detect the highest available schema version in the migrate_repo
|
||||
self.max_version = self.schema.repository.version().version
|
||||
|
@ -86,11 +86,11 @@ class Token(token.persistence.TokenDriverV8):
|
||||
def get_token(self, token_id):
|
||||
if token_id is None:
|
||||
raise exception.TokenNotFound(token_id=token_id)
|
||||
session = sql.get_session()
|
||||
token_ref = session.query(TokenModel).get(token_id)
|
||||
if not token_ref or not token_ref.valid:
|
||||
raise exception.TokenNotFound(token_id=token_id)
|
||||
return token_ref.to_dict()
|
||||
with sql.session_for_read() as session:
|
||||
token_ref = session.query(TokenModel).get(token_id)
|
||||
if not token_ref or not token_ref.valid:
|
||||
raise exception.TokenNotFound(token_id=token_id)
|
||||
return token_ref.to_dict()
|
||||
|
||||
def create_token(self, token_id, data):
|
||||
data_copy = copy.deepcopy(data)
|
||||
@ -101,14 +101,12 @@ class Token(token.persistence.TokenDriverV8):
|
||||
|
||||
token_ref = TokenModel.from_dict(data_copy)
|
||||
token_ref.valid = True
|
||||
session = sql.get_session()
|
||||
with session.begin():
|
||||
with sql.session_for_write() as session:
|
||||
session.add(token_ref)
|
||||
return token_ref.to_dict()
|
||||
|
||||
def delete_token(self, token_id):
|
||||
session = sql.get_session()
|
||||
with session.begin():
|
||||
with sql.session_for_write() as session:
|
||||
token_ref = session.query(TokenModel).get(token_id)
|
||||
if not token_ref or not token_ref.valid:
|
||||
raise exception.TokenNotFound(token_id=token_id)
|
||||
@ -124,9 +122,8 @@ class Token(token.persistence.TokenDriverV8):
|
||||
or the trustor's user ID, so will use trust_id to query the tokens.
|
||||
|
||||
"""
|
||||
session = sql.get_session()
|
||||
token_list = []
|
||||
with session.begin():
|
||||
with sql.session_for_write() as session:
|
||||
now = timeutils.utcnow()
|
||||
query = session.query(TokenModel)
|
||||
query = query.filter_by(valid=True)
|
||||
@ -167,38 +164,37 @@ class Token(token.persistence.TokenDriverV8):
|
||||
return False
|
||||
|
||||
def _list_tokens_for_trust(self, trust_id):
|
||||
session = sql.get_session()
|
||||
tokens = []
|
||||
now = timeutils.utcnow()
|
||||
query = session.query(TokenModel)
|
||||
query = query.filter(TokenModel.expires > now)
|
||||
query = query.filter(TokenModel.trust_id == trust_id)
|
||||
with sql.session_for_read() as session:
|
||||
tokens = []
|
||||
now = timeutils.utcnow()
|
||||
query = session.query(TokenModel)
|
||||
query = query.filter(TokenModel.expires > now)
|
||||
query = query.filter(TokenModel.trust_id == trust_id)
|
||||
|
||||
token_references = query.filter_by(valid=True)
|
||||
for token_ref in token_references:
|
||||
token_ref_dict = token_ref.to_dict()
|
||||
tokens.append(token_ref_dict['id'])
|
||||
return tokens
|
||||
token_references = query.filter_by(valid=True)
|
||||
for token_ref in token_references:
|
||||
token_ref_dict = token_ref.to_dict()
|
||||
tokens.append(token_ref_dict['id'])
|
||||
return tokens
|
||||
|
||||
def _list_tokens_for_user(self, user_id, tenant_id=None):
|
||||
session = sql.get_session()
|
||||
tokens = []
|
||||
now = timeutils.utcnow()
|
||||
query = session.query(TokenModel)
|
||||
query = query.filter(TokenModel.expires > now)
|
||||
query = query.filter(TokenModel.user_id == user_id)
|
||||
with sql.session_for_read() as session:
|
||||
tokens = []
|
||||
now = timeutils.utcnow()
|
||||
query = session.query(TokenModel)
|
||||
query = query.filter(TokenModel.expires > now)
|
||||
query = query.filter(TokenModel.user_id == user_id)
|
||||
|
||||
token_references = query.filter_by(valid=True)
|
||||
for token_ref in token_references:
|
||||
token_ref_dict = token_ref.to_dict()
|
||||
if self._tenant_matches(tenant_id, token_ref_dict):
|
||||
tokens.append(token_ref['id'])
|
||||
return tokens
|
||||
token_references = query.filter_by(valid=True)
|
||||
for token_ref in token_references:
|
||||
token_ref_dict = token_ref.to_dict()
|
||||
if self._tenant_matches(tenant_id, token_ref_dict):
|
||||
tokens.append(token_ref['id'])
|
||||
return tokens
|
||||
|
||||
def _list_tokens_for_consumer(self, user_id, consumer_id):
|
||||
tokens = []
|
||||
session = sql.get_session()
|
||||
with session.begin():
|
||||
with sql.session_for_write() as session:
|
||||
now = timeutils.utcnow()
|
||||
query = session.query(TokenModel)
|
||||
query = query.filter(TokenModel.expires > now)
|
||||
@ -223,29 +219,29 @@ class Token(token.persistence.TokenDriverV8):
|
||||
return self._list_tokens_for_user(user_id, tenant_id)
|
||||
|
||||
def list_revoked_tokens(self):
|
||||
session = sql.get_session()
|
||||
tokens = []
|
||||
now = timeutils.utcnow()
|
||||
query = session.query(TokenModel.id, TokenModel.expires,
|
||||
TokenModel.extra)
|
||||
query = query.filter(TokenModel.expires > now)
|
||||
token_references = query.filter_by(valid=False)
|
||||
for token_ref in token_references:
|
||||
token_data = token_ref[2]['token_data']
|
||||
if 'access' in token_data:
|
||||
# It's a v2 token.
|
||||
audit_ids = token_data['access']['token']['audit_ids']
|
||||
else:
|
||||
# It's a v3 token.
|
||||
audit_ids = token_data['token']['audit_ids']
|
||||
with sql.session_for_read() as session:
|
||||
tokens = []
|
||||
now = timeutils.utcnow()
|
||||
query = session.query(TokenModel.id, TokenModel.expires,
|
||||
TokenModel.extra)
|
||||
query = query.filter(TokenModel.expires > now)
|
||||
token_references = query.filter_by(valid=False)
|
||||
for token_ref in token_references:
|
||||
token_data = token_ref[2]['token_data']
|
||||
if 'access' in token_data:
|
||||
# It's a v2 token.
|
||||
audit_ids = token_data['access']['token']['audit_ids']
|
||||
else:
|
||||
# It's a v3 token.
|
||||
audit_ids = token_data['token']['audit_ids']
|
||||
|
||||
record = {
|
||||
'id': token_ref[0],
|
||||
'expires': token_ref[1],
|
||||
'audit_id': audit_ids[0],
|
||||
}
|
||||
tokens.append(record)
|
||||
return tokens
|
||||
record = {
|
||||
'id': token_ref[0],
|
||||
'expires': token_ref[1],
|
||||
'audit_id': audit_ids[0],
|
||||
}
|
||||
tokens.append(record)
|
||||
return tokens
|
||||
|
||||
def _expiry_range_strategy(self, dialect):
|
||||
"""Choose a token range expiration strategy
|
||||
@ -273,18 +269,18 @@ class Token(token.persistence.TokenDriverV8):
|
||||
return _expiry_range_all
|
||||
|
||||
def flush_expired_tokens(self):
|
||||
session = sql.get_session()
|
||||
dialect = session.bind.dialect.name
|
||||
expiry_range_func = self._expiry_range_strategy(dialect)
|
||||
query = session.query(TokenModel.expires)
|
||||
total_removed = 0
|
||||
upper_bound_func = timeutils.utcnow
|
||||
for expiry_time in expiry_range_func(session, upper_bound_func):
|
||||
delete_query = query.filter(TokenModel.expires <=
|
||||
expiry_time)
|
||||
row_count = delete_query.delete(synchronize_session=False)
|
||||
total_removed += row_count
|
||||
LOG.debug('Removed %d total expired tokens', total_removed)
|
||||
with sql.session_for_write() as session:
|
||||
dialect = session.bind.dialect.name
|
||||
expiry_range_func = self._expiry_range_strategy(dialect)
|
||||
query = session.query(TokenModel.expires)
|
||||
total_removed = 0
|
||||
upper_bound_func = timeutils.utcnow
|
||||
for expiry_time in expiry_range_func(session, upper_bound_func):
|
||||
delete_query = query.filter(TokenModel.expires <=
|
||||
expiry_time)
|
||||
row_count = delete_query.delete(synchronize_session=False)
|
||||
total_removed += row_count
|
||||
LOG.debug('Removed %d total expired tokens', total_removed)
|
||||
|
||||
session.flush()
|
||||
LOG.info(_LI('Total expired tokens removed: %d'), total_removed)
|
||||
session.flush()
|
||||
LOG.info(_LI('Total expired tokens removed: %d'), total_removed)
|
||||
|
@ -59,7 +59,7 @@ class TrustRole(sql.ModelBase):
|
||||
class Trust(trust.TrustDriverV8):
|
||||
@sql.handle_conflicts(conflict_type='trust')
|
||||
def create_trust(self, trust_id, trust, roles):
|
||||
with sql.transaction() as session:
|
||||
with sql.session_for_write() as session:
|
||||
ref = TrustModel.from_dict(trust)
|
||||
ref['id'] = trust_id
|
||||
if ref.get('expires_at') and ref['expires_at'].tzinfo is not None:
|
||||
@ -72,9 +72,9 @@ class Trust(trust.TrustDriverV8):
|
||||
trust_role.role_id = role['id']
|
||||
added_roles.append({'id': role['id']})
|
||||
session.add(trust_role)
|
||||
trust_dict = ref.to_dict()
|
||||
trust_dict['roles'] = added_roles
|
||||
return trust_dict
|
||||
trust_dict = ref.to_dict()
|
||||
trust_dict['roles'] = added_roles
|
||||
return trust_dict
|
||||
|
||||
def _add_roles(self, trust_id, session, trust_dict):
|
||||
roles = []
|
||||
@ -86,7 +86,7 @@ class Trust(trust.TrustDriverV8):
|
||||
def consume_use(self, trust_id):
|
||||
|
||||
for attempt in range(MAXIMUM_CONSUME_ATTEMPTS):
|
||||
with sql.transaction() as session:
|
||||
with sql.session_for_write() as session:
|
||||
try:
|
||||
query_result = (session.query(TrustModel.remaining_uses).
|
||||
filter_by(id=trust_id).
|
||||
@ -132,51 +132,51 @@ class Trust(trust.TrustDriverV8):
|
||||
raise exception.TrustConsumeMaximumAttempt(trust_id=trust_id)
|
||||
|
||||
def get_trust(self, trust_id, deleted=False):
|
||||
session = sql.get_session()
|
||||
query = session.query(TrustModel).filter_by(id=trust_id)
|
||||
if not deleted:
|
||||
query = query.filter_by(deleted_at=None)
|
||||
ref = query.first()
|
||||
if ref is None:
|
||||
raise exception.TrustNotFound(trust_id=trust_id)
|
||||
if ref.expires_at is not None and not deleted:
|
||||
now = timeutils.utcnow()
|
||||
if now > ref.expires_at:
|
||||
with sql.session_for_read() as session:
|
||||
query = session.query(TrustModel).filter_by(id=trust_id)
|
||||
if not deleted:
|
||||
query = query.filter_by(deleted_at=None)
|
||||
ref = query.first()
|
||||
if ref is None:
|
||||
raise exception.TrustNotFound(trust_id=trust_id)
|
||||
# Do not return trusts that can't be used anymore
|
||||
if ref.remaining_uses is not None and not deleted:
|
||||
if ref.remaining_uses <= 0:
|
||||
raise exception.TrustNotFound(trust_id=trust_id)
|
||||
trust_dict = ref.to_dict()
|
||||
if ref.expires_at is not None and not deleted:
|
||||
now = timeutils.utcnow()
|
||||
if now > ref.expires_at:
|
||||
raise exception.TrustNotFound(trust_id=trust_id)
|
||||
# Do not return trusts that can't be used anymore
|
||||
if ref.remaining_uses is not None and not deleted:
|
||||
if ref.remaining_uses <= 0:
|
||||
raise exception.TrustNotFound(trust_id=trust_id)
|
||||
trust_dict = ref.to_dict()
|
||||
|
||||
self._add_roles(trust_id, session, trust_dict)
|
||||
return trust_dict
|
||||
self._add_roles(trust_id, session, trust_dict)
|
||||
return trust_dict
|
||||
|
||||
@sql.handle_conflicts(conflict_type='trust')
|
||||
def list_trusts(self):
|
||||
session = sql.get_session()
|
||||
trusts = session.query(TrustModel).filter_by(deleted_at=None)
|
||||
return [trust_ref.to_dict() for trust_ref in trusts]
|
||||
with sql.session_for_read() as session:
|
||||
trusts = session.query(TrustModel).filter_by(deleted_at=None)
|
||||
return [trust_ref.to_dict() for trust_ref in trusts]
|
||||
|
||||
@sql.handle_conflicts(conflict_type='trust')
|
||||
def list_trusts_for_trustee(self, trustee_user_id):
|
||||
session = sql.get_session()
|
||||
trusts = (session.query(TrustModel).
|
||||
filter_by(deleted_at=None).
|
||||
filter_by(trustee_user_id=trustee_user_id))
|
||||
return [trust_ref.to_dict() for trust_ref in trusts]
|
||||
with sql.session_for_read() as session:
|
||||
trusts = (session.query(TrustModel).
|
||||
filter_by(deleted_at=None).
|
||||
filter_by(trustee_user_id=trustee_user_id))
|
||||
return [trust_ref.to_dict() for trust_ref in trusts]
|
||||
|
||||
@sql.handle_conflicts(conflict_type='trust')
|
||||
def list_trusts_for_trustor(self, trustor_user_id):
|
||||
session = sql.get_session()
|
||||
trusts = (session.query(TrustModel).
|
||||
filter_by(deleted_at=None).
|
||||
filter_by(trustor_user_id=trustor_user_id))
|
||||
return [trust_ref.to_dict() for trust_ref in trusts]
|
||||
with sql.session_for_read() as session:
|
||||
trusts = (session.query(TrustModel).
|
||||
filter_by(deleted_at=None).
|
||||
filter_by(trustor_user_id=trustor_user_id))
|
||||
return [trust_ref.to_dict() for trust_ref in trusts]
|
||||
|
||||
@sql.handle_conflicts(conflict_type='trust')
|
||||
def delete_trust(self, trust_id):
|
||||
with sql.transaction() as session:
|
||||
with sql.session_for_write() as session:
|
||||
trust_ref = session.query(TrustModel).get(trust_id)
|
||||
if not trust_ref:
|
||||
raise exception.TrustNotFound(trust_id=trust_id)
|
||||
|
Loading…
Reference in New Issue
Block a user