Merge "Use the new enginefacade from oslo.db"

This commit is contained in:
Jenkins 2016-02-24 19:14:46 +00:00 committed by Gerrit Code Review
commit d37af165d0
24 changed files with 700 additions and 745 deletions

View File

@ -56,7 +56,7 @@ class Assignment(keystone_assignment.AssignmentDriverV8):
return 'sql' return 'sql'
def list_user_ids_for_project(self, tenant_id): def list_user_ids_for_project(self, tenant_id):
with sql.transaction() as session: with sql.session_for_read() as session:
query = session.query(RoleAssignment.actor_id) query = session.query(RoleAssignment.actor_id)
query = query.filter_by(type=AssignmentType.USER_PROJECT) query = query.filter_by(type=AssignmentType.USER_PROJECT)
query = query.filter_by(target_id=tenant_id) query = query.filter_by(target_id=tenant_id)
@ -71,7 +71,7 @@ class Assignment(keystone_assignment.AssignmentDriverV8):
assignment_type = AssignmentType.calculate_type( assignment_type = AssignmentType.calculate_type(
user_id, group_id, project_id, domain_id) user_id, group_id, project_id, domain_id)
try: try:
with sql.transaction() as session: with sql.session_for_write() as session:
session.add(RoleAssignment( session.add(RoleAssignment(
type=assignment_type, type=assignment_type,
actor_id=user_id or group_id, actor_id=user_id or group_id,
@ -85,7 +85,7 @@ class Assignment(keystone_assignment.AssignmentDriverV8):
def list_grant_role_ids(self, user_id=None, group_id=None, def list_grant_role_ids(self, user_id=None, group_id=None,
domain_id=None, project_id=None, domain_id=None, project_id=None,
inherited_to_projects=False): inherited_to_projects=False):
with sql.transaction() as session: with sql.session_for_read() as session:
q = session.query(RoleAssignment.role_id) q = session.query(RoleAssignment.role_id)
q = q.filter(RoleAssignment.actor_id == (user_id or group_id)) q = q.filter(RoleAssignment.actor_id == (user_id or group_id))
q = q.filter(RoleAssignment.target_id == (project_id or domain_id)) q = q.filter(RoleAssignment.target_id == (project_id or domain_id))
@ -104,7 +104,7 @@ class Assignment(keystone_assignment.AssignmentDriverV8):
def check_grant_role_id(self, role_id, user_id=None, group_id=None, def check_grant_role_id(self, role_id, user_id=None, group_id=None,
domain_id=None, project_id=None, domain_id=None, project_id=None,
inherited_to_projects=False): inherited_to_projects=False):
with sql.transaction() as session: with sql.session_for_read() as session:
try: try:
q = self._build_grant_filter( q = self._build_grant_filter(
session, role_id, user_id, group_id, domain_id, project_id, session, role_id, user_id, group_id, domain_id, project_id,
@ -120,7 +120,7 @@ class Assignment(keystone_assignment.AssignmentDriverV8):
def delete_grant(self, role_id, user_id=None, group_id=None, def delete_grant(self, role_id, user_id=None, group_id=None,
domain_id=None, project_id=None, domain_id=None, project_id=None,
inherited_to_projects=False): inherited_to_projects=False):
with sql.transaction() as session: with sql.session_for_write() as session:
q = self._build_grant_filter( q = self._build_grant_filter(
session, role_id, user_id, group_id, domain_id, project_id, session, role_id, user_id, group_id, domain_id, project_id,
inherited_to_projects) inherited_to_projects)
@ -145,11 +145,11 @@ class Assignment(keystone_assignment.AssignmentDriverV8):
RoleAssignment.inherited == inherited, RoleAssignment.inherited == inherited,
RoleAssignment.actor_id.in_(actors)) RoleAssignment.actor_id.in_(actors))
with sql.transaction() as session: with sql.session_for_read() as session:
query = session.query(RoleAssignment.target_id).filter( query = session.query(RoleAssignment.target_id).filter(
sql_constraints).distinct() sql_constraints).distinct()
return [x.target_id for x in query.all()] return [x.target_id for x in query.all()]
def list_project_ids_for_user(self, user_id, group_ids, hints, def list_project_ids_for_user(self, user_id, group_ids, hints,
inherited=False): inherited=False):
@ -161,7 +161,7 @@ class Assignment(keystone_assignment.AssignmentDriverV8):
def list_domain_ids_for_user(self, user_id, group_ids, hints, def list_domain_ids_for_user(self, user_id, group_ids, hints,
inherited=False): inherited=False):
with sql.transaction() as session: with sql.session_for_read() as session:
query = session.query(RoleAssignment.target_id) query = session.query(RoleAssignment.target_id)
filters = [] filters = []
@ -197,10 +197,10 @@ class Assignment(keystone_assignment.AssignmentDriverV8):
RoleAssignment.inherited == false(), RoleAssignment.inherited == false(),
RoleAssignment.actor_id.in_(group_ids)) RoleAssignment.actor_id.in_(group_ids))
with sql.transaction() as session: with sql.session_for_read() as session:
query = session.query(RoleAssignment.role_id).filter( query = session.query(RoleAssignment.role_id).filter(
sql_constraints).distinct() sql_constraints).distinct()
return [role.role_id for role in query.all()] return [role.role_id for role in query.all()]
def list_role_ids_for_groups_on_project( def list_role_ids_for_groups_on_project(
self, group_ids, project_id, project_domain_id, project_parents): self, group_ids, project_id, project_domain_id, project_parents):
@ -237,13 +237,13 @@ class Assignment(keystone_assignment.AssignmentDriverV8):
sql_constraints = sqlalchemy.and_( sql_constraints = sqlalchemy.and_(
sql_constraints, RoleAssignment.actor_id.in_(group_ids)) sql_constraints, RoleAssignment.actor_id.in_(group_ids))
with sql.transaction() as session: with sql.session_for_read() as session:
# NOTE(morganfainberg): Only select the columns we actually care # NOTE(morganfainberg): Only select the columns we actually care
# about here, in this case role_id. # about here, in this case role_id.
query = session.query(RoleAssignment.role_id).filter( query = session.query(RoleAssignment.role_id).filter(
sql_constraints).distinct() sql_constraints).distinct()
return [result.role_id for result in query.all()] return [result.role_id for result in query.all()]
def list_project_ids_for_groups(self, group_ids, hints, def list_project_ids_for_groups(self, group_ids, hints,
inherited=False): inherited=False):
@ -260,14 +260,14 @@ class Assignment(keystone_assignment.AssignmentDriverV8):
RoleAssignment.inherited == inherited, RoleAssignment.inherited == inherited,
RoleAssignment.actor_id.in_(group_ids)) RoleAssignment.actor_id.in_(group_ids))
with sql.transaction() as session: with sql.session_for_read() as session:
query = session.query(RoleAssignment.target_id).filter( query = session.query(RoleAssignment.target_id).filter(
group_sql_conditions).distinct() group_sql_conditions).distinct()
return [x.target_id for x in query.all()] return [x.target_id for x in query.all()]
def add_role_to_user_and_project(self, user_id, tenant_id, role_id): def add_role_to_user_and_project(self, user_id, tenant_id, role_id):
try: try:
with sql.transaction() as session: with sql.session_for_write() as session:
session.add(RoleAssignment( session.add(RoleAssignment(
type=AssignmentType.USER_PROJECT, type=AssignmentType.USER_PROJECT,
actor_id=user_id, target_id=tenant_id, actor_id=user_id, target_id=tenant_id,
@ -278,7 +278,7 @@ class Assignment(keystone_assignment.AssignmentDriverV8):
raise exception.Conflict(type='role grant', details=msg) raise exception.Conflict(type='role grant', details=msg)
def remove_role_from_user_and_project(self, user_id, tenant_id, role_id): def remove_role_from_user_and_project(self, user_id, tenant_id, role_id):
with sql.transaction() as session: with sql.session_for_write() as session:
q = session.query(RoleAssignment) q = session.query(RoleAssignment)
q = q.filter_by(actor_id=user_id) q = q.filter_by(actor_id=user_id)
q = q.filter_by(target_id=tenant_id) q = q.filter_by(target_id=tenant_id)
@ -368,7 +368,7 @@ class Assignment(keystone_assignment.AssignmentDriverV8):
assignment['inherited_to_projects'] = 'projects' assignment['inherited_to_projects'] = 'projects'
return assignment return assignment
with sql.transaction() as session: with sql.session_for_read() as session:
assignment_types = self._get_assignment_types( assignment_types = self._get_assignment_types(
user_id, group_ids, project_ids, domain_id) user_id, group_ids, project_ids, domain_id)
@ -400,25 +400,25 @@ class Assignment(keystone_assignment.AssignmentDriverV8):
return [denormalize_role(ref) for ref in query.all()] return [denormalize_role(ref) for ref in query.all()]
def delete_project_assignments(self, project_id): def delete_project_assignments(self, project_id):
with sql.transaction() as session: with sql.session_for_write() as session:
q = session.query(RoleAssignment) q = session.query(RoleAssignment)
q = q.filter_by(target_id=project_id) q = q.filter_by(target_id=project_id)
q.delete(False) q.delete(False)
def delete_role_assignments(self, role_id): def delete_role_assignments(self, role_id):
with sql.transaction() as session: with sql.session_for_write() as session:
q = session.query(RoleAssignment) q = session.query(RoleAssignment)
q = q.filter_by(role_id=role_id) q = q.filter_by(role_id=role_id)
q.delete(False) q.delete(False)
def delete_user_assignments(self, user_id): def delete_user_assignments(self, user_id):
with sql.transaction() as session: with sql.session_for_write() as session:
q = session.query(RoleAssignment) q = session.query(RoleAssignment)
q = q.filter_by(actor_id=user_id) q = q.filter_by(actor_id=user_id)
q.delete(False) q.delete(False)
def delete_group_assignments(self, group_id): def delete_group_assignments(self, group_id):
with sql.transaction() as session: with sql.session_for_write() as session:
q = session.query(RoleAssignment) q = session.query(RoleAssignment)
q = q.filter_by(actor_id=group_id) q = q.filter_by(actor_id=group_id)
q.delete(False) q.delete(False)

View File

@ -19,14 +19,14 @@ class Role(assignment.RoleDriverV8):
@sql.handle_conflicts(conflict_type='role') @sql.handle_conflicts(conflict_type='role')
def create_role(self, role_id, role): def create_role(self, role_id, role):
with sql.transaction() as session: with sql.session_for_write() as session:
ref = RoleTable.from_dict(role) ref = RoleTable.from_dict(role)
session.add(ref) session.add(ref)
return ref.to_dict() return ref.to_dict()
@sql.truncated @sql.truncated
def list_roles(self, hints): def list_roles(self, hints):
with sql.transaction() as session: with sql.session_for_read() as session:
query = session.query(RoleTable) query = session.query(RoleTable)
refs = sql.filter_limit_query(RoleTable, query, hints) refs = sql.filter_limit_query(RoleTable, query, hints)
return [ref.to_dict() for ref in refs] return [ref.to_dict() for ref in refs]
@ -35,7 +35,7 @@ class Role(assignment.RoleDriverV8):
if not ids: if not ids:
return [] return []
else: else:
with sql.transaction() as session: with sql.session_for_read() as session:
query = session.query(RoleTable) query = session.query(RoleTable)
query = query.filter(RoleTable.id.in_(ids)) query = query.filter(RoleTable.id.in_(ids))
role_refs = query.all() role_refs = query.all()
@ -48,12 +48,12 @@ class Role(assignment.RoleDriverV8):
return ref return ref
def get_role(self, role_id): def get_role(self, role_id):
with sql.transaction() as session: with sql.session_for_read() as session:
return self._get_role(session, role_id).to_dict() return self._get_role(session, role_id).to_dict()
@sql.handle_conflicts(conflict_type='role') @sql.handle_conflicts(conflict_type='role')
def update_role(self, role_id, role): def update_role(self, role_id, role):
with sql.transaction() as session: with sql.session_for_write() as session:
ref = self._get_role(session, role_id) ref = self._get_role(session, role_id)
old_dict = ref.to_dict() old_dict = ref.to_dict()
for k in role: for k in role:
@ -66,7 +66,7 @@ class Role(assignment.RoleDriverV8):
return ref.to_dict() return ref.to_dict()
def delete_role(self, role_id): def delete_role(self, role_id):
with sql.transaction() as session: with sql.session_for_write() as session:
ref = self._get_role(session, role_id) ref = self._get_role(session, role_id)
session.delete(ref) session.delete(ref)

View File

@ -55,7 +55,7 @@ class Assignment(keystone_assignment.AssignmentDriverV9):
assignment_type = AssignmentType.calculate_type( assignment_type = AssignmentType.calculate_type(
user_id, group_id, project_id, domain_id) user_id, group_id, project_id, domain_id)
try: try:
with sql.transaction() as session: with sql.session_for_write() as session:
session.add(RoleAssignment( session.add(RoleAssignment(
type=assignment_type, type=assignment_type,
actor_id=user_id or group_id, actor_id=user_id or group_id,
@ -69,7 +69,7 @@ class Assignment(keystone_assignment.AssignmentDriverV9):
def list_grant_role_ids(self, user_id=None, group_id=None, def list_grant_role_ids(self, user_id=None, group_id=None,
domain_id=None, project_id=None, domain_id=None, project_id=None,
inherited_to_projects=False): inherited_to_projects=False):
with sql.transaction() as session: with sql.session_for_read() as session:
q = session.query(RoleAssignment.role_id) q = session.query(RoleAssignment.role_id)
q = q.filter(RoleAssignment.actor_id == (user_id or group_id)) q = q.filter(RoleAssignment.actor_id == (user_id or group_id))
q = q.filter(RoleAssignment.target_id == (project_id or domain_id)) q = q.filter(RoleAssignment.target_id == (project_id or domain_id))
@ -88,7 +88,7 @@ class Assignment(keystone_assignment.AssignmentDriverV9):
def check_grant_role_id(self, role_id, user_id=None, group_id=None, def check_grant_role_id(self, role_id, user_id=None, group_id=None,
domain_id=None, project_id=None, domain_id=None, project_id=None,
inherited_to_projects=False): inherited_to_projects=False):
with sql.transaction() as session: with sql.session_for_read() as session:
try: try:
q = self._build_grant_filter( q = self._build_grant_filter(
session, role_id, user_id, group_id, domain_id, project_id, session, role_id, user_id, group_id, domain_id, project_id,
@ -104,7 +104,7 @@ class Assignment(keystone_assignment.AssignmentDriverV9):
def delete_grant(self, role_id, user_id=None, group_id=None, def delete_grant(self, role_id, user_id=None, group_id=None,
domain_id=None, project_id=None, domain_id=None, project_id=None,
inherited_to_projects=False): inherited_to_projects=False):
with sql.transaction() as session: with sql.session_for_write() as session:
q = self._build_grant_filter( q = self._build_grant_filter(
session, role_id, user_id, group_id, domain_id, project_id, session, role_id, user_id, group_id, domain_id, project_id,
inherited_to_projects) inherited_to_projects)
@ -117,7 +117,7 @@ class Assignment(keystone_assignment.AssignmentDriverV9):
def add_role_to_user_and_project(self, user_id, tenant_id, role_id): def add_role_to_user_and_project(self, user_id, tenant_id, role_id):
try: try:
with sql.transaction() as session: with sql.session_for_write() as session:
session.add(RoleAssignment( session.add(RoleAssignment(
type=AssignmentType.USER_PROJECT, type=AssignmentType.USER_PROJECT,
actor_id=user_id, target_id=tenant_id, actor_id=user_id, target_id=tenant_id,
@ -128,7 +128,7 @@ class Assignment(keystone_assignment.AssignmentDriverV9):
raise exception.Conflict(type='role grant', details=msg) raise exception.Conflict(type='role grant', details=msg)
def remove_role_from_user_and_project(self, user_id, tenant_id, role_id): def remove_role_from_user_and_project(self, user_id, tenant_id, role_id):
with sql.transaction() as session: with sql.session_for_write() as session:
q = session.query(RoleAssignment) q = session.query(RoleAssignment)
q = q.filter_by(actor_id=user_id) q = q.filter_by(actor_id=user_id)
q = q.filter_by(target_id=tenant_id) q = q.filter_by(target_id=tenant_id)
@ -218,7 +218,7 @@ class Assignment(keystone_assignment.AssignmentDriverV9):
assignment['inherited_to_projects'] = 'projects' assignment['inherited_to_projects'] = 'projects'
return assignment return assignment
with sql.transaction() as session: with sql.session_for_read() as session:
assignment_types = self._get_assignment_types( assignment_types = self._get_assignment_types(
user_id, group_ids, project_ids, domain_id) user_id, group_ids, project_ids, domain_id)
@ -250,7 +250,7 @@ class Assignment(keystone_assignment.AssignmentDriverV9):
return [denormalize_role(ref) for ref in query.all()] return [denormalize_role(ref) for ref in query.all()]
def delete_project_assignments(self, project_id): def delete_project_assignments(self, project_id):
with sql.transaction() as session: with sql.session_for_write() as session:
q = session.query(RoleAssignment) q = session.query(RoleAssignment)
q = q.filter_by(target_id=project_id).filter( q = q.filter_by(target_id=project_id).filter(
RoleAssignment.type.in_((AssignmentType.USER_PROJECT, RoleAssignment.type.in_((AssignmentType.USER_PROJECT,
@ -259,13 +259,13 @@ class Assignment(keystone_assignment.AssignmentDriverV9):
q.delete(False) q.delete(False)
def delete_role_assignments(self, role_id): def delete_role_assignments(self, role_id):
with sql.transaction() as session: with sql.session_for_write() as session:
q = session.query(RoleAssignment) q = session.query(RoleAssignment)
q = q.filter_by(role_id=role_id) q = q.filter_by(role_id=role_id)
q.delete(False) q.delete(False)
def delete_user_assignments(self, user_id): def delete_user_assignments(self, user_id):
with sql.transaction() as session: with sql.session_for_write() as session:
q = session.query(RoleAssignment) q = session.query(RoleAssignment)
q = q.filter_by(actor_id=user_id).filter( q = q.filter_by(actor_id=user_id).filter(
RoleAssignment.type.in_((AssignmentType.USER_PROJECT, RoleAssignment.type.in_((AssignmentType.USER_PROJECT,
@ -274,7 +274,7 @@ class Assignment(keystone_assignment.AssignmentDriverV9):
q.delete(False) q.delete(False)
def delete_group_assignments(self, group_id): def delete_group_assignments(self, group_id):
with sql.transaction() as session: with sql.session_for_write() as session:
q = session.query(RoleAssignment) q = session.query(RoleAssignment)
q = q.filter_by(actor_id=group_id).filter( q = q.filter_by(actor_id=group_id).filter(
RoleAssignment.type.in_((AssignmentType.GROUP_PROJECT, RoleAssignment.type.in_((AssignmentType.GROUP_PROJECT,

View File

@ -29,7 +29,7 @@ class Role(assignment.RoleDriverV9):
@sql.handle_conflicts(conflict_type='role') @sql.handle_conflicts(conflict_type='role')
def create_role(self, role_id, role): def create_role(self, role_id, role):
with sql.transaction() as session: with sql.session_for_write() as session:
ref = RoleTable.from_dict(role) ref = RoleTable.from_dict(role)
session.add(ref) session.add(ref)
return ref.to_dict() return ref.to_dict()
@ -46,7 +46,7 @@ class Role(assignment.RoleDriverV9):
if (f['name'] == 'domain_id' and f['value'] is None): if (f['name'] == 'domain_id' and f['value'] is None):
f['value'] = NULL_DOMAIN_ID f['value'] = NULL_DOMAIN_ID
with sql.transaction() as session: with sql.session_for_read() as session:
query = session.query(RoleTable) query = session.query(RoleTable)
refs = sql.filter_limit_query(RoleTable, query, hints) refs = sql.filter_limit_query(RoleTable, query, hints)
return [ref.to_dict() for ref in refs] return [ref.to_dict() for ref in refs]
@ -55,7 +55,7 @@ class Role(assignment.RoleDriverV9):
if not ids: if not ids:
return [] return []
else: else:
with sql.transaction() as session: with sql.session_for_read() as session:
query = session.query(RoleTable) query = session.query(RoleTable)
query = query.filter(RoleTable.id.in_(ids)) query = query.filter(RoleTable.id.in_(ids))
role_refs = query.all() role_refs = query.all()
@ -68,12 +68,12 @@ class Role(assignment.RoleDriverV9):
return ref return ref
def get_role(self, role_id): def get_role(self, role_id):
with sql.transaction() as session: with sql.session_for_read() as session:
return self._get_role(session, role_id).to_dict() return self._get_role(session, role_id).to_dict()
@sql.handle_conflicts(conflict_type='role') @sql.handle_conflicts(conflict_type='role')
def update_role(self, role_id, role): def update_role(self, role_id, role):
with sql.transaction() as session: with sql.session_for_write() as session:
ref = self._get_role(session, role_id) ref = self._get_role(session, role_id)
old_dict = ref.to_dict() old_dict = ref.to_dict()
for k in role: for k in role:
@ -86,7 +86,7 @@ class Role(assignment.RoleDriverV9):
return ref.to_dict() return ref.to_dict()
def delete_role(self, role_id): def delete_role(self, role_id):
with sql.transaction() as session: with sql.session_for_write() as session:
ref = self._get_role(session, role_id) ref = self._get_role(session, role_id)
session.delete(ref) session.delete(ref)
@ -105,7 +105,7 @@ class Role(assignment.RoleDriverV9):
@sql.handle_conflicts(conflict_type='implied_role') @sql.handle_conflicts(conflict_type='implied_role')
def create_implied_role(self, prior_role_id, implied_role_id): def create_implied_role(self, prior_role_id, implied_role_id):
with sql.transaction() as session: with sql.session_for_write() as session:
inference = {'prior_role_id': prior_role_id, inference = {'prior_role_id': prior_role_id,
'implied_role_id': implied_role_id} 'implied_role_id': implied_role_id}
ref = ImpliedRoleTable.from_dict(inference) ref = ImpliedRoleTable.from_dict(inference)
@ -119,13 +119,13 @@ class Role(assignment.RoleDriverV9):
return ref.to_dict() return ref.to_dict()
def delete_implied_role(self, prior_role_id, implied_role_id): def delete_implied_role(self, prior_role_id, implied_role_id):
with sql.transaction() as session: with sql.session_for_write() as session:
ref = self._get_implied_role(session, prior_role_id, ref = self._get_implied_role(session, prior_role_id,
implied_role_id) implied_role_id)
session.delete(ref) session.delete(ref)
def list_implied_roles(self, prior_role_id): def list_implied_roles(self, prior_role_id):
with sql.transaction() as session: with sql.session_for_read() as session:
query = session.query( query = session.query(
ImpliedRoleTable).filter( ImpliedRoleTable).filter(
ImpliedRoleTable.prior_role_id == prior_role_id) ImpliedRoleTable.prior_role_id == prior_role_id)
@ -133,13 +133,13 @@ class Role(assignment.RoleDriverV9):
return [ref.to_dict() for ref in refs] return [ref.to_dict() for ref in refs]
def list_role_inference_rules(self): def list_role_inference_rules(self):
with sql.transaction() as session: with sql.session_for_read() as session:
query = session.query(ImpliedRoleTable) query = session.query(ImpliedRoleTable)
refs = query.all() refs = query.all()
return [ref.to_dict() for ref in refs] return [ref.to_dict() for ref in refs]
def get_implied_role(self, prior_role_id, implied_role_id): def get_implied_role(self, prior_role_id, implied_role_id):
with sql.transaction() as session: with sql.session_for_read() as session:
ref = self._get_implied_role(session, prior_role_id, ref = self._get_implied_role(session, prior_role_id,
implied_role_id) implied_role_id)
return ref.to_dict() return ref.to_dict()

View File

@ -84,10 +84,10 @@ class Endpoint(sql.ModelBase, sql.DictBase):
class Catalog(catalog.CatalogDriverV8): class Catalog(catalog.CatalogDriverV8):
# Regions # Regions
def list_regions(self, hints): def list_regions(self, hints):
session = sql.get_session() with sql.session_for_read() as session:
regions = session.query(Region) regions = session.query(Region)
regions = sql.filter_limit_query(Region, regions, hints) regions = sql.filter_limit_query(Region, regions, hints)
return [s.to_dict() for s in list(regions)] return [s.to_dict() for s in list(regions)]
def _get_region(self, session, region_id): def _get_region(self, session, region_id):
ref = session.query(Region).get(region_id) ref = session.query(Region).get(region_id)
@ -136,12 +136,11 @@ class Catalog(catalog.CatalogDriverV8):
return False return False
def get_region(self, region_id): def get_region(self, region_id):
session = sql.get_session() with sql.session_for_read() as session:
return self._get_region(session, region_id).to_dict() return self._get_region(session, region_id).to_dict()
def delete_region(self, region_id): def delete_region(self, region_id):
session = sql.get_session() with sql.session_for_write() as session:
with session.begin():
ref = self._get_region(session, region_id) ref = self._get_region(session, region_id)
if self._has_endpoints(session, ref, ref): if self._has_endpoints(session, ref, ref):
raise exception.RegionDeletionError(region_id=region_id) raise exception.RegionDeletionError(region_id=region_id)
@ -150,16 +149,14 @@ class Catalog(catalog.CatalogDriverV8):
@sql.handle_conflicts(conflict_type='region') @sql.handle_conflicts(conflict_type='region')
def create_region(self, region_ref): def create_region(self, region_ref):
session = sql.get_session() with sql.session_for_write() as session:
with session.begin():
self._check_parent_region(session, region_ref) self._check_parent_region(session, region_ref)
region = Region.from_dict(region_ref) region = Region.from_dict(region_ref)
session.add(region) session.add(region)
return region.to_dict() return region.to_dict()
def update_region(self, region_id, region_ref): def update_region(self, region_id, region_ref):
session = sql.get_session() with sql.session_for_write() as session:
with session.begin():
self._check_parent_region(session, region_ref) self._check_parent_region(session, region_ref)
ref = self._get_region(session, region_id) ref = self._get_region(session, region_id)
old_dict = ref.to_dict() old_dict = ref.to_dict()
@ -169,15 +166,15 @@ class Catalog(catalog.CatalogDriverV8):
for attr in Region.attributes: for attr in Region.attributes:
if attr != 'id': if attr != 'id':
setattr(ref, attr, getattr(new_region, attr)) setattr(ref, attr, getattr(new_region, attr))
return ref.to_dict() return ref.to_dict()
# Services # Services
@driver_hints.truncated @driver_hints.truncated
def list_services(self, hints): def list_services(self, hints):
session = sql.get_session() with sql.session_for_read() as session:
services = session.query(Service) services = session.query(Service)
services = sql.filter_limit_query(Service, services, hints) services = sql.filter_limit_query(Service, services, hints)
return [s.to_dict() for s in list(services)] return [s.to_dict() for s in list(services)]
def _get_service(self, session, service_id): def _get_service(self, session, service_id):
ref = session.query(Service).get(service_id) ref = session.query(Service).get(service_id)
@ -186,26 +183,23 @@ class Catalog(catalog.CatalogDriverV8):
return ref return ref
def get_service(self, service_id): def get_service(self, service_id):
session = sql.get_session() with sql.session_for_read() as session:
return self._get_service(session, service_id).to_dict() return self._get_service(session, service_id).to_dict()
def delete_service(self, service_id): def delete_service(self, service_id):
session = sql.get_session() with sql.session_for_write() as session:
with session.begin():
ref = self._get_service(session, service_id) ref = self._get_service(session, service_id)
session.query(Endpoint).filter_by(service_id=service_id).delete() session.query(Endpoint).filter_by(service_id=service_id).delete()
session.delete(ref) session.delete(ref)
def create_service(self, service_id, service_ref): def create_service(self, service_id, service_ref):
session = sql.get_session() with sql.session_for_write() as session:
with session.begin():
service = Service.from_dict(service_ref) service = Service.from_dict(service_ref)
session.add(service) session.add(service)
return service.to_dict() return service.to_dict()
def update_service(self, service_id, service_ref): def update_service(self, service_id, service_ref):
session = sql.get_session() with sql.session_for_write() as session:
with session.begin():
ref = self._get_service(session, service_id) ref = self._get_service(session, service_id)
old_dict = ref.to_dict() old_dict = ref.to_dict()
old_dict.update(service_ref) old_dict.update(service_ref)
@ -214,20 +208,17 @@ class Catalog(catalog.CatalogDriverV8):
if attr != 'id': if attr != 'id':
setattr(ref, attr, getattr(new_service, attr)) setattr(ref, attr, getattr(new_service, attr))
ref.extra = new_service.extra ref.extra = new_service.extra
return ref.to_dict() return ref.to_dict()
# Endpoints # Endpoints
def create_endpoint(self, endpoint_id, endpoint_ref): def create_endpoint(self, endpoint_id, endpoint_ref):
session = sql.get_session()
new_endpoint = Endpoint.from_dict(endpoint_ref) new_endpoint = Endpoint.from_dict(endpoint_ref)
with sql.session_for_write() as session:
with session.begin():
session.add(new_endpoint) session.add(new_endpoint)
return new_endpoint.to_dict() return new_endpoint.to_dict()
def delete_endpoint(self, endpoint_id): def delete_endpoint(self, endpoint_id):
session = sql.get_session() with sql.session_for_write() as session:
with session.begin():
ref = self._get_endpoint(session, endpoint_id) ref = self._get_endpoint(session, endpoint_id)
session.delete(ref) session.delete(ref)
@ -238,20 +229,18 @@ class Catalog(catalog.CatalogDriverV8):
raise exception.EndpointNotFound(endpoint_id=endpoint_id) raise exception.EndpointNotFound(endpoint_id=endpoint_id)
def get_endpoint(self, endpoint_id): def get_endpoint(self, endpoint_id):
session = sql.get_session() with sql.session_for_read() as session:
return self._get_endpoint(session, endpoint_id).to_dict() return self._get_endpoint(session, endpoint_id).to_dict()
@driver_hints.truncated @driver_hints.truncated
def list_endpoints(self, hints): def list_endpoints(self, hints):
session = sql.get_session() with sql.session_for_read() as session:
endpoints = session.query(Endpoint) endpoints = session.query(Endpoint)
endpoints = sql.filter_limit_query(Endpoint, endpoints, hints) endpoints = sql.filter_limit_query(Endpoint, endpoints, hints)
return [e.to_dict() for e in list(endpoints)] return [e.to_dict() for e in list(endpoints)]
def update_endpoint(self, endpoint_id, endpoint_ref): def update_endpoint(self, endpoint_id, endpoint_ref):
session = sql.get_session() with sql.session_for_write() as session:
with session.begin():
ref = self._get_endpoint(session, endpoint_id) ref = self._get_endpoint(session, endpoint_id)
old_dict = ref.to_dict() old_dict = ref.to_dict()
old_dict.update(endpoint_ref) old_dict.update(endpoint_ref)
@ -260,7 +249,7 @@ class Catalog(catalog.CatalogDriverV8):
if attr != 'id': if attr != 'id':
setattr(ref, attr, getattr(new_endpoint, attr)) setattr(ref, attr, getattr(new_endpoint, attr))
ref.extra = new_endpoint.extra ref.extra = new_endpoint.extra
return ref.to_dict() return ref.to_dict()
def get_catalog(self, user_id, tenant_id): def get_catalog(self, user_id, tenant_id):
"""Retrieve and format the V2 service catalog. """Retrieve and format the V2 service catalog.
@ -289,40 +278,40 @@ class Catalog(catalog.CatalogDriverV8):
else: else:
silent_keyerror_failures = ['tenant_id', 'project_id', ] silent_keyerror_failures = ['tenant_id', 'project_id', ]
session = sql.get_session() with sql.session_for_read() as session:
endpoints = (session.query(Endpoint). endpoints = (session.query(Endpoint).
options(sql.joinedload(Endpoint.service)). options(sql.joinedload(Endpoint.service)).
filter(Endpoint.enabled == true()).all()) filter(Endpoint.enabled == true()).all())
catalog = {} catalog = {}
for endpoint in endpoints: for endpoint in endpoints:
if not endpoint.service['enabled']: if not endpoint.service['enabled']:
continue
try:
formatted_url = core.format_url(
endpoint['url'], substitutions,
silent_keyerror_failures=silent_keyerror_failures)
if formatted_url is not None:
url = formatted_url
else:
continue continue
except exception.MalformedEndpoint: try:
continue # this failure is already logged in format_url() formatted_url = core.format_url(
endpoint['url'], substitutions,
silent_keyerror_failures=silent_keyerror_failures)
if formatted_url is not None:
url = formatted_url
else:
continue
except exception.MalformedEndpoint:
continue # this failure is already logged in format_url()
region = endpoint['region_id'] region = endpoint['region_id']
service_type = endpoint.service['type'] service_type = endpoint.service['type']
default_service = { default_service = {
'id': endpoint['id'], 'id': endpoint['id'],
'name': endpoint.service.extra.get('name', ''), 'name': endpoint.service.extra.get('name', ''),
'publicURL': '' 'publicURL': ''
} }
catalog.setdefault(region, {}) catalog.setdefault(region, {})
catalog[region].setdefault(service_type, default_service) catalog[region].setdefault(service_type, default_service)
interface_url = '%sURL' % endpoint['interface'] interface_url = '%sURL' % endpoint['interface']
catalog[region][service_type][interface_url] = url catalog[region][service_type][interface_url] = url
return catalog return catalog
def get_v3_catalog(self, user_id, tenant_id): def get_v3_catalog(self, user_id, tenant_id):
"""Retrieve and format the current V3 service catalog. """Retrieve and format the current V3 service catalog.
@ -349,44 +338,46 @@ class Catalog(catalog.CatalogDriverV8):
else: else:
silent_keyerror_failures = ['tenant_id', 'project_id', ] silent_keyerror_failures = ['tenant_id', 'project_id', ]
session = sql.get_session() with sql.session_for_read() as session:
services = (session.query(Service).filter(Service.enabled == true()). services = (session.query(Service).filter(
options(sql.joinedload(Service.endpoints)). Service.enabled == true()).options(
all()) sql.joinedload(Service.endpoints)).all())
def make_v3_endpoints(endpoints): def make_v3_endpoints(endpoints):
for endpoint in (ep.to_dict() for ep in endpoints if ep.enabled): for endpoint in (ep.to_dict()
del endpoint['service_id'] for ep in endpoints if ep.enabled):
del endpoint['legacy_endpoint_id'] del endpoint['service_id']
del endpoint['enabled'] del endpoint['legacy_endpoint_id']
endpoint['region'] = endpoint['region_id'] del endpoint['enabled']
try: endpoint['region'] = endpoint['region_id']
formatted_url = core.format_url( try:
endpoint['url'], d, formatted_url = core.format_url(
silent_keyerror_failures=silent_keyerror_failures) endpoint['url'], d,
if formatted_url: silent_keyerror_failures=silent_keyerror_failures)
endpoint['url'] = formatted_url if formatted_url:
else: endpoint['url'] = formatted_url
else:
continue
except exception.MalformedEndpoint:
# this failure is already logged in format_url()
continue continue
except exception.MalformedEndpoint:
continue # this failure is already logged in format_url()
yield endpoint yield endpoint
# TODO(davechen): If there is service with no endpoints, we should skip # TODO(davechen): If there is service with no endpoints, we should
# the service instead of keeping it in the catalog, see bug #1436704. # skip the service instead of keeping it in the catalog,
def make_v3_service(svc): # see bug #1436704.
eps = list(make_v3_endpoints(svc.endpoints)) def make_v3_service(svc):
service = {'endpoints': eps, 'id': svc.id, 'type': svc.type} eps = list(make_v3_endpoints(svc.endpoints))
service['name'] = svc.extra.get('name', '') service = {'endpoints': eps, 'id': svc.id, 'type': svc.type}
return service service['name'] = svc.extra.get('name', '')
return service
return [make_v3_service(svc) for svc in services] return [make_v3_service(svc) for svc in services]
@sql.handle_conflicts(conflict_type='project_endpoint') @sql.handle_conflicts(conflict_type='project_endpoint')
def add_endpoint_to_project(self, endpoint_id, project_id): def add_endpoint_to_project(self, endpoint_id, project_id):
session = sql.get_session() with sql.session_for_write() as session:
with session.begin():
endpoint_filter_ref = ProjectEndpoint(endpoint_id=endpoint_id, endpoint_filter_ref = ProjectEndpoint(endpoint_id=endpoint_id,
project_id=project_id) project_id=project_id)
session.add(endpoint_filter_ref) session.add(endpoint_filter_ref)
@ -402,50 +393,46 @@ class Catalog(catalog.CatalogDriverV8):
return endpoint_filter_ref return endpoint_filter_ref
def check_endpoint_in_project(self, endpoint_id, project_id): def check_endpoint_in_project(self, endpoint_id, project_id):
session = sql.get_session() with sql.session_for_read() as session:
self._get_project_endpoint_ref(session, endpoint_id, project_id) self._get_project_endpoint_ref(session, endpoint_id, project_id)
def remove_endpoint_from_project(self, endpoint_id, project_id): def remove_endpoint_from_project(self, endpoint_id, project_id):
session = sql.get_session() with sql.session_for_write() as session:
endpoint_filter_ref = self._get_project_endpoint_ref( endpoint_filter_ref = self._get_project_endpoint_ref(
session, endpoint_id, project_id) session, endpoint_id, project_id)
with session.begin():
session.delete(endpoint_filter_ref) session.delete(endpoint_filter_ref)
def list_endpoints_for_project(self, project_id): def list_endpoints_for_project(self, project_id):
session = sql.get_session() with sql.session_for_read() as session:
query = session.query(ProjectEndpoint) query = session.query(ProjectEndpoint)
query = query.filter_by(project_id=project_id) query = query.filter_by(project_id=project_id)
endpoint_filter_refs = query.all() endpoint_filter_refs = query.all()
return [ref.to_dict() for ref in endpoint_filter_refs] return [ref.to_dict() for ref in endpoint_filter_refs]
def list_projects_for_endpoint(self, endpoint_id): def list_projects_for_endpoint(self, endpoint_id):
session = sql.get_session() with sql.session_for_read() as session:
query = session.query(ProjectEndpoint) query = session.query(ProjectEndpoint)
query = query.filter_by(endpoint_id=endpoint_id) query = query.filter_by(endpoint_id=endpoint_id)
endpoint_filter_refs = query.all() endpoint_filter_refs = query.all()
return [ref.to_dict() for ref in endpoint_filter_refs] return [ref.to_dict() for ref in endpoint_filter_refs]
def delete_association_by_endpoint(self, endpoint_id): def delete_association_by_endpoint(self, endpoint_id):
session = sql.get_session() with sql.session_for_write() as session:
with session.begin():
query = session.query(ProjectEndpoint) query = session.query(ProjectEndpoint)
query = query.filter_by(endpoint_id=endpoint_id) query = query.filter_by(endpoint_id=endpoint_id)
query.delete(synchronize_session=False) query.delete(synchronize_session=False)
def delete_association_by_project(self, project_id): def delete_association_by_project(self, project_id):
session = sql.get_session() with sql.session_for_write() as session:
with session.begin():
query = session.query(ProjectEndpoint) query = session.query(ProjectEndpoint)
query = query.filter_by(project_id=project_id) query = query.filter_by(project_id=project_id)
query.delete(synchronize_session=False) query.delete(synchronize_session=False)
def create_endpoint_group(self, endpoint_group_id, endpoint_group): def create_endpoint_group(self, endpoint_group_id, endpoint_group):
session = sql.get_session() with sql.session_for_write() as session:
with session.begin():
endpoint_group_ref = EndpointGroup.from_dict(endpoint_group) endpoint_group_ref = EndpointGroup.from_dict(endpoint_group)
session.add(endpoint_group_ref) session.add(endpoint_group_ref)
return endpoint_group_ref.to_dict() return endpoint_group_ref.to_dict()
def _get_endpoint_group(self, session, endpoint_group_id): def _get_endpoint_group(self, session, endpoint_group_id):
endpoint_group_ref = session.query(EndpointGroup).get( endpoint_group_ref = session.query(EndpointGroup).get(
@ -456,14 +443,13 @@ class Catalog(catalog.CatalogDriverV8):
return endpoint_group_ref return endpoint_group_ref
def get_endpoint_group(self, endpoint_group_id): def get_endpoint_group(self, endpoint_group_id):
session = sql.get_session() with sql.session_for_read() as session:
endpoint_group_ref = self._get_endpoint_group(session, endpoint_group_ref = self._get_endpoint_group(session,
endpoint_group_id) endpoint_group_id)
return endpoint_group_ref.to_dict() return endpoint_group_ref.to_dict()
def update_endpoint_group(self, endpoint_group_id, endpoint_group): def update_endpoint_group(self, endpoint_group_id, endpoint_group):
session = sql.get_session() with sql.session_for_write() as session:
with session.begin():
endpoint_group_ref = self._get_endpoint_group(session, endpoint_group_ref = self._get_endpoint_group(session,
endpoint_group_id) endpoint_group_id)
old_endpoint_group = endpoint_group_ref.to_dict() old_endpoint_group = endpoint_group_ref.to_dict()
@ -472,29 +458,26 @@ class Catalog(catalog.CatalogDriverV8):
for attr in EndpointGroup.mutable_attributes: for attr in EndpointGroup.mutable_attributes:
setattr(endpoint_group_ref, attr, setattr(endpoint_group_ref, attr,
getattr(new_endpoint_group, attr)) getattr(new_endpoint_group, attr))
return endpoint_group_ref.to_dict() return endpoint_group_ref.to_dict()
def delete_endpoint_group(self, endpoint_group_id): def delete_endpoint_group(self, endpoint_group_id):
session = sql.get_session() with sql.session_for_write() as session:
endpoint_group_ref = self._get_endpoint_group(session, endpoint_group_ref = self._get_endpoint_group(session,
endpoint_group_id) endpoint_group_id)
with session.begin():
self._delete_endpoint_group_association_by_endpoint_group( self._delete_endpoint_group_association_by_endpoint_group(
session, endpoint_group_id) session, endpoint_group_id)
session.delete(endpoint_group_ref) session.delete(endpoint_group_ref)
def get_endpoint_group_in_project(self, endpoint_group_id, project_id): def get_endpoint_group_in_project(self, endpoint_group_id, project_id):
session = sql.get_session() with sql.session_for_read() as session:
ref = self._get_endpoint_group_in_project(session, ref = self._get_endpoint_group_in_project(session,
endpoint_group_id, endpoint_group_id,
project_id) project_id)
return ref.to_dict() return ref.to_dict()
@sql.handle_conflicts(conflict_type='project_endpoint_group') @sql.handle_conflicts(conflict_type='project_endpoint_group')
def add_endpoint_group_to_project(self, endpoint_group_id, project_id): def add_endpoint_group_to_project(self, endpoint_group_id, project_id):
session = sql.get_session() with sql.session_for_write() as session:
with session.begin():
# Create a new Project Endpoint group entity # Create a new Project Endpoint group entity
endpoint_group_project_ref = ProjectEndpointGroupMembership( endpoint_group_project_ref = ProjectEndpointGroupMembership(
endpoint_group_id=endpoint_group_id, project_id=project_id) endpoint_group_id=endpoint_group_id, project_id=project_id)
@ -512,32 +495,31 @@ class Catalog(catalog.CatalogDriverV8):
return endpoint_group_project_ref return endpoint_group_project_ref
def list_endpoint_groups(self): def list_endpoint_groups(self):
session = sql.get_session() with sql.session_for_read() as session:
query = session.query(EndpointGroup) query = session.query(EndpointGroup)
endpoint_group_refs = query.all() endpoint_group_refs = query.all()
return [e.to_dict() for e in endpoint_group_refs] return [e.to_dict() for e in endpoint_group_refs]
def list_endpoint_groups_for_project(self, project_id): def list_endpoint_groups_for_project(self, project_id):
session = sql.get_session() with sql.session_for_read() as session:
query = session.query(ProjectEndpointGroupMembership) query = session.query(ProjectEndpointGroupMembership)
query = query.filter_by(project_id=project_id) query = query.filter_by(project_id=project_id)
endpoint_group_refs = query.all() endpoint_group_refs = query.all()
return [ref.to_dict() for ref in endpoint_group_refs] return [ref.to_dict() for ref in endpoint_group_refs]
def remove_endpoint_group_from_project(self, endpoint_group_id, def remove_endpoint_group_from_project(self, endpoint_group_id,
project_id): project_id):
session = sql.get_session() with sql.session_for_write() as session:
endpoint_group_project_ref = self._get_endpoint_group_in_project( endpoint_group_project_ref = self._get_endpoint_group_in_project(
session, endpoint_group_id, project_id) session, endpoint_group_id, project_id)
with session.begin():
session.delete(endpoint_group_project_ref) session.delete(endpoint_group_project_ref)
def list_projects_associated_with_endpoint_group(self, endpoint_group_id): def list_projects_associated_with_endpoint_group(self, endpoint_group_id):
session = sql.get_session() with sql.session_for_read() as session:
query = session.query(ProjectEndpointGroupMembership) query = session.query(ProjectEndpointGroupMembership)
query = query.filter_by(endpoint_group_id=endpoint_group_id) query = query.filter_by(endpoint_group_id=endpoint_group_id)
endpoint_group_refs = query.all() endpoint_group_refs = query.all()
return [ref.to_dict() for ref in endpoint_group_refs] return [ref.to_dict() for ref in endpoint_group_refs]
def _delete_endpoint_group_association_by_endpoint_group( def _delete_endpoint_group_association_by_endpoint_group(
self, session, endpoint_group_id): self, session, endpoint_group_id):
@ -546,8 +528,7 @@ class Catalog(catalog.CatalogDriverV8):
query.delete() query.delete()
def delete_endpoint_group_association_by_project(self, project_id): def delete_endpoint_group_association_by_project(self, project_id):
session = sql.get_session() with sql.session_for_write() as session:
with session.begin():
query = session.query(ProjectEndpointGroupMembership) query = session.query(ProjectEndpointGroupMembership)
query = query.filter_by(project_id=project_id) query = query.filter_by(project_id=project_id)
query.delete() query.delete()

View File

@ -18,14 +18,14 @@ Before using this module, call initialize(). This has to be done before
CONF() because it sets up configuration options. CONF() because it sets up configuration options.
""" """
import contextlib
import functools import functools
import threading
from oslo_config import cfg from oslo_config import cfg
from oslo_db import exception as db_exception from oslo_db import exception as db_exception
from oslo_db import options as db_options from oslo_db import options as db_options
from oslo_db.sqlalchemy import enginefacade
from oslo_db.sqlalchemy import models from oslo_db.sqlalchemy import models
from oslo_db.sqlalchemy import session as db_session
from oslo_log import log from oslo_log import log
from oslo_serialization import jsonutils from oslo_serialization import jsonutils
import six import six
@ -166,38 +166,41 @@ class ModelDictMixin(object):
return {name: getattr(self, name) for name in names} return {name: getattr(self, name) for name in names}
_engine_facade = None _main_context_manager = None
def _get_engine_facade(): def _get_main_context_manager():
global _engine_facade global _main_context_manager
if not _engine_facade: if not _main_context_manager:
_engine_facade = db_session.EngineFacade.from_config(CONF) _main_context_manager = enginefacade.transaction_context()
return _engine_facade return _main_context_manager
def cleanup(): def cleanup():
global _engine_facade global _main_context_manager
_engine_facade = None _main_context_manager = None
def get_engine(): def get_engine():
return _get_engine_facade().get_engine() return _get_main_context_manager().get_legacy_facade().get_engine()
def get_session(expire_on_commit=False): def get_session():
return _get_engine_facade().get_session(expire_on_commit=expire_on_commit) return _get_main_context_manager().get_legacy_facade().get_session()
@contextlib.contextmanager _CONTEXT = threading.local()
def transaction(expire_on_commit=False):
"""Return a SQLAlchemy session in a scoped transaction."""
session = get_session(expire_on_commit=expire_on_commit) def session_for_read():
with session.begin(): return _get_main_context_manager().reader.using(_CONTEXT)
yield session
def session_for_write():
return _get_main_context_manager().writer.using(_CONTEXT)
def truncated(f): def truncated(f):

View File

@ -178,7 +178,7 @@ def _sync_extension_repo(extension, version):
try: try:
abs_path = find_migrate_repo(package) abs_path = find_migrate_repo(package)
try: try:
migration.db_version_control(sql.get_engine(), abs_path) migration.db_version_control(engine, abs_path)
# Register the repo with the version control API # Register the repo with the version control API
# If it already knows about the repo, it will throw # If it already knows about the repo, it will throw
# an exception that we can safely ignore # an exception that we can safely ignore

View File

@ -36,28 +36,27 @@ class Credential(credential.CredentialDriverV8):
@sql.handle_conflicts(conflict_type='credential') @sql.handle_conflicts(conflict_type='credential')
def create_credential(self, credential_id, credential): def create_credential(self, credential_id, credential):
session = sql.get_session() with sql.session_for_write() as session:
with session.begin():
ref = CredentialModel.from_dict(credential) ref = CredentialModel.from_dict(credential)
session.add(ref) session.add(ref)
return ref.to_dict() return ref.to_dict()
@driver_hints.truncated @driver_hints.truncated
def list_credentials(self, hints): def list_credentials(self, hints):
session = sql.get_session() with sql.session_for_read() as session:
credentials = session.query(CredentialModel) credentials = session.query(CredentialModel)
credentials = sql.filter_limit_query(CredentialModel, credentials = sql.filter_limit_query(CredentialModel,
credentials, hints) credentials, hints)
return [s.to_dict() for s in credentials] return [s.to_dict() for s in credentials]
def list_credentials_for_user(self, user_id, type=None): def list_credentials_for_user(self, user_id, type=None):
session = sql.get_session() with sql.session_for_read() as session:
query = session.query(CredentialModel) query = session.query(CredentialModel)
query = query.filter_by(user_id=user_id) query = query.filter_by(user_id=user_id)
if type: if type:
query = query.filter_by(type=type) query = query.filter_by(type=type)
refs = query.all() refs = query.all()
return [ref.to_dict() for ref in refs] return [ref.to_dict() for ref in refs]
def _get_credential(self, session, credential_id): def _get_credential(self, session, credential_id):
ref = session.query(CredentialModel).get(credential_id) ref = session.query(CredentialModel).get(credential_id)
@ -66,13 +65,12 @@ class Credential(credential.CredentialDriverV8):
return ref return ref
def get_credential(self, credential_id): def get_credential(self, credential_id):
session = sql.get_session() with sql.session_for_read() as session:
return self._get_credential(session, credential_id).to_dict() return self._get_credential(session, credential_id).to_dict()
@sql.handle_conflicts(conflict_type='credential') @sql.handle_conflicts(conflict_type='credential')
def update_credential(self, credential_id, credential): def update_credential(self, credential_id, credential):
session = sql.get_session() with sql.session_for_write() as session:
with session.begin():
ref = self._get_credential(session, credential_id) ref = self._get_credential(session, credential_id)
old_dict = ref.to_dict() old_dict = ref.to_dict()
for k in credential: for k in credential:
@ -82,27 +80,21 @@ class Credential(credential.CredentialDriverV8):
if attr != 'id': if attr != 'id':
setattr(ref, attr, getattr(new_credential, attr)) setattr(ref, attr, getattr(new_credential, attr))
ref.extra = new_credential.extra ref.extra = new_credential.extra
return ref.to_dict() return ref.to_dict()
def delete_credential(self, credential_id): def delete_credential(self, credential_id):
session = sql.get_session() with sql.session_for_write() as session:
with session.begin():
ref = self._get_credential(session, credential_id) ref = self._get_credential(session, credential_id)
session.delete(ref) session.delete(ref)
def delete_credentials_for_project(self, project_id): def delete_credentials_for_project(self, project_id):
session = sql.get_session() with sql.session_for_write() as session:
with session.begin():
query = session.query(CredentialModel) query = session.query(CredentialModel)
query = query.filter_by(project_id=project_id) query = query.filter_by(project_id=project_id)
query.delete() query.delete()
def delete_credentials_for_user(self, user_id): def delete_credentials_for_user(self, user_id):
session = sql.get_session() with sql.session_for_write() as session:
with session.begin():
query = session.query(CredentialModel) query = session.query(CredentialModel)
query = query.filter_by(user_id=user_id) query = query.filter_by(user_id=user_id)
query.delete() query.delete()

View File

@ -51,7 +51,7 @@ class EndpointPolicy(object):
def create_policy_association(self, policy_id, endpoint_id=None, def create_policy_association(self, policy_id, endpoint_id=None,
service_id=None, region_id=None): service_id=None, region_id=None):
with sql.transaction() as session: with sql.session_for_write() as session:
try: try:
# See if there is already a row for this association, and if # See if there is already a row for this association, and if
# so, update it with the new policy_id # so, update it with the new policy_id
@ -79,14 +79,14 @@ class EndpointPolicy(object):
# NOTE(henry-nash): Getting a single value to save object # NOTE(henry-nash): Getting a single value to save object
# management overhead. # management overhead.
with sql.transaction() as session: with sql.session_for_read() as session:
if session.query(PolicyAssociation.id).filter( if session.query(PolicyAssociation.id).filter(
sql_constraints).distinct().count() == 0: sql_constraints).distinct().count() == 0:
raise exception.PolicyAssociationNotFound() raise exception.PolicyAssociationNotFound()
def delete_policy_association(self, policy_id, endpoint_id=None, def delete_policy_association(self, policy_id, endpoint_id=None,
service_id=None, region_id=None): service_id=None, region_id=None):
with sql.transaction() as session: with sql.session_for_write() as session:
query = session.query(PolicyAssociation) query = session.query(PolicyAssociation)
query = query.filter_by(policy_id=policy_id) query = query.filter_by(policy_id=policy_id)
query = query.filter_by(endpoint_id=endpoint_id) query = query.filter_by(endpoint_id=endpoint_id)
@ -102,7 +102,7 @@ class EndpointPolicy(object):
PolicyAssociation.region_id == region_id) PolicyAssociation.region_id == region_id)
try: try:
with sql.transaction() as session: with sql.session_for_read() as session:
policy_id = session.query(PolicyAssociation.policy_id).filter( policy_id = session.query(PolicyAssociation.policy_id).filter(
sql_constraints).distinct().one() sql_constraints).distinct().one()
return {'policy_id': policy_id} return {'policy_id': policy_id}
@ -110,31 +110,31 @@ class EndpointPolicy(object):
raise exception.PolicyAssociationNotFound() raise exception.PolicyAssociationNotFound()
def list_associations_for_policy(self, policy_id): def list_associations_for_policy(self, policy_id):
with sql.transaction() as session: with sql.session_for_read() as session:
query = session.query(PolicyAssociation) query = session.query(PolicyAssociation)
query = query.filter_by(policy_id=policy_id) query = query.filter_by(policy_id=policy_id)
return [ref.to_dict() for ref in query.all()] return [ref.to_dict() for ref in query.all()]
def delete_association_by_endpoint(self, endpoint_id): def delete_association_by_endpoint(self, endpoint_id):
with sql.transaction() as session: with sql.session_for_write() as session:
query = session.query(PolicyAssociation) query = session.query(PolicyAssociation)
query = query.filter_by(endpoint_id=endpoint_id) query = query.filter_by(endpoint_id=endpoint_id)
query.delete() query.delete()
def delete_association_by_service(self, service_id): def delete_association_by_service(self, service_id):
with sql.transaction() as session: with sql.session_for_write() as session:
query = session.query(PolicyAssociation) query = session.query(PolicyAssociation)
query = query.filter_by(service_id=service_id) query = query.filter_by(service_id=service_id)
query.delete() query.delete()
def delete_association_by_region(self, region_id): def delete_association_by_region(self, region_id):
with sql.transaction() as session: with sql.session_for_write() as session:
query = session.query(PolicyAssociation) query = session.query(PolicyAssociation)
query = query.filter_by(region_id=region_id) query = query.filter_by(region_id=region_id)
query.delete() query.delete()
def delete_association_by_policy(self, policy_id): def delete_association_by_policy(self, policy_id):
with sql.transaction() as session: with sql.session_for_write() as session:
query = session.query(PolicyAssociation) query = session.query(PolicyAssociation)
query = query.filter_by(policy_id=policy_id) query = query.filter_by(policy_id=policy_id)
query.delete() query.delete()

View File

@ -161,13 +161,13 @@ class Federation(core.FederationDriverV8):
@sql.handle_conflicts(conflict_type='identity_provider') @sql.handle_conflicts(conflict_type='identity_provider')
def create_idp(self, idp_id, idp): def create_idp(self, idp_id, idp):
idp['id'] = idp_id idp['id'] = idp_id
with sql.transaction() as session: with sql.session_for_write() as session:
idp_ref = IdentityProviderModel.from_dict(idp) idp_ref = IdentityProviderModel.from_dict(idp)
session.add(idp_ref) session.add(idp_ref)
return idp_ref.to_dict() return idp_ref.to_dict()
def delete_idp(self, idp_id): def delete_idp(self, idp_id):
with sql.transaction() as session: with sql.session_for_write() as session:
self._delete_assigned_protocols(session, idp_id) self._delete_assigned_protocols(session, idp_id)
idp_ref = self._get_idp(session, idp_id) idp_ref = self._get_idp(session, idp_id)
session.delete(idp_ref) session.delete(idp_ref)
@ -187,30 +187,30 @@ class Federation(core.FederationDriverV8):
raise exception.IdentityProviderNotFound(idp_id=remote_id) raise exception.IdentityProviderNotFound(idp_id=remote_id)
def list_idps(self): def list_idps(self):
with sql.transaction() as session: with sql.session_for_read() as session:
idps = session.query(IdentityProviderModel) idps = session.query(IdentityProviderModel)
idps_list = [idp.to_dict() for idp in idps] idps_list = [idp.to_dict() for idp in idps]
return idps_list return idps_list
def get_idp(self, idp_id): def get_idp(self, idp_id):
with sql.transaction() as session: with sql.session_for_read() as session:
idp_ref = self._get_idp(session, idp_id) idp_ref = self._get_idp(session, idp_id)
return idp_ref.to_dict() return idp_ref.to_dict()
def get_idp_from_remote_id(self, remote_id): def get_idp_from_remote_id(self, remote_id):
with sql.transaction() as session: with sql.session_for_read() as session:
ref = self._get_idp_from_remote_id(session, remote_id) ref = self._get_idp_from_remote_id(session, remote_id)
return ref.to_dict() return ref.to_dict()
def update_idp(self, idp_id, idp): def update_idp(self, idp_id, idp):
with sql.transaction() as session: with sql.session_for_write() as session:
idp_ref = self._get_idp(session, idp_id) idp_ref = self._get_idp(session, idp_id)
old_idp = idp_ref.to_dict() old_idp = idp_ref.to_dict()
old_idp.update(idp) old_idp.update(idp)
new_idp = IdentityProviderModel.from_dict(old_idp) new_idp = IdentityProviderModel.from_dict(old_idp)
for attr in IdentityProviderModel.mutable_attributes: for attr in IdentityProviderModel.mutable_attributes:
setattr(idp_ref, attr, getattr(new_idp, attr)) setattr(idp_ref, attr, getattr(new_idp, attr))
return idp_ref.to_dict() return idp_ref.to_dict()
# Protocol CRUD # Protocol CRUD
def _get_protocol(self, session, idp_id, protocol_id): def _get_protocol(self, session, idp_id, protocol_id):
@ -227,36 +227,36 @@ class Federation(core.FederationDriverV8):
def create_protocol(self, idp_id, protocol_id, protocol): def create_protocol(self, idp_id, protocol_id, protocol):
protocol['id'] = protocol_id protocol['id'] = protocol_id
protocol['idp_id'] = idp_id protocol['idp_id'] = idp_id
with sql.transaction() as session: with sql.session_for_write() as session:
self._get_idp(session, idp_id) self._get_idp(session, idp_id)
protocol_ref = FederationProtocolModel.from_dict(protocol) protocol_ref = FederationProtocolModel.from_dict(protocol)
session.add(protocol_ref) session.add(protocol_ref)
return protocol_ref.to_dict() return protocol_ref.to_dict()
def update_protocol(self, idp_id, protocol_id, protocol): def update_protocol(self, idp_id, protocol_id, protocol):
with sql.transaction() as session: with sql.session_for_write() as session:
proto_ref = self._get_protocol(session, idp_id, protocol_id) proto_ref = self._get_protocol(session, idp_id, protocol_id)
old_proto = proto_ref.to_dict() old_proto = proto_ref.to_dict()
old_proto.update(protocol) old_proto.update(protocol)
new_proto = FederationProtocolModel.from_dict(old_proto) new_proto = FederationProtocolModel.from_dict(old_proto)
for attr in FederationProtocolModel.mutable_attributes: for attr in FederationProtocolModel.mutable_attributes:
setattr(proto_ref, attr, getattr(new_proto, attr)) setattr(proto_ref, attr, getattr(new_proto, attr))
return proto_ref.to_dict() return proto_ref.to_dict()
def get_protocol(self, idp_id, protocol_id): def get_protocol(self, idp_id, protocol_id):
with sql.transaction() as session: with sql.session_for_read() as session:
protocol_ref = self._get_protocol(session, idp_id, protocol_id) protocol_ref = self._get_protocol(session, idp_id, protocol_id)
return protocol_ref.to_dict() return protocol_ref.to_dict()
def list_protocols(self, idp_id): def list_protocols(self, idp_id):
with sql.transaction() as session: with sql.session_for_read() as session:
q = session.query(FederationProtocolModel) q = session.query(FederationProtocolModel)
q = q.filter_by(idp_id=idp_id) q = q.filter_by(idp_id=idp_id)
protocols = [protocol.to_dict() for protocol in q] protocols = [protocol.to_dict() for protocol in q]
return protocols return protocols
def delete_protocol(self, idp_id, protocol_id): def delete_protocol(self, idp_id, protocol_id):
with sql.transaction() as session: with sql.session_for_write() as session:
key_ref = self._get_protocol(session, idp_id, protocol_id) key_ref = self._get_protocol(session, idp_id, protocol_id)
session.delete(key_ref) session.delete(key_ref)
@ -277,58 +277,58 @@ class Federation(core.FederationDriverV8):
ref = {} ref = {}
ref['id'] = mapping_id ref['id'] = mapping_id
ref['rules'] = mapping.get('rules') ref['rules'] = mapping.get('rules')
with sql.transaction() as session: with sql.session_for_write() as session:
mapping_ref = MappingModel.from_dict(ref) mapping_ref = MappingModel.from_dict(ref)
session.add(mapping_ref) session.add(mapping_ref)
return mapping_ref.to_dict() return mapping_ref.to_dict()
def delete_mapping(self, mapping_id): def delete_mapping(self, mapping_id):
with sql.transaction() as session: with sql.session_for_write() as session:
mapping_ref = self._get_mapping(session, mapping_id) mapping_ref = self._get_mapping(session, mapping_id)
session.delete(mapping_ref) session.delete(mapping_ref)
def list_mappings(self): def list_mappings(self):
with sql.transaction() as session: with sql.session_for_read() as session:
mappings = session.query(MappingModel) mappings = session.query(MappingModel)
return [x.to_dict() for x in mappings] return [x.to_dict() for x in mappings]
def get_mapping(self, mapping_id): def get_mapping(self, mapping_id):
with sql.transaction() as session: with sql.session_for_read() as session:
mapping_ref = self._get_mapping(session, mapping_id) mapping_ref = self._get_mapping(session, mapping_id)
return mapping_ref.to_dict() return mapping_ref.to_dict()
@sql.handle_conflicts(conflict_type='mapping') @sql.handle_conflicts(conflict_type='mapping')
def update_mapping(self, mapping_id, mapping): def update_mapping(self, mapping_id, mapping):
ref = {} ref = {}
ref['id'] = mapping_id ref['id'] = mapping_id
ref['rules'] = mapping.get('rules') ref['rules'] = mapping.get('rules')
with sql.transaction() as session: with sql.session_for_write() as session:
mapping_ref = self._get_mapping(session, mapping_id) mapping_ref = self._get_mapping(session, mapping_id)
old_mapping = mapping_ref.to_dict() old_mapping = mapping_ref.to_dict()
old_mapping.update(ref) old_mapping.update(ref)
new_mapping = MappingModel.from_dict(old_mapping) new_mapping = MappingModel.from_dict(old_mapping)
for attr in MappingModel.attributes: for attr in MappingModel.attributes:
setattr(mapping_ref, attr, getattr(new_mapping, attr)) setattr(mapping_ref, attr, getattr(new_mapping, attr))
return mapping_ref.to_dict() return mapping_ref.to_dict()
def get_mapping_from_idp_and_protocol(self, idp_id, protocol_id): def get_mapping_from_idp_and_protocol(self, idp_id, protocol_id):
with sql.transaction() as session: with sql.session_for_read() as session:
protocol_ref = self._get_protocol(session, idp_id, protocol_id) protocol_ref = self._get_protocol(session, idp_id, protocol_id)
mapping_id = protocol_ref.mapping_id mapping_id = protocol_ref.mapping_id
mapping_ref = self._get_mapping(session, mapping_id) mapping_ref = self._get_mapping(session, mapping_id)
return mapping_ref.to_dict() return mapping_ref.to_dict()
# Service Provider CRUD # Service Provider CRUD
@sql.handle_conflicts(conflict_type='service_provider') @sql.handle_conflicts(conflict_type='service_provider')
def create_sp(self, sp_id, sp): def create_sp(self, sp_id, sp):
sp['id'] = sp_id sp['id'] = sp_id
with sql.transaction() as session: with sql.session_for_write() as session:
sp_ref = ServiceProviderModel.from_dict(sp) sp_ref = ServiceProviderModel.from_dict(sp)
session.add(sp_ref) session.add(sp_ref)
return sp_ref.to_dict() return sp_ref.to_dict()
def delete_sp(self, sp_id): def delete_sp(self, sp_id):
with sql.transaction() as session: with sql.session_for_write() as session:
sp_ref = self._get_sp(session, sp_id) sp_ref = self._get_sp(session, sp_id)
session.delete(sp_ref) session.delete(sp_ref)
@ -339,28 +339,28 @@ class Federation(core.FederationDriverV8):
return sp_ref return sp_ref
def list_sps(self): def list_sps(self):
with sql.transaction() as session: with sql.session_for_read() as session:
sps = session.query(ServiceProviderModel) sps = session.query(ServiceProviderModel)
sps_list = [sp.to_dict() for sp in sps] sps_list = [sp.to_dict() for sp in sps]
return sps_list return sps_list
def get_sp(self, sp_id): def get_sp(self, sp_id):
with sql.transaction() as session: with sql.session_for_read() as session:
sp_ref = self._get_sp(session, sp_id) sp_ref = self._get_sp(session, sp_id)
return sp_ref.to_dict() return sp_ref.to_dict()
def update_sp(self, sp_id, sp): def update_sp(self, sp_id, sp):
with sql.transaction() as session: with sql.session_for_write() as session:
sp_ref = self._get_sp(session, sp_id) sp_ref = self._get_sp(session, sp_id)
old_sp = sp_ref.to_dict() old_sp = sp_ref.to_dict()
old_sp.update(sp) old_sp.update(sp)
new_sp = ServiceProviderModel.from_dict(old_sp) new_sp = ServiceProviderModel.from_dict(old_sp)
for attr in ServiceProviderModel.mutable_attributes: for attr in ServiceProviderModel.mutable_attributes:
setattr(sp_ref, attr, getattr(new_sp, attr)) setattr(sp_ref, attr, getattr(new_sp, attr))
return sp_ref.to_dict() return sp_ref.to_dict()
def get_enabled_service_providers(self): def get_enabled_service_providers(self):
with sql.transaction() as session: with sql.session_for_read() as session:
service_providers = session.query(ServiceProviderModel) service_providers = session.query(ServiceProviderModel)
service_providers = service_providers.filter_by(enabled=True) service_providers = service_providers.filter_by(enabled=True)
return service_providers return service_providers

View File

@ -169,10 +169,10 @@ class Federation(core.FederationDriverV9):
def create_idp(self, idp_id, idp): def create_idp(self, idp_id, idp):
idp['id'] = idp_id idp['id'] = idp_id
try: try:
with sql.transaction() as session: with sql.session_for_write() as session:
idp_ref = IdentityProviderModel.from_dict(idp) idp_ref = IdentityProviderModel.from_dict(idp)
session.add(idp_ref) session.add(idp_ref)
return idp_ref.to_dict() return idp_ref.to_dict()
except sql.DBDuplicateEntry as e: except sql.DBDuplicateEntry as e:
conflict_type = 'identity_provider' conflict_type = 'identity_provider'
details = six.text_type(e) details = six.text_type(e)
@ -186,7 +186,7 @@ class Federation(core.FederationDriverV9):
raise exception.Conflict(type=conflict_type, details=msg) raise exception.Conflict(type=conflict_type, details=msg)
def delete_idp(self, idp_id): def delete_idp(self, idp_id):
with sql.transaction() as session: with sql.session_for_write() as session:
self._delete_assigned_protocols(session, idp_id) self._delete_assigned_protocols(session, idp_id)
idp_ref = self._get_idp(session, idp_id) idp_ref = self._get_idp(session, idp_id)
session.delete(idp_ref) session.delete(idp_ref)
@ -206,31 +206,31 @@ class Federation(core.FederationDriverV9):
raise exception.IdentityProviderNotFound(idp_id=remote_id) raise exception.IdentityProviderNotFound(idp_id=remote_id)
def list_idps(self, hints=None): def list_idps(self, hints=None):
with sql.transaction() as session: with sql.session_for_read() as session:
query = session.query(IdentityProviderModel) query = session.query(IdentityProviderModel)
idps = sql.filter_limit_query(IdentityProviderModel, query, hints) idps = sql.filter_limit_query(IdentityProviderModel, query, hints)
idps_list = [idp.to_dict() for idp in idps] idps_list = [idp.to_dict() for idp in idps]
return idps_list return idps_list
def get_idp(self, idp_id): def get_idp(self, idp_id):
with sql.transaction() as session: with sql.session_for_read() as session:
idp_ref = self._get_idp(session, idp_id) idp_ref = self._get_idp(session, idp_id)
return idp_ref.to_dict() return idp_ref.to_dict()
def get_idp_from_remote_id(self, remote_id): def get_idp_from_remote_id(self, remote_id):
with sql.transaction() as session: with sql.session_for_read() as session:
ref = self._get_idp_from_remote_id(session, remote_id) ref = self._get_idp_from_remote_id(session, remote_id)
return ref.to_dict() return ref.to_dict()
def update_idp(self, idp_id, idp): def update_idp(self, idp_id, idp):
with sql.transaction() as session: with sql.session_for_write() as session:
idp_ref = self._get_idp(session, idp_id) idp_ref = self._get_idp(session, idp_id)
old_idp = idp_ref.to_dict() old_idp = idp_ref.to_dict()
old_idp.update(idp) old_idp.update(idp)
new_idp = IdentityProviderModel.from_dict(old_idp) new_idp = IdentityProviderModel.from_dict(old_idp)
for attr in IdentityProviderModel.mutable_attributes: for attr in IdentityProviderModel.mutable_attributes:
setattr(idp_ref, attr, getattr(new_idp, attr)) setattr(idp_ref, attr, getattr(new_idp, attr))
return idp_ref.to_dict() return idp_ref.to_dict()
# Protocol CRUD # Protocol CRUD
def _get_protocol(self, session, idp_id, protocol_id): def _get_protocol(self, session, idp_id, protocol_id):
@ -247,36 +247,36 @@ class Federation(core.FederationDriverV9):
def create_protocol(self, idp_id, protocol_id, protocol): def create_protocol(self, idp_id, protocol_id, protocol):
protocol['id'] = protocol_id protocol['id'] = protocol_id
protocol['idp_id'] = idp_id protocol['idp_id'] = idp_id
with sql.transaction() as session: with sql.session_for_write() as session:
self._get_idp(session, idp_id) self._get_idp(session, idp_id)
protocol_ref = FederationProtocolModel.from_dict(protocol) protocol_ref = FederationProtocolModel.from_dict(protocol)
session.add(protocol_ref) session.add(protocol_ref)
return protocol_ref.to_dict() return protocol_ref.to_dict()
def update_protocol(self, idp_id, protocol_id, protocol): def update_protocol(self, idp_id, protocol_id, protocol):
with sql.transaction() as session: with sql.session_for_write() as session:
proto_ref = self._get_protocol(session, idp_id, protocol_id) proto_ref = self._get_protocol(session, idp_id, protocol_id)
old_proto = proto_ref.to_dict() old_proto = proto_ref.to_dict()
old_proto.update(protocol) old_proto.update(protocol)
new_proto = FederationProtocolModel.from_dict(old_proto) new_proto = FederationProtocolModel.from_dict(old_proto)
for attr in FederationProtocolModel.mutable_attributes: for attr in FederationProtocolModel.mutable_attributes:
setattr(proto_ref, attr, getattr(new_proto, attr)) setattr(proto_ref, attr, getattr(new_proto, attr))
return proto_ref.to_dict() return proto_ref.to_dict()
def get_protocol(self, idp_id, protocol_id): def get_protocol(self, idp_id, protocol_id):
with sql.transaction() as session: with sql.session_for_read() as session:
protocol_ref = self._get_protocol(session, idp_id, protocol_id) protocol_ref = self._get_protocol(session, idp_id, protocol_id)
return protocol_ref.to_dict() return protocol_ref.to_dict()
def list_protocols(self, idp_id): def list_protocols(self, idp_id):
with sql.transaction() as session: with sql.session_for_read() as session:
q = session.query(FederationProtocolModel) q = session.query(FederationProtocolModel)
q = q.filter_by(idp_id=idp_id) q = q.filter_by(idp_id=idp_id)
protocols = [protocol.to_dict() for protocol in q] protocols = [protocol.to_dict() for protocol in q]
return protocols return protocols
def delete_protocol(self, idp_id, protocol_id): def delete_protocol(self, idp_id, protocol_id):
with sql.transaction() as session: with sql.session_for_write() as session:
key_ref = self._get_protocol(session, idp_id, protocol_id) key_ref = self._get_protocol(session, idp_id, protocol_id)
session.delete(key_ref) session.delete(key_ref)
@ -297,58 +297,58 @@ class Federation(core.FederationDriverV9):
ref = {} ref = {}
ref['id'] = mapping_id ref['id'] = mapping_id
ref['rules'] = mapping.get('rules') ref['rules'] = mapping.get('rules')
with sql.transaction() as session: with sql.session_for_write() as session:
mapping_ref = MappingModel.from_dict(ref) mapping_ref = MappingModel.from_dict(ref)
session.add(mapping_ref) session.add(mapping_ref)
return mapping_ref.to_dict() return mapping_ref.to_dict()
def delete_mapping(self, mapping_id): def delete_mapping(self, mapping_id):
with sql.transaction() as session: with sql.session_for_write() as session:
mapping_ref = self._get_mapping(session, mapping_id) mapping_ref = self._get_mapping(session, mapping_id)
session.delete(mapping_ref) session.delete(mapping_ref)
def list_mappings(self): def list_mappings(self):
with sql.transaction() as session: with sql.session_for_read() as session:
mappings = session.query(MappingModel) mappings = session.query(MappingModel)
return [x.to_dict() for x in mappings] return [x.to_dict() for x in mappings]
def get_mapping(self, mapping_id): def get_mapping(self, mapping_id):
with sql.transaction() as session: with sql.session_for_read() as session:
mapping_ref = self._get_mapping(session, mapping_id) mapping_ref = self._get_mapping(session, mapping_id)
return mapping_ref.to_dict() return mapping_ref.to_dict()
@sql.handle_conflicts(conflict_type='mapping') @sql.handle_conflicts(conflict_type='mapping')
def update_mapping(self, mapping_id, mapping): def update_mapping(self, mapping_id, mapping):
ref = {} ref = {}
ref['id'] = mapping_id ref['id'] = mapping_id
ref['rules'] = mapping.get('rules') ref['rules'] = mapping.get('rules')
with sql.transaction() as session: with sql.session_for_write() as session:
mapping_ref = self._get_mapping(session, mapping_id) mapping_ref = self._get_mapping(session, mapping_id)
old_mapping = mapping_ref.to_dict() old_mapping = mapping_ref.to_dict()
old_mapping.update(ref) old_mapping.update(ref)
new_mapping = MappingModel.from_dict(old_mapping) new_mapping = MappingModel.from_dict(old_mapping)
for attr in MappingModel.attributes: for attr in MappingModel.attributes:
setattr(mapping_ref, attr, getattr(new_mapping, attr)) setattr(mapping_ref, attr, getattr(new_mapping, attr))
return mapping_ref.to_dict() return mapping_ref.to_dict()
def get_mapping_from_idp_and_protocol(self, idp_id, protocol_id): def get_mapping_from_idp_and_protocol(self, idp_id, protocol_id):
with sql.transaction() as session: with sql.session_for_read() as session:
protocol_ref = self._get_protocol(session, idp_id, protocol_id) protocol_ref = self._get_protocol(session, idp_id, protocol_id)
mapping_id = protocol_ref.mapping_id mapping_id = protocol_ref.mapping_id
mapping_ref = self._get_mapping(session, mapping_id) mapping_ref = self._get_mapping(session, mapping_id)
return mapping_ref.to_dict() return mapping_ref.to_dict()
# Service Provider CRUD # Service Provider CRUD
@sql.handle_conflicts(conflict_type='service_provider') @sql.handle_conflicts(conflict_type='service_provider')
def create_sp(self, sp_id, sp): def create_sp(self, sp_id, sp):
sp['id'] = sp_id sp['id'] = sp_id
with sql.transaction() as session: with sql.session_for_write() as session:
sp_ref = ServiceProviderModel.from_dict(sp) sp_ref = ServiceProviderModel.from_dict(sp)
session.add(sp_ref) session.add(sp_ref)
return sp_ref.to_dict() return sp_ref.to_dict()
def delete_sp(self, sp_id): def delete_sp(self, sp_id):
with sql.transaction() as session: with sql.session_for_write() as session:
sp_ref = self._get_sp(session, sp_id) sp_ref = self._get_sp(session, sp_id)
session.delete(sp_ref) session.delete(sp_ref)
@ -359,28 +359,28 @@ class Federation(core.FederationDriverV9):
return sp_ref return sp_ref
def list_sps(self): def list_sps(self):
with sql.transaction() as session: with sql.session_for_read() as session:
sps = session.query(ServiceProviderModel) sps = session.query(ServiceProviderModel)
sps_list = [sp.to_dict() for sp in sps] sps_list = [sp.to_dict() for sp in sps]
return sps_list return sps_list
def get_sp(self, sp_id): def get_sp(self, sp_id):
with sql.transaction() as session: with sql.session_for_read() as session:
sp_ref = self._get_sp(session, sp_id) sp_ref = self._get_sp(session, sp_id)
return sp_ref.to_dict() return sp_ref.to_dict()
def update_sp(self, sp_id, sp): def update_sp(self, sp_id, sp):
with sql.transaction() as session: with sql.session_for_write() as session:
sp_ref = self._get_sp(session, sp_id) sp_ref = self._get_sp(session, sp_id)
old_sp = sp_ref.to_dict() old_sp = sp_ref.to_dict()
old_sp.update(sp) old_sp.update(sp)
new_sp = ServiceProviderModel.from_dict(old_sp) new_sp = ServiceProviderModel.from_dict(old_sp)
for attr in ServiceProviderModel.mutable_attributes: for attr in ServiceProviderModel.mutable_attributes:
setattr(sp_ref, attr, getattr(new_sp, attr)) setattr(sp_ref, attr, getattr(new_sp, attr))
return sp_ref.to_dict() return sp_ref.to_dict()
def get_enabled_service_providers(self): def get_enabled_service_providers(self):
with sql.transaction() as session: with sql.session_for_read() as session:
service_providers = session.query(ServiceProviderModel) service_providers = session.query(ServiceProviderModel)
service_providers = service_providers.filter_by(enabled=True) service_providers = service_providers.filter_by(enabled=True)
return service_providers return service_providers

View File

@ -178,33 +178,32 @@ class Identity(identity.IdentityDriverV8):
# Identity interface # Identity interface
def authenticate(self, user_id, password): def authenticate(self, user_id, password):
session = sql.get_session() with sql.session_for_read() as session:
user_ref = None user_ref = None
try: try:
user_ref = self._get_user(session, user_id) user_ref = self._get_user(session, user_id)
except exception.UserNotFound: except exception.UserNotFound:
raise AssertionError(_('Invalid user / password')) raise AssertionError(_('Invalid user / password'))
if not self._check_password(password, user_ref): if not self._check_password(password, user_ref):
raise AssertionError(_('Invalid user / password')) raise AssertionError(_('Invalid user / password'))
return identity.filter_user(user_ref.to_dict()) return identity.filter_user(user_ref.to_dict())
# user crud # user crud
@sql.handle_conflicts(conflict_type='user') @sql.handle_conflicts(conflict_type='user')
def create_user(self, user_id, user): def create_user(self, user_id, user):
user = utils.hash_user_password(user) user = utils.hash_user_password(user)
session = sql.get_session() with sql.session_for_write() as session:
with session.begin():
user_ref = User.from_dict(user) user_ref = User.from_dict(user)
session.add(user_ref) session.add(user_ref)
return identity.filter_user(user_ref.to_dict()) return identity.filter_user(user_ref.to_dict())
@driver_hints.truncated @driver_hints.truncated
def list_users(self, hints): def list_users(self, hints):
session = sql.get_session() with sql.session_for_read() as session:
query = session.query(User).outerjoin(LocalUser) query = session.query(User).outerjoin(LocalUser)
user_refs = sql.filter_limit_query(User, query, hints) user_refs = sql.filter_limit_query(User, query, hints)
return [identity.filter_user(x.to_dict()) for x in user_refs] return [identity.filter_user(x.to_dict()) for x in user_refs]
def _get_user(self, session, user_id): def _get_user(self, session, user_id):
user_ref = session.query(User).get(user_id) user_ref = session.query(User).get(user_id)
@ -213,25 +212,24 @@ class Identity(identity.IdentityDriverV8):
return user_ref return user_ref
def get_user(self, user_id): def get_user(self, user_id):
session = sql.get_session() with sql.session_for_read() as session:
return identity.filter_user(self._get_user(session, user_id).to_dict()) return identity.filter_user(
self._get_user(session, user_id).to_dict())
def get_user_by_name(self, user_name, domain_id): def get_user_by_name(self, user_name, domain_id):
session = sql.get_session() with sql.session_for_read() as session:
query = session.query(User).join(LocalUser) query = session.query(User).join(LocalUser)
query = query.filter(and_(LocalUser.name == user_name, query = query.filter(and_(LocalUser.name == user_name,
LocalUser.domain_id == domain_id)) LocalUser.domain_id == domain_id))
try: try:
user_ref = query.one() user_ref = query.one()
except sql.NotFound: except sql.NotFound:
raise exception.UserNotFound(user_id=user_name) raise exception.UserNotFound(user_id=user_name)
return identity.filter_user(user_ref.to_dict()) return identity.filter_user(user_ref.to_dict())
@sql.handle_conflicts(conflict_type='user') @sql.handle_conflicts(conflict_type='user')
def update_user(self, user_id, user): def update_user(self, user_id, user):
session = sql.get_session() with sql.session_for_write() as session:
with session.begin():
user_ref = self._get_user(session, user_id) user_ref = self._get_user(session, user_id)
old_user_dict = user_ref.to_dict() old_user_dict = user_ref.to_dict()
user = utils.hash_user_password(user) user = utils.hash_user_password(user)
@ -242,77 +240,74 @@ class Identity(identity.IdentityDriverV8):
if attr != 'id': if attr != 'id':
setattr(user_ref, attr, getattr(new_user, attr)) setattr(user_ref, attr, getattr(new_user, attr))
user_ref.extra = new_user.extra user_ref.extra = new_user.extra
return identity.filter_user(user_ref.to_dict(include_extra_dict=True)) return identity.filter_user(
user_ref.to_dict(include_extra_dict=True))
def add_user_to_group(self, user_id, group_id): def add_user_to_group(self, user_id, group_id):
session = sql.get_session() with sql.session_for_write() as session:
self.get_group(group_id) self.get_group(group_id)
self.get_user(user_id) self.get_user(user_id)
query = session.query(UserGroupMembership) query = session.query(UserGroupMembership)
query = query.filter_by(user_id=user_id) query = query.filter_by(user_id=user_id)
query = query.filter_by(group_id=group_id) query = query.filter_by(group_id=group_id)
rv = query.first() rv = query.first()
if rv: if rv:
return return
with session.begin():
session.add(UserGroupMembership(user_id=user_id, session.add(UserGroupMembership(user_id=user_id,
group_id=group_id)) group_id=group_id))
def check_user_in_group(self, user_id, group_id): def check_user_in_group(self, user_id, group_id):
session = sql.get_session() with sql.session_for_read() as session:
self.get_group(group_id)
self.get_user(user_id)
query = session.query(UserGroupMembership)
query = query.filter_by(user_id=user_id)
query = query.filter_by(group_id=group_id)
if not query.first():
raise exception.NotFound(_("User '%(user_id)s' not found in"
" group '%(group_id)s'") %
{'user_id': user_id,
'group_id': group_id})
def remove_user_from_group(self, user_id, group_id):
session = sql.get_session()
# We don't check if user or group are still valid and let the remove
# be tried anyway - in case this is some kind of clean-up operation
query = session.query(UserGroupMembership)
query = query.filter_by(user_id=user_id)
query = query.filter_by(group_id=group_id)
membership_ref = query.first()
if membership_ref is None:
# Check if the group and user exist to return descriptive
# exceptions.
self.get_group(group_id) self.get_group(group_id)
self.get_user(user_id) self.get_user(user_id)
raise exception.NotFound(_("User '%(user_id)s' not found in" query = session.query(UserGroupMembership)
" group '%(group_id)s'") % query = query.filter_by(user_id=user_id)
{'user_id': user_id, query = query.filter_by(group_id=group_id)
'group_id': group_id}) if not query.first():
with session.begin(): raise exception.NotFound(_("User '%(user_id)s' not found in"
" group '%(group_id)s'") %
{'user_id': user_id,
'group_id': group_id})
def remove_user_from_group(self, user_id, group_id):
# We don't check if user or group are still valid and let the remove
# be tried anyway - in case this is some kind of clean-up operation
with sql.session_for_write() as session:
query = session.query(UserGroupMembership)
query = query.filter_by(user_id=user_id)
query = query.filter_by(group_id=group_id)
membership_ref = query.first()
if membership_ref is None:
# Check if the group and user exist to return descriptive
# exceptions.
self.get_group(group_id)
self.get_user(user_id)
raise exception.NotFound(_("User '%(user_id)s' not found in"
" group '%(group_id)s'") %
{'user_id': user_id,
'group_id': group_id})
session.delete(membership_ref) session.delete(membership_ref)
def list_groups_for_user(self, user_id, hints): def list_groups_for_user(self, user_id, hints):
session = sql.get_session() with sql.session_for_read() as session:
self.get_user(user_id) self.get_user(user_id)
query = session.query(Group).join(UserGroupMembership) query = session.query(Group).join(UserGroupMembership)
query = query.filter(UserGroupMembership.user_id == user_id) query = query.filter(UserGroupMembership.user_id == user_id)
query = sql.filter_limit_query(Group, query, hints) query = sql.filter_limit_query(Group, query, hints)
return [g.to_dict() for g in query] return [g.to_dict() for g in query]
def list_users_in_group(self, group_id, hints): def list_users_in_group(self, group_id, hints):
session = sql.get_session() with sql.session_for_read() as session:
self.get_group(group_id) self.get_group(group_id)
query = session.query(User).outerjoin(LocalUser) query = session.query(User).outerjoin(LocalUser)
query = query.join(UserGroupMembership) query = query.join(UserGroupMembership)
query = query.filter(UserGroupMembership.group_id == group_id) query = query.filter(UserGroupMembership.group_id == group_id)
query = sql.filter_limit_query(User, query, hints) query = sql.filter_limit_query(User, query, hints)
return [identity.filter_user(u.to_dict()) for u in query] return [identity.filter_user(u.to_dict()) for u in query]
def delete_user(self, user_id): def delete_user(self, user_id):
session = sql.get_session() with sql.session_for_write() as session:
with session.begin():
ref = self._get_user(session, user_id) ref = self._get_user(session, user_id)
q = session.query(UserGroupMembership) q = session.query(UserGroupMembership)
@ -325,18 +320,17 @@ class Identity(identity.IdentityDriverV8):
@sql.handle_conflicts(conflict_type='group') @sql.handle_conflicts(conflict_type='group')
def create_group(self, group_id, group): def create_group(self, group_id, group):
session = sql.get_session() with sql.session_for_write() as session:
with session.begin():
ref = Group.from_dict(group) ref = Group.from_dict(group)
session.add(ref) session.add(ref)
return ref.to_dict() return ref.to_dict()
@driver_hints.truncated @driver_hints.truncated
def list_groups(self, hints): def list_groups(self, hints):
session = sql.get_session() with sql.session_for_read() as session:
query = session.query(Group) query = session.query(Group)
refs = sql.filter_limit_query(Group, query, hints) refs = sql.filter_limit_query(Group, query, hints)
return [ref.to_dict() for ref in refs] return [ref.to_dict() for ref in refs]
def _get_group(self, session, group_id): def _get_group(self, session, group_id):
ref = session.query(Group).get(group_id) ref = session.query(Group).get(group_id)
@ -345,25 +339,23 @@ class Identity(identity.IdentityDriverV8):
return ref return ref
def get_group(self, group_id): def get_group(self, group_id):
session = sql.get_session() with sql.session_for_read() as session:
return self._get_group(session, group_id).to_dict() return self._get_group(session, group_id).to_dict()
def get_group_by_name(self, group_name, domain_id): def get_group_by_name(self, group_name, domain_id):
session = sql.get_session() with sql.session_for_read() as session:
query = session.query(Group) query = session.query(Group)
query = query.filter_by(name=group_name) query = query.filter_by(name=group_name)
query = query.filter_by(domain_id=domain_id) query = query.filter_by(domain_id=domain_id)
try: try:
group_ref = query.one() group_ref = query.one()
except sql.NotFound: except sql.NotFound:
raise exception.GroupNotFound(group_id=group_name) raise exception.GroupNotFound(group_id=group_name)
return group_ref.to_dict() return group_ref.to_dict()
@sql.handle_conflicts(conflict_type='group') @sql.handle_conflicts(conflict_type='group')
def update_group(self, group_id, group): def update_group(self, group_id, group):
session = sql.get_session() with sql.session_for_write() as session:
with session.begin():
ref = self._get_group(session, group_id) ref = self._get_group(session, group_id)
old_dict = ref.to_dict() old_dict = ref.to_dict()
for k in group: for k in group:
@ -373,12 +365,10 @@ class Identity(identity.IdentityDriverV8):
if attr != 'id': if attr != 'id':
setattr(ref, attr, getattr(new_group, attr)) setattr(ref, attr, getattr(new_group, attr))
ref.extra = new_group.extra ref.extra = new_group.extra
return ref.to_dict() return ref.to_dict()
def delete_group(self, group_id): def delete_group(self, group_id):
session = sql.get_session() with sql.session_for_write() as session:
with session.begin():
ref = self._get_group(session, group_id) ref = self._get_group(session, group_id)
q = session.query(UserGroupMembership) q = session.query(UserGroupMembership)

View File

@ -45,27 +45,27 @@ class Mapping(identity.MappingDriverV8):
# work if we hashed all the entries, even those that already generate # work if we hashed all the entries, even those that already generate
# UUIDs, like SQL. Further, this would only work if the generation # UUIDs, like SQL. Further, this would only work if the generation
# algorithm was immutable (e.g. it had always been sha256). # algorithm was immutable (e.g. it had always been sha256).
session = sql.get_session() with sql.session_for_read() as session:
query = session.query(IDMapping.public_id) query = session.query(IDMapping.public_id)
query = query.filter_by(domain_id=local_entity['domain_id']) query = query.filter_by(domain_id=local_entity['domain_id'])
query = query.filter_by(local_id=local_entity['local_id']) query = query.filter_by(local_id=local_entity['local_id'])
query = query.filter_by(entity_type=local_entity['entity_type']) query = query.filter_by(entity_type=local_entity['entity_type'])
try: try:
public_ref = query.one() public_ref = query.one()
public_id = public_ref.public_id public_id = public_ref.public_id
return public_id return public_id
except sql.NotFound: except sql.NotFound:
return None return None
def get_id_mapping(self, public_id): def get_id_mapping(self, public_id):
session = sql.get_session() with sql.session_for_read() as session:
mapping_ref = session.query(IDMapping).get(public_id) mapping_ref = session.query(IDMapping).get(public_id)
if mapping_ref: if mapping_ref:
return mapping_ref.to_dict() return mapping_ref.to_dict()
def create_id_mapping(self, local_entity, public_id=None): def create_id_mapping(self, local_entity, public_id=None):
entity = local_entity.copy() entity = local_entity.copy()
with sql.transaction() as session: with sql.session_for_write() as session:
if public_id is None: if public_id is None:
public_id = self.id_generator_api.generate_public_ID(entity) public_id = self.id_generator_api.generate_public_ID(entity)
entity['public_id'] = public_id entity['public_id'] = public_id
@ -74,7 +74,7 @@ class Mapping(identity.MappingDriverV8):
return public_id return public_id
def delete_id_mapping(self, public_id): def delete_id_mapping(self, public_id):
with sql.transaction() as session: with sql.session_for_write() as session:
try: try:
session.query(IDMapping).filter( session.query(IDMapping).filter(
IDMapping.public_id == public_id).delete() IDMapping.public_id == public_id).delete()
@ -84,14 +84,15 @@ class Mapping(identity.MappingDriverV8):
pass pass
def purge_mappings(self, purge_filter): def purge_mappings(self, purge_filter):
session = sql.get_session() with sql.session_for_write() as session:
query = session.query(IDMapping) query = session.query(IDMapping)
if 'domain_id' in purge_filter: if 'domain_id' in purge_filter:
query = query.filter_by(domain_id=purge_filter['domain_id']) query = query.filter_by(domain_id=purge_filter['domain_id'])
if 'public_id' in purge_filter: if 'public_id' in purge_filter:
query = query.filter_by(public_id=purge_filter['public_id']) query = query.filter_by(public_id=purge_filter['public_id'])
if 'local_id' in purge_filter: if 'local_id' in purge_filter:
query = query.filter_by(local_id=purge_filter['local_id']) query = query.filter_by(local_id=purge_filter['local_id'])
if 'entity_type' in purge_filter: if 'entity_type' in purge_filter:
query = query.filter_by(entity_type=purge_filter['entity_type']) query = query.filter_by(
query.delete() entity_type=purge_filter['entity_type'])
query.delete()

View File

@ -92,17 +92,16 @@ class OAuth1(core.Oauth1DriverV8):
return consumer_ref return consumer_ref
def get_consumer_with_secret(self, consumer_id): def get_consumer_with_secret(self, consumer_id):
session = sql.get_session() with sql.session_for_read() as session:
consumer_ref = self._get_consumer(session, consumer_id) consumer_ref = self._get_consumer(session, consumer_id)
return consumer_ref.to_dict() return consumer_ref.to_dict()
def get_consumer(self, consumer_id): def get_consumer(self, consumer_id):
return core.filter_consumer( return core.filter_consumer(
self.get_consumer_with_secret(consumer_id)) self.get_consumer_with_secret(consumer_id))
def create_consumer(self, consumer_ref): def create_consumer(self, consumer_ref):
session = sql.get_session() with sql.session_for_write() as session:
with session.begin():
consumer = Consumer.from_dict(consumer_ref) consumer = Consumer.from_dict(consumer_ref)
session.add(consumer) session.add(consumer)
return consumer.to_dict() return consumer.to_dict()
@ -128,20 +127,18 @@ class OAuth1(core.Oauth1DriverV8):
session.delete(token_ref) session.delete(token_ref)
def delete_consumer(self, consumer_id): def delete_consumer(self, consumer_id):
session = sql.get_session() with sql.session_for_write() as session:
with session.begin():
self._delete_request_tokens(session, consumer_id) self._delete_request_tokens(session, consumer_id)
self._delete_access_tokens(session, consumer_id) self._delete_access_tokens(session, consumer_id)
self._delete_consumer(session, consumer_id) self._delete_consumer(session, consumer_id)
def list_consumers(self): def list_consumers(self):
session = sql.get_session() with sql.session_for_read() as session:
cons = session.query(Consumer) cons = session.query(Consumer)
return [core.filter_consumer(x.to_dict()) for x in cons] return [core.filter_consumer(x.to_dict()) for x in cons]
def update_consumer(self, consumer_id, consumer_ref): def update_consumer(self, consumer_id, consumer_ref):
session = sql.get_session() with sql.session_for_write() as session:
with session.begin():
consumer = self._get_consumer(session, consumer_id) consumer = self._get_consumer(session, consumer_id)
old_consumer_dict = consumer.to_dict() old_consumer_dict = consumer.to_dict()
old_consumer_dict.update(consumer_ref) old_consumer_dict.update(consumer_ref)
@ -169,11 +166,10 @@ class OAuth1(core.Oauth1DriverV8):
ref['role_ids'] = None ref['role_ids'] = None
ref['consumer_id'] = consumer_id ref['consumer_id'] = consumer_id
ref['expires_at'] = expiry_date ref['expires_at'] = expiry_date
session = sql.get_session() with sql.session_for_write() as session:
with session.begin():
token_ref = RequestToken.from_dict(ref) token_ref = RequestToken.from_dict(ref)
session.add(token_ref) session.add(token_ref)
return token_ref.to_dict() return token_ref.to_dict()
def _get_request_token(self, session, request_token_id): def _get_request_token(self, session, request_token_id):
token_ref = session.query(RequestToken).get(request_token_id) token_ref = session.query(RequestToken).get(request_token_id)
@ -182,14 +178,13 @@ class OAuth1(core.Oauth1DriverV8):
return token_ref return token_ref
def get_request_token(self, request_token_id): def get_request_token(self, request_token_id):
session = sql.get_session() with sql.session_for_read() as session:
token_ref = self._get_request_token(session, request_token_id) token_ref = self._get_request_token(session, request_token_id)
return token_ref.to_dict() return token_ref.to_dict()
def authorize_request_token(self, request_token_id, user_id, def authorize_request_token(self, request_token_id, user_id,
role_ids): role_ids):
session = sql.get_session() with sql.session_for_write() as session:
with session.begin():
token_ref = self._get_request_token(session, request_token_id) token_ref = self._get_request_token(session, request_token_id)
token_dict = token_ref.to_dict() token_dict = token_ref.to_dict()
token_dict['authorizing_user_id'] = user_id token_dict['authorizing_user_id'] = user_id
@ -203,13 +198,12 @@ class OAuth1(core.Oauth1DriverV8):
or attr == 'role_ids'): or attr == 'role_ids'):
setattr(token_ref, attr, getattr(new_token, attr)) setattr(token_ref, attr, getattr(new_token, attr))
return token_ref.to_dict() return token_ref.to_dict()
def create_access_token(self, request_id, access_token_duration): def create_access_token(self, request_id, access_token_duration):
access_token_id = uuid.uuid4().hex access_token_id = uuid.uuid4().hex
access_token_secret = uuid.uuid4().hex access_token_secret = uuid.uuid4().hex
session = sql.get_session() with sql.session_for_write() as session:
with session.begin():
req_token_ref = self._get_request_token(session, request_id) req_token_ref = self._get_request_token(session, request_id)
token_dict = req_token_ref.to_dict() token_dict = req_token_ref.to_dict()
@ -235,7 +229,7 @@ class OAuth1(core.Oauth1DriverV8):
# remove request token, it's been used # remove request token, it's been used
session.delete(req_token_ref) session.delete(req_token_ref)
return token_ref.to_dict() return token_ref.to_dict()
def _get_access_token(self, session, access_token_id): def _get_access_token(self, session, access_token_id):
token_ref = session.query(AccessToken).get(access_token_id) token_ref = session.query(AccessToken).get(access_token_id)
@ -244,19 +238,18 @@ class OAuth1(core.Oauth1DriverV8):
return token_ref return token_ref
def get_access_token(self, access_token_id): def get_access_token(self, access_token_id):
session = sql.get_session() with sql.session_for_read() as session:
token_ref = self._get_access_token(session, access_token_id) token_ref = self._get_access_token(session, access_token_id)
return token_ref.to_dict() return token_ref.to_dict()
def list_access_tokens(self, user_id): def list_access_tokens(self, user_id):
session = sql.get_session() with sql.session_for_read() as session:
q = session.query(AccessToken) q = session.query(AccessToken)
user_auths = q.filter_by(authorizing_user_id=user_id) user_auths = q.filter_by(authorizing_user_id=user_id)
return [core.filter_token(x.to_dict()) for x in user_auths] return [core.filter_token(x.to_dict()) for x in user_auths]
def delete_access_token(self, user_id, access_token_id): def delete_access_token(self, user_id, access_token_id):
session = sql.get_session() with sql.session_for_write() as session:
with session.begin():
token_ref = self._get_access_token(session, access_token_id) token_ref = self._get_access_token(session, access_token_id)
token_dict = token_ref.to_dict() token_dict = token_ref.to_dict()
if token_dict['authorizing_user_id'] != user_id: if token_dict['authorizing_user_id'] != user_id:

View File

@ -30,19 +30,16 @@ class Policy(rules.Policy):
@sql.handle_conflicts(conflict_type='policy') @sql.handle_conflicts(conflict_type='policy')
def create_policy(self, policy_id, policy): def create_policy(self, policy_id, policy):
session = sql.get_session() with sql.session_for_write() as session:
with session.begin():
ref = PolicyModel.from_dict(policy) ref = PolicyModel.from_dict(policy)
session.add(ref) session.add(ref)
return ref.to_dict() return ref.to_dict()
def list_policies(self): def list_policies(self):
session = sql.get_session() with sql.session_for_read() as session:
refs = session.query(PolicyModel).all()
refs = session.query(PolicyModel).all() return [ref.to_dict() for ref in refs]
return [ref.to_dict() for ref in refs]
def _get_policy(self, session, policy_id): def _get_policy(self, session, policy_id):
"""Private method to get a policy model object (NOT a dictionary).""" """Private method to get a policy model object (NOT a dictionary)."""
@ -52,15 +49,12 @@ class Policy(rules.Policy):
return ref return ref
def get_policy(self, policy_id): def get_policy(self, policy_id):
session = sql.get_session() with sql.session_for_read() as session:
return self._get_policy(session, policy_id).to_dict()
return self._get_policy(session, policy_id).to_dict()
@sql.handle_conflicts(conflict_type='policy') @sql.handle_conflicts(conflict_type='policy')
def update_policy(self, policy_id, policy): def update_policy(self, policy_id, policy):
session = sql.get_session() with sql.session_for_write() as session:
with session.begin():
ref = self._get_policy(session, policy_id) ref = self._get_policy(session, policy_id)
old_dict = ref.to_dict() old_dict = ref.to_dict()
old_dict.update(policy) old_dict.update(policy)
@ -72,8 +66,6 @@ class Policy(rules.Policy):
return ref.to_dict() return ref.to_dict()
def delete_policy(self, policy_id): def delete_policy(self, policy_id):
session = sql.get_session() with sql.session_for_write() as session:
with session.begin():
ref = self._get_policy(session, policy_id) ref = self._get_policy(session, policy_id)
session.delete(ref) session.delete(ref)

View File

@ -35,11 +35,11 @@ class Resource(keystone_resource.ResourceDriverV8):
return project_ref return project_ref
def get_project(self, tenant_id): def get_project(self, tenant_id):
with sql.transaction() as session: with sql.session_for_read() as session:
return self._get_project(session, tenant_id).to_dict() return self._get_project(session, tenant_id).to_dict()
def get_project_by_name(self, tenant_name, domain_id): def get_project_by_name(self, tenant_name, domain_id):
with sql.transaction() as session: with sql.session_for_read() as session:
query = session.query(Project) query = session.query(Project)
query = query.filter_by(name=tenant_name) query = query.filter_by(name=tenant_name)
query = query.filter_by(domain_id=domain_id) query = query.filter_by(domain_id=domain_id)
@ -51,7 +51,7 @@ class Resource(keystone_resource.ResourceDriverV8):
@driver_hints.truncated @driver_hints.truncated
def list_projects(self, hints): def list_projects(self, hints):
with sql.transaction() as session: with sql.session_for_read() as session:
query = session.query(Project) query = session.query(Project)
project_refs = sql.filter_limit_query(Project, query, hints) project_refs = sql.filter_limit_query(Project, query, hints)
return [project_ref.to_dict() for project_ref in project_refs] return [project_ref.to_dict() for project_ref in project_refs]
@ -60,7 +60,7 @@ class Resource(keystone_resource.ResourceDriverV8):
if not ids: if not ids:
return [] return []
else: else:
with sql.transaction() as session: with sql.session_for_read() as session:
query = session.query(Project) query = session.query(Project)
query = query.filter(Project.id.in_(ids)) query = query.filter(Project.id.in_(ids))
return [project_ref.to_dict() for project_ref in query.all()] return [project_ref.to_dict() for project_ref in query.all()]
@ -69,14 +69,14 @@ class Resource(keystone_resource.ResourceDriverV8):
if not domain_ids: if not domain_ids:
return [] return []
else: else:
with sql.transaction() as session: with sql.session_for_read() as session:
query = session.query(Project.id) query = session.query(Project.id)
query = ( query = (
query.filter(Project.domain_id.in_(domain_ids))) query.filter(Project.domain_id.in_(domain_ids)))
return [x.id for x in query.all()] return [x.id for x in query.all()]
def list_projects_in_domain(self, domain_id): def list_projects_in_domain(self, domain_id):
with sql.transaction() as session: with sql.session_for_read() as session:
self._get_domain(session, domain_id) self._get_domain(session, domain_id)
query = session.query(Project) query = session.query(Project)
project_refs = query.filter_by(domain_id=domain_id) project_refs = query.filter_by(domain_id=domain_id)
@ -89,7 +89,7 @@ class Resource(keystone_resource.ResourceDriverV8):
return [project_ref.to_dict() for project_ref in project_refs] return [project_ref.to_dict() for project_ref in project_refs]
def list_projects_in_subtree(self, project_id): def list_projects_in_subtree(self, project_id):
with sql.transaction() as session: with sql.session_for_read() as session:
children = self._get_children(session, [project_id]) children = self._get_children(session, [project_id])
subtree = [] subtree = []
examined = set([project_id]) examined = set([project_id])
@ -110,7 +110,7 @@ class Resource(keystone_resource.ResourceDriverV8):
return subtree return subtree
def list_project_parents(self, project_id): def list_project_parents(self, project_id):
with sql.transaction() as session: with sql.session_for_read() as session:
project = self._get_project(session, project_id).to_dict() project = self._get_project(session, project_id).to_dict()
parents = [] parents = []
examined = set() examined = set()
@ -130,7 +130,7 @@ class Resource(keystone_resource.ResourceDriverV8):
return parents return parents
def is_leaf_project(self, project_id): def is_leaf_project(self, project_id):
with sql.transaction() as session: with sql.session_for_read() as session:
project_refs = self._get_children(session, [project_id]) project_refs = self._get_children(session, [project_id])
return not project_refs return not project_refs
@ -138,7 +138,7 @@ class Resource(keystone_resource.ResourceDriverV8):
@sql.handle_conflicts(conflict_type='project') @sql.handle_conflicts(conflict_type='project')
def create_project(self, tenant_id, tenant): def create_project(self, tenant_id, tenant):
tenant['name'] = clean.project_name(tenant['name']) tenant['name'] = clean.project_name(tenant['name'])
with sql.transaction() as session: with sql.session_for_write() as session:
tenant_ref = Project.from_dict(tenant) tenant_ref = Project.from_dict(tenant)
session.add(tenant_ref) session.add(tenant_ref)
return tenant_ref.to_dict() return tenant_ref.to_dict()
@ -148,7 +148,7 @@ class Resource(keystone_resource.ResourceDriverV8):
if 'name' in tenant: if 'name' in tenant:
tenant['name'] = clean.project_name(tenant['name']) tenant['name'] = clean.project_name(tenant['name'])
with sql.transaction() as session: with sql.session_for_write() as session:
tenant_ref = self._get_project(session, tenant_id) tenant_ref = self._get_project(session, tenant_id)
old_project_dict = tenant_ref.to_dict() old_project_dict = tenant_ref.to_dict()
for k in tenant: for k in tenant:
@ -162,7 +162,7 @@ class Resource(keystone_resource.ResourceDriverV8):
@sql.handle_conflicts(conflict_type='project') @sql.handle_conflicts(conflict_type='project')
def delete_project(self, tenant_id): def delete_project(self, tenant_id):
with sql.transaction() as session: with sql.session_for_write() as session:
tenant_ref = self._get_project(session, tenant_id) tenant_ref = self._get_project(session, tenant_id)
session.delete(tenant_ref) session.delete(tenant_ref)
@ -170,14 +170,14 @@ class Resource(keystone_resource.ResourceDriverV8):
@sql.handle_conflicts(conflict_type='domain') @sql.handle_conflicts(conflict_type='domain')
def create_domain(self, domain_id, domain): def create_domain(self, domain_id, domain):
with sql.transaction() as session: with sql.session_for_write() as session:
ref = Domain.from_dict(domain) ref = Domain.from_dict(domain)
session.add(ref) session.add(ref)
return ref.to_dict() return ref.to_dict()
@driver_hints.truncated @driver_hints.truncated
def list_domains(self, hints): def list_domains(self, hints):
with sql.transaction() as session: with sql.session_for_read() as session:
query = session.query(Domain) query = session.query(Domain)
refs = sql.filter_limit_query(Domain, query, hints) refs = sql.filter_limit_query(Domain, query, hints)
return [ref.to_dict() for ref in refs] return [ref.to_dict() for ref in refs]
@ -186,7 +186,7 @@ class Resource(keystone_resource.ResourceDriverV8):
if not ids: if not ids:
return [] return []
else: else:
with sql.transaction() as session: with sql.session_for_read() as session:
query = session.query(Domain) query = session.query(Domain)
query = query.filter(Domain.id.in_(ids)) query = query.filter(Domain.id.in_(ids))
domain_refs = query.all() domain_refs = query.all()
@ -199,11 +199,11 @@ class Resource(keystone_resource.ResourceDriverV8):
return ref return ref
def get_domain(self, domain_id): def get_domain(self, domain_id):
with sql.transaction() as session: with sql.session_for_read() as session:
return self._get_domain(session, domain_id).to_dict() return self._get_domain(session, domain_id).to_dict()
def get_domain_by_name(self, domain_name): def get_domain_by_name(self, domain_name):
with sql.transaction() as session: with sql.session_for_read() as session:
try: try:
ref = (session.query(Domain). ref = (session.query(Domain).
filter_by(name=domain_name).one()) filter_by(name=domain_name).one())
@ -213,7 +213,7 @@ class Resource(keystone_resource.ResourceDriverV8):
@sql.handle_conflicts(conflict_type='domain') @sql.handle_conflicts(conflict_type='domain')
def update_domain(self, domain_id, domain): def update_domain(self, domain_id, domain):
with sql.transaction() as session: with sql.session_for_write() as session:
ref = self._get_domain(session, domain_id) ref = self._get_domain(session, domain_id)
old_dict = ref.to_dict() old_dict = ref.to_dict()
for k in domain: for k in domain:
@ -226,7 +226,7 @@ class Resource(keystone_resource.ResourceDriverV8):
return ref.to_dict() return ref.to_dict()
def delete_domain(self, domain_id): def delete_domain(self, domain_id):
with sql.transaction() as session: with sql.session_for_write() as session:
ref = self._get_domain(session, domain_id) ref = self._get_domain(session, domain_id)
session.delete(ref) session.delete(ref)

View File

@ -38,11 +38,11 @@ class Resource(keystone_resource.ResourceDriverV9):
return project_ref return project_ref
def get_project(self, project_id): def get_project(self, project_id):
with sql.transaction() as session: with sql.session_for_read() as session:
return self._get_project(session, project_id).to_dict() return self._get_project(session, project_id).to_dict()
def get_project_by_name(self, project_name, domain_id): def get_project_by_name(self, project_name, domain_id):
with sql.transaction() as session: with sql.session_for_read() as session:
query = session.query(Project) query = session.query(Project)
query = query.filter_by(name=project_name) query = query.filter_by(name=project_name)
if domain_id is None: if domain_id is None:
@ -70,7 +70,7 @@ class Resource(keystone_resource.ResourceDriverV9):
for f in hints.filters: for f in hints.filters:
if (f['name'] == 'domain_id' and f['value'] is None): if (f['name'] == 'domain_id' and f['value'] is None):
f['value'] = keystone_resource.NULL_DOMAIN_ID f['value'] = keystone_resource.NULL_DOMAIN_ID
with sql.transaction() as session: with sql.session_for_read() as session:
query = session.query(Project) query = session.query(Project)
project_refs = sql.filter_limit_query(Project, query, hints) project_refs = sql.filter_limit_query(Project, query, hints)
return [project_ref.to_dict() for project_ref in project_refs return [project_ref.to_dict() for project_ref in project_refs
@ -80,7 +80,7 @@ class Resource(keystone_resource.ResourceDriverV9):
if not ids: if not ids:
return [] return []
else: else:
with sql.transaction() as session: with sql.session_for_read() as session:
query = session.query(Project) query = session.query(Project)
query = query.filter(Project.id.in_(ids)) query = query.filter(Project.id.in_(ids))
return [project_ref.to_dict() for project_ref in query.all() return [project_ref.to_dict() for project_ref in query.all()
@ -90,7 +90,7 @@ class Resource(keystone_resource.ResourceDriverV9):
if not domain_ids: if not domain_ids:
return [] return []
else: else:
with sql.transaction() as session: with sql.session_for_read() as session:
query = session.query(Project.id) query = session.query(Project.id)
query = ( query = (
query.filter(Project.domain_id.in_(domain_ids))) query.filter(Project.domain_id.in_(domain_ids)))
@ -98,7 +98,7 @@ class Resource(keystone_resource.ResourceDriverV9):
if not self._is_hidden_ref(x)] if not self._is_hidden_ref(x)]
def list_projects_in_domain(self, domain_id): def list_projects_in_domain(self, domain_id):
with sql.transaction() as session: with sql.session_for_read() as session:
self._get_domain(session, domain_id) self._get_domain(session, domain_id)
query = session.query(Project) query = session.query(Project)
project_refs = query.filter_by(domain_id=domain_id) project_refs = query.filter_by(domain_id=domain_id)
@ -111,7 +111,7 @@ class Resource(keystone_resource.ResourceDriverV9):
return [project_ref.to_dict() for project_ref in project_refs] return [project_ref.to_dict() for project_ref in project_refs]
def list_projects_in_subtree(self, project_id): def list_projects_in_subtree(self, project_id):
with sql.transaction() as session: with sql.session_for_read() as session:
children = self._get_children(session, [project_id]) children = self._get_children(session, [project_id])
subtree = [] subtree = []
examined = set([project_id]) examined = set([project_id])
@ -132,7 +132,7 @@ class Resource(keystone_resource.ResourceDriverV9):
return subtree return subtree
def list_project_parents(self, project_id): def list_project_parents(self, project_id):
with sql.transaction() as session: with sql.session_for_read() as session:
project = self._get_project(session, project_id).to_dict() project = self._get_project(session, project_id).to_dict()
parents = [] parents = []
examined = set() examined = set()
@ -152,7 +152,7 @@ class Resource(keystone_resource.ResourceDriverV9):
return parents return parents
def is_leaf_project(self, project_id): def is_leaf_project(self, project_id):
with sql.transaction() as session: with sql.session_for_read() as session:
project_refs = self._get_children(session, [project_id]) project_refs = self._get_children(session, [project_id])
return not project_refs return not project_refs
@ -161,7 +161,7 @@ class Resource(keystone_resource.ResourceDriverV9):
def create_project(self, project_id, project): def create_project(self, project_id, project):
project['name'] = clean.project_name(project['name']) project['name'] = clean.project_name(project['name'])
new_project = self._encode_domain_id(project) new_project = self._encode_domain_id(project)
with sql.transaction() as session: with sql.session_for_write() as session:
project_ref = Project.from_dict(new_project) project_ref = Project.from_dict(new_project)
session.add(project_ref) session.add(project_ref)
return project_ref.to_dict() return project_ref.to_dict()
@ -172,7 +172,7 @@ class Resource(keystone_resource.ResourceDriverV9):
project['name'] = clean.project_name(project['name']) project['name'] = clean.project_name(project['name'])
update_project = self._encode_domain_id(project) update_project = self._encode_domain_id(project)
with sql.transaction() as session: with sql.session_for_write() as session:
project_ref = self._get_project(session, project_id) project_ref = self._get_project(session, project_id)
old_project_dict = project_ref.to_dict() old_project_dict = project_ref.to_dict()
for k in update_project: for k in update_project:
@ -189,7 +189,7 @@ class Resource(keystone_resource.ResourceDriverV9):
@sql.handle_conflicts(conflict_type='project') @sql.handle_conflicts(conflict_type='project')
def delete_project(self, project_id): def delete_project(self, project_id):
with sql.transaction() as session: with sql.session_for_write() as session:
project_ref = self._get_project(session, project_id) project_ref = self._get_project(session, project_id)
session.delete(project_ref) session.delete(project_ref)
@ -197,7 +197,7 @@ class Resource(keystone_resource.ResourceDriverV9):
def delete_projects_from_ids(self, project_ids): def delete_projects_from_ids(self, project_ids):
if not project_ids: if not project_ids:
return return
with sql.transaction() as session: with sql.session_for_write() as session:
query = session.query(Project).filter(Project.id.in_( query = session.query(Project).filter(Project.id.in_(
project_ids)) project_ids))
project_ids_from_bd = [p['id'] for p in query.all()] project_ids_from_bd = [p['id'] for p in query.all()]
@ -212,14 +212,14 @@ class Resource(keystone_resource.ResourceDriverV9):
@sql.handle_conflicts(conflict_type='domain') @sql.handle_conflicts(conflict_type='domain')
def create_domain(self, domain_id, domain): def create_domain(self, domain_id, domain):
with sql.transaction() as session: with sql.session_for_write() as session:
ref = Domain.from_dict(domain) ref = Domain.from_dict(domain)
session.add(ref) session.add(ref)
return ref.to_dict() return ref.to_dict()
@driver_hints.truncated @driver_hints.truncated
def list_domains(self, hints): def list_domains(self, hints):
with sql.transaction() as session: with sql.session_for_read() as session:
query = session.query(Domain) query = session.query(Domain)
refs = sql.filter_limit_query(Domain, query, hints) refs = sql.filter_limit_query(Domain, query, hints)
return [ref.to_dict() for ref in refs return [ref.to_dict() for ref in refs
@ -229,7 +229,7 @@ class Resource(keystone_resource.ResourceDriverV9):
if not ids: if not ids:
return [] return []
else: else:
with sql.transaction() as session: with sql.session_for_read() as session:
query = session.query(Domain) query = session.query(Domain)
query = query.filter(Domain.id.in_(ids)) query = query.filter(Domain.id.in_(ids))
domain_refs = query.all() domain_refs = query.all()
@ -243,11 +243,11 @@ class Resource(keystone_resource.ResourceDriverV9):
return ref return ref
def get_domain(self, domain_id): def get_domain(self, domain_id):
with sql.transaction() as session: with sql.session_for_read() as session:
return self._get_domain(session, domain_id).to_dict() return self._get_domain(session, domain_id).to_dict()
def get_domain_by_name(self, domain_name): def get_domain_by_name(self, domain_name):
with sql.transaction() as session: with sql.session_for_read() as session:
try: try:
ref = (session.query(Domain). ref = (session.query(Domain).
filter_by(name=domain_name).one()) filter_by(name=domain_name).one())
@ -260,7 +260,7 @@ class Resource(keystone_resource.ResourceDriverV9):
@sql.handle_conflicts(conflict_type='domain') @sql.handle_conflicts(conflict_type='domain')
def update_domain(self, domain_id, domain): def update_domain(self, domain_id, domain):
with sql.transaction() as session: with sql.session_for_write() as session:
ref = self._get_domain(session, domain_id) ref = self._get_domain(session, domain_id)
old_dict = ref.to_dict() old_dict = ref.to_dict()
for k in domain: for k in domain:
@ -273,7 +273,7 @@ class Resource(keystone_resource.ResourceDriverV9):
return ref.to_dict() return ref.to_dict()
def delete_domain(self, domain_id): def delete_domain(self, domain_id):
with sql.transaction() as session: with sql.session_for_write() as session:
ref = self._get_domain(session, domain_id) ref = self._get_domain(session, domain_id)
session.delete(ref) session.delete(ref)

View File

@ -59,12 +59,12 @@ class DomainConfig(resource.DomainConfigDriverV8):
@sql.handle_conflicts(conflict_type='domain_config') @sql.handle_conflicts(conflict_type='domain_config')
def create_config_option(self, domain_id, group, option, value, def create_config_option(self, domain_id, group, option, value,
sensitive=False): sensitive=False):
with sql.transaction() as session: with sql.session_for_write() as session:
config_table = self.choose_table(sensitive) config_table = self.choose_table(sensitive)
ref = config_table(domain_id=domain_id, group=group, ref = config_table(domain_id=domain_id, group=group,
option=option, value=value) option=option, value=value)
session.add(ref) session.add(ref)
return ref.to_dict() return ref.to_dict()
def _get_config_option(self, session, domain_id, group, option, sensitive): def _get_config_option(self, session, domain_id, group, option, sensitive):
try: try:
@ -80,14 +80,14 @@ class DomainConfig(resource.DomainConfigDriverV8):
return ref return ref
def get_config_option(self, domain_id, group, option, sensitive=False): def get_config_option(self, domain_id, group, option, sensitive=False):
with sql.transaction() as session: with sql.session_for_read() as session:
ref = self._get_config_option(session, domain_id, group, option, ref = self._get_config_option(session, domain_id, group, option,
sensitive) sensitive)
return ref.to_dict() return ref.to_dict()
def list_config_options(self, domain_id, group=None, option=None, def list_config_options(self, domain_id, group=None, option=None,
sensitive=False): sensitive=False):
with sql.transaction() as session: with sql.session_for_read() as session:
config_table = self.choose_table(sensitive) config_table = self.choose_table(sensitive)
query = session.query(config_table) query = session.query(config_table)
query = query.filter_by(domain_id=domain_id) query = query.filter_by(domain_id=domain_id)
@ -99,11 +99,11 @@ class DomainConfig(resource.DomainConfigDriverV8):
def update_config_option(self, domain_id, group, option, value, def update_config_option(self, domain_id, group, option, value,
sensitive=False): sensitive=False):
with sql.transaction() as session: with sql.session_for_write() as session:
ref = self._get_config_option(session, domain_id, group, option, ref = self._get_config_option(session, domain_id, group, option,
sensitive) sensitive)
ref.value = value ref.value = value
return ref.to_dict() return ref.to_dict()
def delete_config_options(self, domain_id, group=None, option=None, def delete_config_options(self, domain_id, group=None, option=None,
sensitive=False): sensitive=False):
@ -114,7 +114,7 @@ class DomainConfig(resource.DomainConfigDriverV8):
if there was nothing to delete. if there was nothing to delete.
""" """
with sql.transaction() as session: with sql.session_for_write() as session:
config_table = self.choose_table(sensitive) config_table = self.choose_table(sensitive)
query = session.query(config_table) query = session.query(config_table)
query = query.filter_by(domain_id=domain_id) query = query.filter_by(domain_id=domain_id)
@ -126,7 +126,7 @@ class DomainConfig(resource.DomainConfigDriverV8):
def obtain_registration(self, domain_id, type): def obtain_registration(self, domain_id, type):
try: try:
with sql.transaction() as session: with sql.session_for_write() as session:
ref = ConfigRegister(type=type, domain_id=domain_id) ref = ConfigRegister(type=type, domain_id=domain_id)
session.add(ref) session.add(ref)
return True return True
@ -136,15 +136,15 @@ class DomainConfig(resource.DomainConfigDriverV8):
return False return False
def read_registration(self, type): def read_registration(self, type):
with sql.transaction() as session: with sql.session_for_read() as session:
ref = session.query(ConfigRegister).get(type) ref = session.query(ConfigRegister).get(type)
if not ref: if not ref:
raise exception.ConfigRegistrationNotFound() raise exception.ConfigRegistrationNotFound()
return ref.domain_id return ref.domain_id
def release_registration(self, domain_id, type=None): def release_registration(self, domain_id, type=None):
"""Silently delete anything registered for the domain specified.""" """Silently delete anything registered for the domain specified."""
with sql.transaction() as session: with sql.session_for_write() as session:
query = session.query(ConfigRegister) query = session.query(ConfigRegister)
if type: if type:
query = query.filter_by(type=type) query = query.filter_by(type=type)

View File

@ -60,37 +60,37 @@ class Revoke(revoke.RevokeDriverV8):
def _prune_expired_events(self): def _prune_expired_events(self):
oldest = revoke.revoked_before_cutoff_time() oldest = revoke.revoked_before_cutoff_time()
session = sql.get_session() with sql.session_for_write() as session:
dialect = session.bind.dialect.name dialect = session.bind.dialect.name
batch_size = self._flush_batch_size(dialect) batch_size = self._flush_batch_size(dialect)
if batch_size > 0: if batch_size > 0:
query = session.query(RevocationEvent.id) query = session.query(RevocationEvent.id)
query = query.filter(RevocationEvent.revoked_at < oldest) query = query.filter(RevocationEvent.revoked_at < oldest)
query = query.limit(batch_size).subquery() query = query.limit(batch_size).subquery()
delete_query = (session.query(RevocationEvent). delete_query = (session.query(RevocationEvent).
filter(RevocationEvent.id.in_(query))) filter(RevocationEvent.id.in_(query)))
while True: while True:
rowcount = delete_query.delete(synchronize_session=False) rowcount = delete_query.delete(synchronize_session=False)
if rowcount == 0: if rowcount == 0:
break break
else: else:
query = session.query(RevocationEvent) query = session.query(RevocationEvent)
query = query.filter(RevocationEvent.revoked_at < oldest) query = query.filter(RevocationEvent.revoked_at < oldest)
query.delete(synchronize_session=False) query.delete(synchronize_session=False)
session.flush() session.flush()
def list_events(self, last_fetch=None): def list_events(self, last_fetch=None):
session = sql.get_session() with sql.session_for_read() as session:
query = session.query(RevocationEvent).order_by( query = session.query(RevocationEvent).order_by(
RevocationEvent.revoked_at) RevocationEvent.revoked_at)
if last_fetch: if last_fetch:
query = query.filter(RevocationEvent.revoked_at > last_fetch) query = query.filter(RevocationEvent.revoked_at > last_fetch)
events = [model.RevokeEvent(**e.to_dict()) for e in query] events = [model.RevokeEvent(**e.to_dict()) for e in query]
return events return events
def revoke(self, event): def revoke(self, event):
kwargs = dict() kwargs = dict()
@ -98,7 +98,6 @@ class Revoke(revoke.RevokeDriverV8):
kwargs[attr] = getattr(event, attr) kwargs[attr] = getattr(event, attr)
kwargs['id'] = uuid.uuid4().hex kwargs['id'] = uuid.uuid4().hex
record = RevocationEvent(**kwargs) record = RevocationEvent(**kwargs)
session = sql.get_session() with sql.session_for_write() as session:
with session.begin():
session.add(record) session.add(record)
self._prune_expired_events() self._prune_expired_events()

View File

@ -17,6 +17,6 @@ from keystone.identity.mapping_backends import sql as mapping_sql
def list_id_mappings(): def list_id_mappings():
"""List all id_mappings for testing purposes.""" """List all id_mappings for testing purposes."""
a_session = sql.get_session() with sql.session_for_read() as session:
refs = a_session.query(mapping_sql.IDMapping).all() refs = session.query(mapping_sql.IDMapping).all()
return [x.to_dict() for x in refs] return [x.to_dict() for x in refs]

View File

@ -611,7 +611,7 @@ class SqlToken(SqlTests, test_backend.TokenTests):
tok = token_sql.Token() tok = token_sql.Token()
tok.list_revoked_tokens() tok.list_revoked_tokens()
mock_query = mock_sql.get_session().query mock_query = mock_sql.session_for_read().__enter__().query
mock_query.assert_called_with(*expected_query_args) mock_query.assert_called_with(*expected_query_args)
def test_flush_expired_tokens_batch(self): def test_flush_expired_tokens_batch(self):
@ -636,8 +636,12 @@ class SqlToken(SqlTests, test_backend.TokenTests):
# other tests below test the differences between how they use the batch # other tests below test the differences between how they use the batch
# strategy # strategy
with mock.patch.object(token_sql, 'sql') as mock_sql: with mock.patch.object(token_sql, 'sql') as mock_sql:
mock_sql.get_session().query().filter().delete.return_value = 0 mock_sql.session_for_write().__enter__(
mock_sql.get_session().bind.dialect.name = 'mysql' ).query().filter().delete.return_value = 0
mock_sql.session_for_write().__enter__(
).bind.dialect.name = 'mysql'
tok = token_sql.Token() tok = token_sql.Token()
expiry_mock = mock.Mock() expiry_mock = mock.Mock()
ITERS = [1, 2, 3] ITERS = [1, 2, 3]
@ -648,7 +652,10 @@ class SqlToken(SqlTests, test_backend.TokenTests):
# The expiry strategy is only invoked once, the other calls are via # The expiry strategy is only invoked once, the other calls are via
# the yield return. # the yield return.
self.assertEqual(1, expiry_mock.call_count) self.assertEqual(1, expiry_mock.call_count)
mock_delete = mock_sql.get_session().query().filter().delete
mock_delete = mock_sql.session_for_write().__enter__(
).query().filter().delete
self.assertThat(mock_delete.call_args_list, self.assertThat(mock_delete.call_args_list,
matchers.HasLength(len(ITERS))) matchers.HasLength(len(ITERS)))

View File

@ -161,7 +161,8 @@ class SqlMigrateBase(unit.SQLDriverOverrides, unit.TestCase):
self.repo_package()) self.repo_package())
self.schema = versioning_api.ControlledSchema.create( self.schema = versioning_api.ControlledSchema.create(
self.engine, self.engine,
self.repo_path, self.initial_db_version) self.repo_path,
self.initial_db_version)
# auto-detect the highest available schema version in the migrate_repo # auto-detect the highest available schema version in the migrate_repo
self.max_version = self.schema.repository.version().version self.max_version = self.schema.repository.version().version

View File

@ -86,11 +86,11 @@ class Token(token.persistence.TokenDriverV8):
def get_token(self, token_id): def get_token(self, token_id):
if token_id is None: if token_id is None:
raise exception.TokenNotFound(token_id=token_id) raise exception.TokenNotFound(token_id=token_id)
session = sql.get_session() with sql.session_for_read() as session:
token_ref = session.query(TokenModel).get(token_id) token_ref = session.query(TokenModel).get(token_id)
if not token_ref or not token_ref.valid: if not token_ref or not token_ref.valid:
raise exception.TokenNotFound(token_id=token_id) raise exception.TokenNotFound(token_id=token_id)
return token_ref.to_dict() return token_ref.to_dict()
def create_token(self, token_id, data): def create_token(self, token_id, data):
data_copy = copy.deepcopy(data) data_copy = copy.deepcopy(data)
@ -101,14 +101,12 @@ class Token(token.persistence.TokenDriverV8):
token_ref = TokenModel.from_dict(data_copy) token_ref = TokenModel.from_dict(data_copy)
token_ref.valid = True token_ref.valid = True
session = sql.get_session() with sql.session_for_write() as session:
with session.begin():
session.add(token_ref) session.add(token_ref)
return token_ref.to_dict() return token_ref.to_dict()
def delete_token(self, token_id): def delete_token(self, token_id):
session = sql.get_session() with sql.session_for_write() as session:
with session.begin():
token_ref = session.query(TokenModel).get(token_id) token_ref = session.query(TokenModel).get(token_id)
if not token_ref or not token_ref.valid: if not token_ref or not token_ref.valid:
raise exception.TokenNotFound(token_id=token_id) raise exception.TokenNotFound(token_id=token_id)
@ -124,9 +122,8 @@ class Token(token.persistence.TokenDriverV8):
or the trustor's user ID, so will use trust_id to query the tokens. or the trustor's user ID, so will use trust_id to query the tokens.
""" """
session = sql.get_session()
token_list = [] token_list = []
with session.begin(): with sql.session_for_write() as session:
now = timeutils.utcnow() now = timeutils.utcnow()
query = session.query(TokenModel) query = session.query(TokenModel)
query = query.filter_by(valid=True) query = query.filter_by(valid=True)
@ -167,38 +164,37 @@ class Token(token.persistence.TokenDriverV8):
return False return False
def _list_tokens_for_trust(self, trust_id): def _list_tokens_for_trust(self, trust_id):
session = sql.get_session() with sql.session_for_read() as session:
tokens = [] tokens = []
now = timeutils.utcnow() now = timeutils.utcnow()
query = session.query(TokenModel) query = session.query(TokenModel)
query = query.filter(TokenModel.expires > now) query = query.filter(TokenModel.expires > now)
query = query.filter(TokenModel.trust_id == trust_id) query = query.filter(TokenModel.trust_id == trust_id)
token_references = query.filter_by(valid=True) token_references = query.filter_by(valid=True)
for token_ref in token_references: for token_ref in token_references:
token_ref_dict = token_ref.to_dict() token_ref_dict = token_ref.to_dict()
tokens.append(token_ref_dict['id']) tokens.append(token_ref_dict['id'])
return tokens return tokens
def _list_tokens_for_user(self, user_id, tenant_id=None): def _list_tokens_for_user(self, user_id, tenant_id=None):
session = sql.get_session() with sql.session_for_read() as session:
tokens = [] tokens = []
now = timeutils.utcnow() now = timeutils.utcnow()
query = session.query(TokenModel) query = session.query(TokenModel)
query = query.filter(TokenModel.expires > now) query = query.filter(TokenModel.expires > now)
query = query.filter(TokenModel.user_id == user_id) query = query.filter(TokenModel.user_id == user_id)
token_references = query.filter_by(valid=True) token_references = query.filter_by(valid=True)
for token_ref in token_references: for token_ref in token_references:
token_ref_dict = token_ref.to_dict() token_ref_dict = token_ref.to_dict()
if self._tenant_matches(tenant_id, token_ref_dict): if self._tenant_matches(tenant_id, token_ref_dict):
tokens.append(token_ref['id']) tokens.append(token_ref['id'])
return tokens return tokens
def _list_tokens_for_consumer(self, user_id, consumer_id): def _list_tokens_for_consumer(self, user_id, consumer_id):
tokens = [] tokens = []
session = sql.get_session() with sql.session_for_write() as session:
with session.begin():
now = timeutils.utcnow() now = timeutils.utcnow()
query = session.query(TokenModel) query = session.query(TokenModel)
query = query.filter(TokenModel.expires > now) query = query.filter(TokenModel.expires > now)
@ -223,29 +219,29 @@ class Token(token.persistence.TokenDriverV8):
return self._list_tokens_for_user(user_id, tenant_id) return self._list_tokens_for_user(user_id, tenant_id)
def list_revoked_tokens(self): def list_revoked_tokens(self):
session = sql.get_session() with sql.session_for_read() as session:
tokens = [] tokens = []
now = timeutils.utcnow() now = timeutils.utcnow()
query = session.query(TokenModel.id, TokenModel.expires, query = session.query(TokenModel.id, TokenModel.expires,
TokenModel.extra) TokenModel.extra)
query = query.filter(TokenModel.expires > now) query = query.filter(TokenModel.expires > now)
token_references = query.filter_by(valid=False) token_references = query.filter_by(valid=False)
for token_ref in token_references: for token_ref in token_references:
token_data = token_ref[2]['token_data'] token_data = token_ref[2]['token_data']
if 'access' in token_data: if 'access' in token_data:
# It's a v2 token. # It's a v2 token.
audit_ids = token_data['access']['token']['audit_ids'] audit_ids = token_data['access']['token']['audit_ids']
else: else:
# It's a v3 token. # It's a v3 token.
audit_ids = token_data['token']['audit_ids'] audit_ids = token_data['token']['audit_ids']
record = { record = {
'id': token_ref[0], 'id': token_ref[0],
'expires': token_ref[1], 'expires': token_ref[1],
'audit_id': audit_ids[0], 'audit_id': audit_ids[0],
} }
tokens.append(record) tokens.append(record)
return tokens return tokens
def _expiry_range_strategy(self, dialect): def _expiry_range_strategy(self, dialect):
"""Choose a token range expiration strategy """Choose a token range expiration strategy
@ -273,18 +269,18 @@ class Token(token.persistence.TokenDriverV8):
return _expiry_range_all return _expiry_range_all
def flush_expired_tokens(self): def flush_expired_tokens(self):
session = sql.get_session() with sql.session_for_write() as session:
dialect = session.bind.dialect.name dialect = session.bind.dialect.name
expiry_range_func = self._expiry_range_strategy(dialect) expiry_range_func = self._expiry_range_strategy(dialect)
query = session.query(TokenModel.expires) query = session.query(TokenModel.expires)
total_removed = 0 total_removed = 0
upper_bound_func = timeutils.utcnow upper_bound_func = timeutils.utcnow
for expiry_time in expiry_range_func(session, upper_bound_func): for expiry_time in expiry_range_func(session, upper_bound_func):
delete_query = query.filter(TokenModel.expires <= delete_query = query.filter(TokenModel.expires <=
expiry_time) expiry_time)
row_count = delete_query.delete(synchronize_session=False) row_count = delete_query.delete(synchronize_session=False)
total_removed += row_count total_removed += row_count
LOG.debug('Removed %d total expired tokens', total_removed) LOG.debug('Removed %d total expired tokens', total_removed)
session.flush() session.flush()
LOG.info(_LI('Total expired tokens removed: %d'), total_removed) LOG.info(_LI('Total expired tokens removed: %d'), total_removed)

View File

@ -59,7 +59,7 @@ class TrustRole(sql.ModelBase):
class Trust(trust.TrustDriverV8): class Trust(trust.TrustDriverV8):
@sql.handle_conflicts(conflict_type='trust') @sql.handle_conflicts(conflict_type='trust')
def create_trust(self, trust_id, trust, roles): def create_trust(self, trust_id, trust, roles):
with sql.transaction() as session: with sql.session_for_write() as session:
ref = TrustModel.from_dict(trust) ref = TrustModel.from_dict(trust)
ref['id'] = trust_id ref['id'] = trust_id
if ref.get('expires_at') and ref['expires_at'].tzinfo is not None: if ref.get('expires_at') and ref['expires_at'].tzinfo is not None:
@ -72,9 +72,9 @@ class Trust(trust.TrustDriverV8):
trust_role.role_id = role['id'] trust_role.role_id = role['id']
added_roles.append({'id': role['id']}) added_roles.append({'id': role['id']})
session.add(trust_role) session.add(trust_role)
trust_dict = ref.to_dict() trust_dict = ref.to_dict()
trust_dict['roles'] = added_roles trust_dict['roles'] = added_roles
return trust_dict return trust_dict
def _add_roles(self, trust_id, session, trust_dict): def _add_roles(self, trust_id, session, trust_dict):
roles = [] roles = []
@ -86,7 +86,7 @@ class Trust(trust.TrustDriverV8):
def consume_use(self, trust_id): def consume_use(self, trust_id):
for attempt in range(MAXIMUM_CONSUME_ATTEMPTS): for attempt in range(MAXIMUM_CONSUME_ATTEMPTS):
with sql.transaction() as session: with sql.session_for_write() as session:
try: try:
query_result = (session.query(TrustModel.remaining_uses). query_result = (session.query(TrustModel.remaining_uses).
filter_by(id=trust_id). filter_by(id=trust_id).
@ -132,51 +132,51 @@ class Trust(trust.TrustDriverV8):
raise exception.TrustConsumeMaximumAttempt(trust_id=trust_id) raise exception.TrustConsumeMaximumAttempt(trust_id=trust_id)
def get_trust(self, trust_id, deleted=False): def get_trust(self, trust_id, deleted=False):
session = sql.get_session() with sql.session_for_read() as session:
query = session.query(TrustModel).filter_by(id=trust_id) query = session.query(TrustModel).filter_by(id=trust_id)
if not deleted: if not deleted:
query = query.filter_by(deleted_at=None) query = query.filter_by(deleted_at=None)
ref = query.first() ref = query.first()
if ref is None: if ref is None:
raise exception.TrustNotFound(trust_id=trust_id)
if ref.expires_at is not None and not deleted:
now = timeutils.utcnow()
if now > ref.expires_at:
raise exception.TrustNotFound(trust_id=trust_id) raise exception.TrustNotFound(trust_id=trust_id)
# Do not return trusts that can't be used anymore if ref.expires_at is not None and not deleted:
if ref.remaining_uses is not None and not deleted: now = timeutils.utcnow()
if ref.remaining_uses <= 0: if now > ref.expires_at:
raise exception.TrustNotFound(trust_id=trust_id) raise exception.TrustNotFound(trust_id=trust_id)
trust_dict = ref.to_dict() # Do not return trusts that can't be used anymore
if ref.remaining_uses is not None and not deleted:
if ref.remaining_uses <= 0:
raise exception.TrustNotFound(trust_id=trust_id)
trust_dict = ref.to_dict()
self._add_roles(trust_id, session, trust_dict) self._add_roles(trust_id, session, trust_dict)
return trust_dict return trust_dict
@sql.handle_conflicts(conflict_type='trust') @sql.handle_conflicts(conflict_type='trust')
def list_trusts(self): def list_trusts(self):
session = sql.get_session() with sql.session_for_read() as session:
trusts = session.query(TrustModel).filter_by(deleted_at=None) trusts = session.query(TrustModel).filter_by(deleted_at=None)
return [trust_ref.to_dict() for trust_ref in trusts] return [trust_ref.to_dict() for trust_ref in trusts]
@sql.handle_conflicts(conflict_type='trust') @sql.handle_conflicts(conflict_type='trust')
def list_trusts_for_trustee(self, trustee_user_id): def list_trusts_for_trustee(self, trustee_user_id):
session = sql.get_session() with sql.session_for_read() as session:
trusts = (session.query(TrustModel). trusts = (session.query(TrustModel).
filter_by(deleted_at=None). filter_by(deleted_at=None).
filter_by(trustee_user_id=trustee_user_id)) filter_by(trustee_user_id=trustee_user_id))
return [trust_ref.to_dict() for trust_ref in trusts] return [trust_ref.to_dict() for trust_ref in trusts]
@sql.handle_conflicts(conflict_type='trust') @sql.handle_conflicts(conflict_type='trust')
def list_trusts_for_trustor(self, trustor_user_id): def list_trusts_for_trustor(self, trustor_user_id):
session = sql.get_session() with sql.session_for_read() as session:
trusts = (session.query(TrustModel). trusts = (session.query(TrustModel).
filter_by(deleted_at=None). filter_by(deleted_at=None).
filter_by(trustor_user_id=trustor_user_id)) filter_by(trustor_user_id=trustor_user_id))
return [trust_ref.to_dict() for trust_ref in trusts] return [trust_ref.to_dict() for trust_ref in trusts]
@sql.handle_conflicts(conflict_type='trust') @sql.handle_conflicts(conflict_type='trust')
def delete_trust(self, trust_id): def delete_trust(self, trust_id):
with sql.transaction() as session: with sql.session_for_write() as session:
trust_ref = session.query(TrustModel).get(trust_id) trust_ref = session.query(TrustModel).get(trust_id)
if not trust_ref: if not trust_ref:
raise exception.TrustNotFound(trust_id=trust_id) raise exception.TrustNotFound(trust_id=trust_id)