Use the new enginefacade from oslo.db

EngineFacade is deprecated. This partially switches keystone to
use oslo.db.sqlalchemy.enginefacade. 'get_session' and 'get_engine'
methods are still used in sql migrations and related tests.

Change-Id: I221232d50821fe2adb9881f237f06714003ce79d
Partial-Bug: #1490571
This commit is contained in:
Grzegorz Grasza 2015-12-14 17:07:46 +01:00 committed by Morgan Fainberg
parent e943768088
commit 0e156737d0
24 changed files with 700 additions and 745 deletions

View File

@ -56,7 +56,7 @@ class Assignment(keystone_assignment.AssignmentDriverV8):
return 'sql'
def list_user_ids_for_project(self, tenant_id):
with sql.transaction() as session:
with sql.session_for_read() as session:
query = session.query(RoleAssignment.actor_id)
query = query.filter_by(type=AssignmentType.USER_PROJECT)
query = query.filter_by(target_id=tenant_id)
@ -71,7 +71,7 @@ class Assignment(keystone_assignment.AssignmentDriverV8):
assignment_type = AssignmentType.calculate_type(
user_id, group_id, project_id, domain_id)
try:
with sql.transaction() as session:
with sql.session_for_write() as session:
session.add(RoleAssignment(
type=assignment_type,
actor_id=user_id or group_id,
@ -85,7 +85,7 @@ class Assignment(keystone_assignment.AssignmentDriverV8):
def list_grant_role_ids(self, user_id=None, group_id=None,
domain_id=None, project_id=None,
inherited_to_projects=False):
with sql.transaction() as session:
with sql.session_for_read() as session:
q = session.query(RoleAssignment.role_id)
q = q.filter(RoleAssignment.actor_id == (user_id or group_id))
q = q.filter(RoleAssignment.target_id == (project_id or domain_id))
@ -104,7 +104,7 @@ class Assignment(keystone_assignment.AssignmentDriverV8):
def check_grant_role_id(self, role_id, user_id=None, group_id=None,
domain_id=None, project_id=None,
inherited_to_projects=False):
with sql.transaction() as session:
with sql.session_for_read() as session:
try:
q = self._build_grant_filter(
session, role_id, user_id, group_id, domain_id, project_id,
@ -120,7 +120,7 @@ class Assignment(keystone_assignment.AssignmentDriverV8):
def delete_grant(self, role_id, user_id=None, group_id=None,
domain_id=None, project_id=None,
inherited_to_projects=False):
with sql.transaction() as session:
with sql.session_for_write() as session:
q = self._build_grant_filter(
session, role_id, user_id, group_id, domain_id, project_id,
inherited_to_projects)
@ -145,11 +145,11 @@ class Assignment(keystone_assignment.AssignmentDriverV8):
RoleAssignment.inherited == inherited,
RoleAssignment.actor_id.in_(actors))
with sql.transaction() as session:
with sql.session_for_read() as session:
query = session.query(RoleAssignment.target_id).filter(
sql_constraints).distinct()
return [x.target_id for x in query.all()]
return [x.target_id for x in query.all()]
def list_project_ids_for_user(self, user_id, group_ids, hints,
inherited=False):
@ -161,7 +161,7 @@ class Assignment(keystone_assignment.AssignmentDriverV8):
def list_domain_ids_for_user(self, user_id, group_ids, hints,
inherited=False):
with sql.transaction() as session:
with sql.session_for_read() as session:
query = session.query(RoleAssignment.target_id)
filters = []
@ -197,10 +197,10 @@ class Assignment(keystone_assignment.AssignmentDriverV8):
RoleAssignment.inherited == false(),
RoleAssignment.actor_id.in_(group_ids))
with sql.transaction() as session:
with sql.session_for_read() as session:
query = session.query(RoleAssignment.role_id).filter(
sql_constraints).distinct()
return [role.role_id for role in query.all()]
return [role.role_id for role in query.all()]
def list_role_ids_for_groups_on_project(
self, group_ids, project_id, project_domain_id, project_parents):
@ -237,13 +237,13 @@ class Assignment(keystone_assignment.AssignmentDriverV8):
sql_constraints = sqlalchemy.and_(
sql_constraints, RoleAssignment.actor_id.in_(group_ids))
with sql.transaction() as session:
with sql.session_for_read() as session:
# NOTE(morganfainberg): Only select the columns we actually care
# about here, in this case role_id.
query = session.query(RoleAssignment.role_id).filter(
sql_constraints).distinct()
return [result.role_id for result in query.all()]
return [result.role_id for result in query.all()]
def list_project_ids_for_groups(self, group_ids, hints,
inherited=False):
@ -260,14 +260,14 @@ class Assignment(keystone_assignment.AssignmentDriverV8):
RoleAssignment.inherited == inherited,
RoleAssignment.actor_id.in_(group_ids))
with sql.transaction() as session:
with sql.session_for_read() as session:
query = session.query(RoleAssignment.target_id).filter(
group_sql_conditions).distinct()
return [x.target_id for x in query.all()]
return [x.target_id for x in query.all()]
def add_role_to_user_and_project(self, user_id, tenant_id, role_id):
try:
with sql.transaction() as session:
with sql.session_for_write() as session:
session.add(RoleAssignment(
type=AssignmentType.USER_PROJECT,
actor_id=user_id, target_id=tenant_id,
@ -278,7 +278,7 @@ class Assignment(keystone_assignment.AssignmentDriverV8):
raise exception.Conflict(type='role grant', details=msg)
def remove_role_from_user_and_project(self, user_id, tenant_id, role_id):
with sql.transaction() as session:
with sql.session_for_write() as session:
q = session.query(RoleAssignment)
q = q.filter_by(actor_id=user_id)
q = q.filter_by(target_id=tenant_id)
@ -368,7 +368,7 @@ class Assignment(keystone_assignment.AssignmentDriverV8):
assignment['inherited_to_projects'] = 'projects'
return assignment
with sql.transaction() as session:
with sql.session_for_read() as session:
assignment_types = self._get_assignment_types(
user_id, group_ids, project_ids, domain_id)
@ -400,25 +400,25 @@ class Assignment(keystone_assignment.AssignmentDriverV8):
return [denormalize_role(ref) for ref in query.all()]
def delete_project_assignments(self, project_id):
with sql.transaction() as session:
with sql.session_for_write() as session:
q = session.query(RoleAssignment)
q = q.filter_by(target_id=project_id)
q.delete(False)
def delete_role_assignments(self, role_id):
with sql.transaction() as session:
with sql.session_for_write() as session:
q = session.query(RoleAssignment)
q = q.filter_by(role_id=role_id)
q.delete(False)
def delete_user_assignments(self, user_id):
with sql.transaction() as session:
with sql.session_for_write() as session:
q = session.query(RoleAssignment)
q = q.filter_by(actor_id=user_id)
q.delete(False)
def delete_group_assignments(self, group_id):
with sql.transaction() as session:
with sql.session_for_write() as session:
q = session.query(RoleAssignment)
q = q.filter_by(actor_id=group_id)
q.delete(False)

View File

@ -19,14 +19,14 @@ class Role(assignment.RoleDriverV8):
@sql.handle_conflicts(conflict_type='role')
def create_role(self, role_id, role):
with sql.transaction() as session:
with sql.session_for_write() as session:
ref = RoleTable.from_dict(role)
session.add(ref)
return ref.to_dict()
@sql.truncated
def list_roles(self, hints):
with sql.transaction() as session:
with sql.session_for_read() as session:
query = session.query(RoleTable)
refs = sql.filter_limit_query(RoleTable, query, hints)
return [ref.to_dict() for ref in refs]
@ -35,7 +35,7 @@ class Role(assignment.RoleDriverV8):
if not ids:
return []
else:
with sql.transaction() as session:
with sql.session_for_read() as session:
query = session.query(RoleTable)
query = query.filter(RoleTable.id.in_(ids))
role_refs = query.all()
@ -48,12 +48,12 @@ class Role(assignment.RoleDriverV8):
return ref
def get_role(self, role_id):
with sql.transaction() as session:
with sql.session_for_read() as session:
return self._get_role(session, role_id).to_dict()
@sql.handle_conflicts(conflict_type='role')
def update_role(self, role_id, role):
with sql.transaction() as session:
with sql.session_for_write() as session:
ref = self._get_role(session, role_id)
old_dict = ref.to_dict()
for k in role:
@ -66,7 +66,7 @@ class Role(assignment.RoleDriverV8):
return ref.to_dict()
def delete_role(self, role_id):
with sql.transaction() as session:
with sql.session_for_write() as session:
ref = self._get_role(session, role_id)
session.delete(ref)

View File

@ -55,7 +55,7 @@ class Assignment(keystone_assignment.AssignmentDriverV9):
assignment_type = AssignmentType.calculate_type(
user_id, group_id, project_id, domain_id)
try:
with sql.transaction() as session:
with sql.session_for_write() as session:
session.add(RoleAssignment(
type=assignment_type,
actor_id=user_id or group_id,
@ -69,7 +69,7 @@ class Assignment(keystone_assignment.AssignmentDriverV9):
def list_grant_role_ids(self, user_id=None, group_id=None,
domain_id=None, project_id=None,
inherited_to_projects=False):
with sql.transaction() as session:
with sql.session_for_read() as session:
q = session.query(RoleAssignment.role_id)
q = q.filter(RoleAssignment.actor_id == (user_id or group_id))
q = q.filter(RoleAssignment.target_id == (project_id or domain_id))
@ -88,7 +88,7 @@ class Assignment(keystone_assignment.AssignmentDriverV9):
def check_grant_role_id(self, role_id, user_id=None, group_id=None,
domain_id=None, project_id=None,
inherited_to_projects=False):
with sql.transaction() as session:
with sql.session_for_read() as session:
try:
q = self._build_grant_filter(
session, role_id, user_id, group_id, domain_id, project_id,
@ -104,7 +104,7 @@ class Assignment(keystone_assignment.AssignmentDriverV9):
def delete_grant(self, role_id, user_id=None, group_id=None,
domain_id=None, project_id=None,
inherited_to_projects=False):
with sql.transaction() as session:
with sql.session_for_write() as session:
q = self._build_grant_filter(
session, role_id, user_id, group_id, domain_id, project_id,
inherited_to_projects)
@ -117,7 +117,7 @@ class Assignment(keystone_assignment.AssignmentDriverV9):
def add_role_to_user_and_project(self, user_id, tenant_id, role_id):
try:
with sql.transaction() as session:
with sql.session_for_write() as session:
session.add(RoleAssignment(
type=AssignmentType.USER_PROJECT,
actor_id=user_id, target_id=tenant_id,
@ -128,7 +128,7 @@ class Assignment(keystone_assignment.AssignmentDriverV9):
raise exception.Conflict(type='role grant', details=msg)
def remove_role_from_user_and_project(self, user_id, tenant_id, role_id):
with sql.transaction() as session:
with sql.session_for_write() as session:
q = session.query(RoleAssignment)
q = q.filter_by(actor_id=user_id)
q = q.filter_by(target_id=tenant_id)
@ -218,7 +218,7 @@ class Assignment(keystone_assignment.AssignmentDriverV9):
assignment['inherited_to_projects'] = 'projects'
return assignment
with sql.transaction() as session:
with sql.session_for_read() as session:
assignment_types = self._get_assignment_types(
user_id, group_ids, project_ids, domain_id)
@ -250,7 +250,7 @@ class Assignment(keystone_assignment.AssignmentDriverV9):
return [denormalize_role(ref) for ref in query.all()]
def delete_project_assignments(self, project_id):
with sql.transaction() as session:
with sql.session_for_write() as session:
q = session.query(RoleAssignment)
q = q.filter_by(target_id=project_id).filter(
RoleAssignment.type.in_((AssignmentType.USER_PROJECT,
@ -259,13 +259,13 @@ class Assignment(keystone_assignment.AssignmentDriverV9):
q.delete(False)
def delete_role_assignments(self, role_id):
with sql.transaction() as session:
with sql.session_for_write() as session:
q = session.query(RoleAssignment)
q = q.filter_by(role_id=role_id)
q.delete(False)
def delete_user_assignments(self, user_id):
with sql.transaction() as session:
with sql.session_for_write() as session:
q = session.query(RoleAssignment)
q = q.filter_by(actor_id=user_id).filter(
RoleAssignment.type.in_((AssignmentType.USER_PROJECT,
@ -274,7 +274,7 @@ class Assignment(keystone_assignment.AssignmentDriverV9):
q.delete(False)
def delete_group_assignments(self, group_id):
with sql.transaction() as session:
with sql.session_for_write() as session:
q = session.query(RoleAssignment)
q = q.filter_by(actor_id=group_id).filter(
RoleAssignment.type.in_((AssignmentType.GROUP_PROJECT,

View File

@ -29,7 +29,7 @@ class Role(assignment.RoleDriverV9):
@sql.handle_conflicts(conflict_type='role')
def create_role(self, role_id, role):
with sql.transaction() as session:
with sql.session_for_write() as session:
ref = RoleTable.from_dict(role)
session.add(ref)
return ref.to_dict()
@ -46,7 +46,7 @@ class Role(assignment.RoleDriverV9):
if (f['name'] == 'domain_id' and f['value'] is None):
f['value'] = NULL_DOMAIN_ID
with sql.transaction() as session:
with sql.session_for_read() as session:
query = session.query(RoleTable)
refs = sql.filter_limit_query(RoleTable, query, hints)
return [ref.to_dict() for ref in refs]
@ -55,7 +55,7 @@ class Role(assignment.RoleDriverV9):
if not ids:
return []
else:
with sql.transaction() as session:
with sql.session_for_read() as session:
query = session.query(RoleTable)
query = query.filter(RoleTable.id.in_(ids))
role_refs = query.all()
@ -68,12 +68,12 @@ class Role(assignment.RoleDriverV9):
return ref
def get_role(self, role_id):
with sql.transaction() as session:
with sql.session_for_read() as session:
return self._get_role(session, role_id).to_dict()
@sql.handle_conflicts(conflict_type='role')
def update_role(self, role_id, role):
with sql.transaction() as session:
with sql.session_for_write() as session:
ref = self._get_role(session, role_id)
old_dict = ref.to_dict()
for k in role:
@ -86,7 +86,7 @@ class Role(assignment.RoleDriverV9):
return ref.to_dict()
def delete_role(self, role_id):
with sql.transaction() as session:
with sql.session_for_write() as session:
ref = self._get_role(session, role_id)
session.delete(ref)
@ -105,7 +105,7 @@ class Role(assignment.RoleDriverV9):
@sql.handle_conflicts(conflict_type='implied_role')
def create_implied_role(self, prior_role_id, implied_role_id):
with sql.transaction() as session:
with sql.session_for_write() as session:
inference = {'prior_role_id': prior_role_id,
'implied_role_id': implied_role_id}
ref = ImpliedRoleTable.from_dict(inference)
@ -119,13 +119,13 @@ class Role(assignment.RoleDriverV9):
return ref.to_dict()
def delete_implied_role(self, prior_role_id, implied_role_id):
with sql.transaction() as session:
with sql.session_for_write() as session:
ref = self._get_implied_role(session, prior_role_id,
implied_role_id)
session.delete(ref)
def list_implied_roles(self, prior_role_id):
with sql.transaction() as session:
with sql.session_for_read() as session:
query = session.query(
ImpliedRoleTable).filter(
ImpliedRoleTable.prior_role_id == prior_role_id)
@ -133,13 +133,13 @@ class Role(assignment.RoleDriverV9):
return [ref.to_dict() for ref in refs]
def list_role_inference_rules(self):
with sql.transaction() as session:
with sql.session_for_read() as session:
query = session.query(ImpliedRoleTable)
refs = query.all()
return [ref.to_dict() for ref in refs]
def get_implied_role(self, prior_role_id, implied_role_id):
with sql.transaction() as session:
with sql.session_for_read() as session:
ref = self._get_implied_role(session, prior_role_id,
implied_role_id)
return ref.to_dict()

View File

@ -84,10 +84,10 @@ class Endpoint(sql.ModelBase, sql.DictBase):
class Catalog(catalog.CatalogDriverV8):
# Regions
def list_regions(self, hints):
session = sql.get_session()
regions = session.query(Region)
regions = sql.filter_limit_query(Region, regions, hints)
return [s.to_dict() for s in list(regions)]
with sql.session_for_read() as session:
regions = session.query(Region)
regions = sql.filter_limit_query(Region, regions, hints)
return [s.to_dict() for s in list(regions)]
def _get_region(self, session, region_id):
ref = session.query(Region).get(region_id)
@ -136,12 +136,11 @@ class Catalog(catalog.CatalogDriverV8):
return False
def get_region(self, region_id):
session = sql.get_session()
return self._get_region(session, region_id).to_dict()
with sql.session_for_read() as session:
return self._get_region(session, region_id).to_dict()
def delete_region(self, region_id):
session = sql.get_session()
with session.begin():
with sql.session_for_write() as session:
ref = self._get_region(session, region_id)
if self._has_endpoints(session, ref, ref):
raise exception.RegionDeletionError(region_id=region_id)
@ -150,16 +149,14 @@ class Catalog(catalog.CatalogDriverV8):
@sql.handle_conflicts(conflict_type='region')
def create_region(self, region_ref):
session = sql.get_session()
with session.begin():
with sql.session_for_write() as session:
self._check_parent_region(session, region_ref)
region = Region.from_dict(region_ref)
session.add(region)
return region.to_dict()
return region.to_dict()
def update_region(self, region_id, region_ref):
session = sql.get_session()
with session.begin():
with sql.session_for_write() as session:
self._check_parent_region(session, region_ref)
ref = self._get_region(session, region_id)
old_dict = ref.to_dict()
@ -169,15 +166,15 @@ class Catalog(catalog.CatalogDriverV8):
for attr in Region.attributes:
if attr != 'id':
setattr(ref, attr, getattr(new_region, attr))
return ref.to_dict()
return ref.to_dict()
# Services
@driver_hints.truncated
def list_services(self, hints):
session = sql.get_session()
services = session.query(Service)
services = sql.filter_limit_query(Service, services, hints)
return [s.to_dict() for s in list(services)]
with sql.session_for_read() as session:
services = session.query(Service)
services = sql.filter_limit_query(Service, services, hints)
return [s.to_dict() for s in list(services)]
def _get_service(self, session, service_id):
ref = session.query(Service).get(service_id)
@ -186,26 +183,23 @@ class Catalog(catalog.CatalogDriverV8):
return ref
def get_service(self, service_id):
session = sql.get_session()
return self._get_service(session, service_id).to_dict()
with sql.session_for_read() as session:
return self._get_service(session, service_id).to_dict()
def delete_service(self, service_id):
session = sql.get_session()
with session.begin():
with sql.session_for_write() as session:
ref = self._get_service(session, service_id)
session.query(Endpoint).filter_by(service_id=service_id).delete()
session.delete(ref)
def create_service(self, service_id, service_ref):
session = sql.get_session()
with session.begin():
with sql.session_for_write() as session:
service = Service.from_dict(service_ref)
session.add(service)
return service.to_dict()
return service.to_dict()
def update_service(self, service_id, service_ref):
session = sql.get_session()
with session.begin():
with sql.session_for_write() as session:
ref = self._get_service(session, service_id)
old_dict = ref.to_dict()
old_dict.update(service_ref)
@ -214,20 +208,17 @@ class Catalog(catalog.CatalogDriverV8):
if attr != 'id':
setattr(ref, attr, getattr(new_service, attr))
ref.extra = new_service.extra
return ref.to_dict()
return ref.to_dict()
# Endpoints
def create_endpoint(self, endpoint_id, endpoint_ref):
session = sql.get_session()
new_endpoint = Endpoint.from_dict(endpoint_ref)
with session.begin():
with sql.session_for_write() as session:
session.add(new_endpoint)
return new_endpoint.to_dict()
def delete_endpoint(self, endpoint_id):
session = sql.get_session()
with session.begin():
with sql.session_for_write() as session:
ref = self._get_endpoint(session, endpoint_id)
session.delete(ref)
@ -238,20 +229,18 @@ class Catalog(catalog.CatalogDriverV8):
raise exception.EndpointNotFound(endpoint_id=endpoint_id)
def get_endpoint(self, endpoint_id):
session = sql.get_session()
return self._get_endpoint(session, endpoint_id).to_dict()
with sql.session_for_read() as session:
return self._get_endpoint(session, endpoint_id).to_dict()
@driver_hints.truncated
def list_endpoints(self, hints):
session = sql.get_session()
endpoints = session.query(Endpoint)
endpoints = sql.filter_limit_query(Endpoint, endpoints, hints)
return [e.to_dict() for e in list(endpoints)]
with sql.session_for_read() as session:
endpoints = session.query(Endpoint)
endpoints = sql.filter_limit_query(Endpoint, endpoints, hints)
return [e.to_dict() for e in list(endpoints)]
def update_endpoint(self, endpoint_id, endpoint_ref):
session = sql.get_session()
with session.begin():
with sql.session_for_write() as session:
ref = self._get_endpoint(session, endpoint_id)
old_dict = ref.to_dict()
old_dict.update(endpoint_ref)
@ -260,7 +249,7 @@ class Catalog(catalog.CatalogDriverV8):
if attr != 'id':
setattr(ref, attr, getattr(new_endpoint, attr))
ref.extra = new_endpoint.extra
return ref.to_dict()
return ref.to_dict()
def get_catalog(self, user_id, tenant_id):
"""Retrieve and format the V2 service catalog.
@ -289,40 +278,40 @@ class Catalog(catalog.CatalogDriverV8):
else:
silent_keyerror_failures = ['tenant_id', 'project_id', ]
session = sql.get_session()
endpoints = (session.query(Endpoint).
options(sql.joinedload(Endpoint.service)).
filter(Endpoint.enabled == true()).all())
with sql.session_for_read() as session:
endpoints = (session.query(Endpoint).
options(sql.joinedload(Endpoint.service)).
filter(Endpoint.enabled == true()).all())
catalog = {}
catalog = {}
for endpoint in endpoints:
if not endpoint.service['enabled']:
continue
try:
formatted_url = core.format_url(
endpoint['url'], substitutions,
silent_keyerror_failures=silent_keyerror_failures)
if formatted_url is not None:
url = formatted_url
else:
for endpoint in endpoints:
if not endpoint.service['enabled']:
continue
except exception.MalformedEndpoint:
continue # this failure is already logged in format_url()
try:
formatted_url = core.format_url(
endpoint['url'], substitutions,
silent_keyerror_failures=silent_keyerror_failures)
if formatted_url is not None:
url = formatted_url
else:
continue
except exception.MalformedEndpoint:
continue # this failure is already logged in format_url()
region = endpoint['region_id']
service_type = endpoint.service['type']
default_service = {
'id': endpoint['id'],
'name': endpoint.service.extra.get('name', ''),
'publicURL': ''
}
catalog.setdefault(region, {})
catalog[region].setdefault(service_type, default_service)
interface_url = '%sURL' % endpoint['interface']
catalog[region][service_type][interface_url] = url
region = endpoint['region_id']
service_type = endpoint.service['type']
default_service = {
'id': endpoint['id'],
'name': endpoint.service.extra.get('name', ''),
'publicURL': ''
}
catalog.setdefault(region, {})
catalog[region].setdefault(service_type, default_service)
interface_url = '%sURL' % endpoint['interface']
catalog[region][service_type][interface_url] = url
return catalog
return catalog
def get_v3_catalog(self, user_id, tenant_id):
"""Retrieve and format the current V3 service catalog.
@ -349,44 +338,46 @@ class Catalog(catalog.CatalogDriverV8):
else:
silent_keyerror_failures = ['tenant_id', 'project_id', ]
session = sql.get_session()
services = (session.query(Service).filter(Service.enabled == true()).
options(sql.joinedload(Service.endpoints)).
all())
with sql.session_for_read() as session:
services = (session.query(Service).filter(
Service.enabled == true()).options(
sql.joinedload(Service.endpoints)).all())
def make_v3_endpoints(endpoints):
for endpoint in (ep.to_dict() for ep in endpoints if ep.enabled):
del endpoint['service_id']
del endpoint['legacy_endpoint_id']
del endpoint['enabled']
endpoint['region'] = endpoint['region_id']
try:
formatted_url = core.format_url(
endpoint['url'], d,
silent_keyerror_failures=silent_keyerror_failures)
if formatted_url:
endpoint['url'] = formatted_url
else:
def make_v3_endpoints(endpoints):
for endpoint in (ep.to_dict()
for ep in endpoints if ep.enabled):
del endpoint['service_id']
del endpoint['legacy_endpoint_id']
del endpoint['enabled']
endpoint['region'] = endpoint['region_id']
try:
formatted_url = core.format_url(
endpoint['url'], d,
silent_keyerror_failures=silent_keyerror_failures)
if formatted_url:
endpoint['url'] = formatted_url
else:
continue
except exception.MalformedEndpoint:
# this failure is already logged in format_url()
continue
except exception.MalformedEndpoint:
continue # this failure is already logged in format_url()
yield endpoint
yield endpoint
# TODO(davechen): If there is service with no endpoints, we should skip
# the service instead of keeping it in the catalog, see bug #1436704.
def make_v3_service(svc):
eps = list(make_v3_endpoints(svc.endpoints))
service = {'endpoints': eps, 'id': svc.id, 'type': svc.type}
service['name'] = svc.extra.get('name', '')
return service
# TODO(davechen): If there is service with no endpoints, we should
# skip the service instead of keeping it in the catalog,
# see bug #1436704.
def make_v3_service(svc):
eps = list(make_v3_endpoints(svc.endpoints))
service = {'endpoints': eps, 'id': svc.id, 'type': svc.type}
service['name'] = svc.extra.get('name', '')
return service
return [make_v3_service(svc) for svc in services]
return [make_v3_service(svc) for svc in services]
@sql.handle_conflicts(conflict_type='project_endpoint')
def add_endpoint_to_project(self, endpoint_id, project_id):
session = sql.get_session()
with session.begin():
with sql.session_for_write() as session:
endpoint_filter_ref = ProjectEndpoint(endpoint_id=endpoint_id,
project_id=project_id)
session.add(endpoint_filter_ref)
@ -402,50 +393,46 @@ class Catalog(catalog.CatalogDriverV8):
return endpoint_filter_ref
def check_endpoint_in_project(self, endpoint_id, project_id):
session = sql.get_session()
self._get_project_endpoint_ref(session, endpoint_id, project_id)
with sql.session_for_read() as session:
self._get_project_endpoint_ref(session, endpoint_id, project_id)
def remove_endpoint_from_project(self, endpoint_id, project_id):
session = sql.get_session()
endpoint_filter_ref = self._get_project_endpoint_ref(
session, endpoint_id, project_id)
with session.begin():
with sql.session_for_write() as session:
endpoint_filter_ref = self._get_project_endpoint_ref(
session, endpoint_id, project_id)
session.delete(endpoint_filter_ref)
def list_endpoints_for_project(self, project_id):
session = sql.get_session()
query = session.query(ProjectEndpoint)
query = query.filter_by(project_id=project_id)
endpoint_filter_refs = query.all()
return [ref.to_dict() for ref in endpoint_filter_refs]
with sql.session_for_read() as session:
query = session.query(ProjectEndpoint)
query = query.filter_by(project_id=project_id)
endpoint_filter_refs = query.all()
return [ref.to_dict() for ref in endpoint_filter_refs]
def list_projects_for_endpoint(self, endpoint_id):
session = sql.get_session()
query = session.query(ProjectEndpoint)
query = query.filter_by(endpoint_id=endpoint_id)
endpoint_filter_refs = query.all()
return [ref.to_dict() for ref in endpoint_filter_refs]
with sql.session_for_read() as session:
query = session.query(ProjectEndpoint)
query = query.filter_by(endpoint_id=endpoint_id)
endpoint_filter_refs = query.all()
return [ref.to_dict() for ref in endpoint_filter_refs]
def delete_association_by_endpoint(self, endpoint_id):
session = sql.get_session()
with session.begin():
with sql.session_for_write() as session:
query = session.query(ProjectEndpoint)
query = query.filter_by(endpoint_id=endpoint_id)
query.delete(synchronize_session=False)
def delete_association_by_project(self, project_id):
session = sql.get_session()
with session.begin():
with sql.session_for_write() as session:
query = session.query(ProjectEndpoint)
query = query.filter_by(project_id=project_id)
query.delete(synchronize_session=False)
def create_endpoint_group(self, endpoint_group_id, endpoint_group):
session = sql.get_session()
with session.begin():
with sql.session_for_write() as session:
endpoint_group_ref = EndpointGroup.from_dict(endpoint_group)
session.add(endpoint_group_ref)
return endpoint_group_ref.to_dict()
return endpoint_group_ref.to_dict()
def _get_endpoint_group(self, session, endpoint_group_id):
endpoint_group_ref = session.query(EndpointGroup).get(
@ -456,14 +443,13 @@ class Catalog(catalog.CatalogDriverV8):
return endpoint_group_ref
def get_endpoint_group(self, endpoint_group_id):
session = sql.get_session()
endpoint_group_ref = self._get_endpoint_group(session,
endpoint_group_id)
return endpoint_group_ref.to_dict()
with sql.session_for_read() as session:
endpoint_group_ref = self._get_endpoint_group(session,
endpoint_group_id)
return endpoint_group_ref.to_dict()
def update_endpoint_group(self, endpoint_group_id, endpoint_group):
session = sql.get_session()
with session.begin():
with sql.session_for_write() as session:
endpoint_group_ref = self._get_endpoint_group(session,
endpoint_group_id)
old_endpoint_group = endpoint_group_ref.to_dict()
@ -472,29 +458,26 @@ class Catalog(catalog.CatalogDriverV8):
for attr in EndpointGroup.mutable_attributes:
setattr(endpoint_group_ref, attr,
getattr(new_endpoint_group, attr))
return endpoint_group_ref.to_dict()
return endpoint_group_ref.to_dict()
def delete_endpoint_group(self, endpoint_group_id):
session = sql.get_session()
endpoint_group_ref = self._get_endpoint_group(session,
endpoint_group_id)
with session.begin():
with sql.session_for_write() as session:
endpoint_group_ref = self._get_endpoint_group(session,
endpoint_group_id)
self._delete_endpoint_group_association_by_endpoint_group(
session, endpoint_group_id)
session.delete(endpoint_group_ref)
def get_endpoint_group_in_project(self, endpoint_group_id, project_id):
session = sql.get_session()
ref = self._get_endpoint_group_in_project(session,
endpoint_group_id,
project_id)
return ref.to_dict()
with sql.session_for_read() as session:
ref = self._get_endpoint_group_in_project(session,
endpoint_group_id,
project_id)
return ref.to_dict()
@sql.handle_conflicts(conflict_type='project_endpoint_group')
def add_endpoint_group_to_project(self, endpoint_group_id, project_id):
session = sql.get_session()
with session.begin():
with sql.session_for_write() as session:
# Create a new Project Endpoint group entity
endpoint_group_project_ref = ProjectEndpointGroupMembership(
endpoint_group_id=endpoint_group_id, project_id=project_id)
@ -512,32 +495,31 @@ class Catalog(catalog.CatalogDriverV8):
return endpoint_group_project_ref
def list_endpoint_groups(self):
session = sql.get_session()
query = session.query(EndpointGroup)
endpoint_group_refs = query.all()
return [e.to_dict() for e in endpoint_group_refs]
with sql.session_for_read() as session:
query = session.query(EndpointGroup)
endpoint_group_refs = query.all()
return [e.to_dict() for e in endpoint_group_refs]
def list_endpoint_groups_for_project(self, project_id):
session = sql.get_session()
query = session.query(ProjectEndpointGroupMembership)
query = query.filter_by(project_id=project_id)
endpoint_group_refs = query.all()
return [ref.to_dict() for ref in endpoint_group_refs]
with sql.session_for_read() as session:
query = session.query(ProjectEndpointGroupMembership)
query = query.filter_by(project_id=project_id)
endpoint_group_refs = query.all()
return [ref.to_dict() for ref in endpoint_group_refs]
def remove_endpoint_group_from_project(self, endpoint_group_id,
project_id):
session = sql.get_session()
endpoint_group_project_ref = self._get_endpoint_group_in_project(
session, endpoint_group_id, project_id)
with session.begin():
with sql.session_for_write() as session:
endpoint_group_project_ref = self._get_endpoint_group_in_project(
session, endpoint_group_id, project_id)
session.delete(endpoint_group_project_ref)
def list_projects_associated_with_endpoint_group(self, endpoint_group_id):
session = sql.get_session()
query = session.query(ProjectEndpointGroupMembership)
query = query.filter_by(endpoint_group_id=endpoint_group_id)
endpoint_group_refs = query.all()
return [ref.to_dict() for ref in endpoint_group_refs]
with sql.session_for_read() as session:
query = session.query(ProjectEndpointGroupMembership)
query = query.filter_by(endpoint_group_id=endpoint_group_id)
endpoint_group_refs = query.all()
return [ref.to_dict() for ref in endpoint_group_refs]
def _delete_endpoint_group_association_by_endpoint_group(
self, session, endpoint_group_id):
@ -546,8 +528,7 @@ class Catalog(catalog.CatalogDriverV8):
query.delete()
def delete_endpoint_group_association_by_project(self, project_id):
session = sql.get_session()
with session.begin():
with sql.session_for_write() as session:
query = session.query(ProjectEndpointGroupMembership)
query = query.filter_by(project_id=project_id)
query.delete()

View File

@ -18,14 +18,14 @@ Before using this module, call initialize(). This has to be done before
CONF() because it sets up configuration options.
"""
import contextlib
import functools
import threading
from oslo_config import cfg
from oslo_db import exception as db_exception
from oslo_db import options as db_options
from oslo_db.sqlalchemy import enginefacade
from oslo_db.sqlalchemy import models
from oslo_db.sqlalchemy import session as db_session
from oslo_log import log
from oslo_serialization import jsonutils
import six
@ -166,38 +166,41 @@ class ModelDictMixin(object):
return {name: getattr(self, name) for name in names}
_engine_facade = None
_main_context_manager = None
def _get_engine_facade():
global _engine_facade
def _get_main_context_manager():
global _main_context_manager
if not _engine_facade:
_engine_facade = db_session.EngineFacade.from_config(CONF)
if not _main_context_manager:
_main_context_manager = enginefacade.transaction_context()
return _engine_facade
return _main_context_manager
def cleanup():
global _engine_facade
global _main_context_manager
_engine_facade = None
_main_context_manager = None
def get_engine():
return _get_engine_facade().get_engine()
return _get_main_context_manager().get_legacy_facade().get_engine()
def get_session(expire_on_commit=False):
return _get_engine_facade().get_session(expire_on_commit=expire_on_commit)
def get_session():
return _get_main_context_manager().get_legacy_facade().get_session()
@contextlib.contextmanager
def transaction(expire_on_commit=False):
"""Return a SQLAlchemy session in a scoped transaction."""
session = get_session(expire_on_commit=expire_on_commit)
with session.begin():
yield session
_CONTEXT = threading.local()
def session_for_read():
return _get_main_context_manager().reader.using(_CONTEXT)
def session_for_write():
return _get_main_context_manager().writer.using(_CONTEXT)
def truncated(f):

View File

@ -178,7 +178,7 @@ def _sync_extension_repo(extension, version):
try:
abs_path = find_migrate_repo(package)
try:
migration.db_version_control(sql.get_engine(), abs_path)
migration.db_version_control(engine, abs_path)
# Register the repo with the version control API
# If it already knows about the repo, it will throw
# an exception that we can safely ignore

View File

@ -36,28 +36,27 @@ class Credential(credential.CredentialDriverV8):
@sql.handle_conflicts(conflict_type='credential')
def create_credential(self, credential_id, credential):
session = sql.get_session()
with session.begin():
with sql.session_for_write() as session:
ref = CredentialModel.from_dict(credential)
session.add(ref)
return ref.to_dict()
return ref.to_dict()
@driver_hints.truncated
def list_credentials(self, hints):
session = sql.get_session()
credentials = session.query(CredentialModel)
credentials = sql.filter_limit_query(CredentialModel,
credentials, hints)
return [s.to_dict() for s in credentials]
with sql.session_for_read() as session:
credentials = session.query(CredentialModel)
credentials = sql.filter_limit_query(CredentialModel,
credentials, hints)
return [s.to_dict() for s in credentials]
def list_credentials_for_user(self, user_id, type=None):
session = sql.get_session()
query = session.query(CredentialModel)
query = query.filter_by(user_id=user_id)
if type:
query = query.filter_by(type=type)
refs = query.all()
return [ref.to_dict() for ref in refs]
with sql.session_for_read() as session:
query = session.query(CredentialModel)
query = query.filter_by(user_id=user_id)
if type:
query = query.filter_by(type=type)
refs = query.all()
return [ref.to_dict() for ref in refs]
def _get_credential(self, session, credential_id):
ref = session.query(CredentialModel).get(credential_id)
@ -66,13 +65,12 @@ class Credential(credential.CredentialDriverV8):
return ref
def get_credential(self, credential_id):
session = sql.get_session()
return self._get_credential(session, credential_id).to_dict()
with sql.session_for_read() as session:
return self._get_credential(session, credential_id).to_dict()
@sql.handle_conflicts(conflict_type='credential')
def update_credential(self, credential_id, credential):
session = sql.get_session()
with session.begin():
with sql.session_for_write() as session:
ref = self._get_credential(session, credential_id)
old_dict = ref.to_dict()
for k in credential:
@ -82,27 +80,21 @@ class Credential(credential.CredentialDriverV8):
if attr != 'id':
setattr(ref, attr, getattr(new_credential, attr))
ref.extra = new_credential.extra
return ref.to_dict()
return ref.to_dict()
def delete_credential(self, credential_id):
session = sql.get_session()
with session.begin():
with sql.session_for_write() as session:
ref = self._get_credential(session, credential_id)
session.delete(ref)
def delete_credentials_for_project(self, project_id):
session = sql.get_session()
with session.begin():
with sql.session_for_write() as session:
query = session.query(CredentialModel)
query = query.filter_by(project_id=project_id)
query.delete()
def delete_credentials_for_user(self, user_id):
session = sql.get_session()
with session.begin():
with sql.session_for_write() as session:
query = session.query(CredentialModel)
query = query.filter_by(user_id=user_id)
query.delete()

View File

@ -51,7 +51,7 @@ class EndpointPolicy(object):
def create_policy_association(self, policy_id, endpoint_id=None,
service_id=None, region_id=None):
with sql.transaction() as session:
with sql.session_for_write() as session:
try:
# See if there is already a row for this association, and if
# so, update it with the new policy_id
@ -79,14 +79,14 @@ class EndpointPolicy(object):
# NOTE(henry-nash): Getting a single value to save object
# management overhead.
with sql.transaction() as session:
with sql.session_for_read() as session:
if session.query(PolicyAssociation.id).filter(
sql_constraints).distinct().count() == 0:
raise exception.PolicyAssociationNotFound()
def delete_policy_association(self, policy_id, endpoint_id=None,
service_id=None, region_id=None):
with sql.transaction() as session:
with sql.session_for_write() as session:
query = session.query(PolicyAssociation)
query = query.filter_by(policy_id=policy_id)
query = query.filter_by(endpoint_id=endpoint_id)
@ -102,7 +102,7 @@ class EndpointPolicy(object):
PolicyAssociation.region_id == region_id)
try:
with sql.transaction() as session:
with sql.session_for_read() as session:
policy_id = session.query(PolicyAssociation.policy_id).filter(
sql_constraints).distinct().one()
return {'policy_id': policy_id}
@ -110,31 +110,31 @@ class EndpointPolicy(object):
raise exception.PolicyAssociationNotFound()
def list_associations_for_policy(self, policy_id):
with sql.transaction() as session:
with sql.session_for_read() as session:
query = session.query(PolicyAssociation)
query = query.filter_by(policy_id=policy_id)
return [ref.to_dict() for ref in query.all()]
def delete_association_by_endpoint(self, endpoint_id):
with sql.transaction() as session:
with sql.session_for_write() as session:
query = session.query(PolicyAssociation)
query = query.filter_by(endpoint_id=endpoint_id)
query.delete()
def delete_association_by_service(self, service_id):
with sql.transaction() as session:
with sql.session_for_write() as session:
query = session.query(PolicyAssociation)
query = query.filter_by(service_id=service_id)
query.delete()
def delete_association_by_region(self, region_id):
with sql.transaction() as session:
with sql.session_for_write() as session:
query = session.query(PolicyAssociation)
query = query.filter_by(region_id=region_id)
query.delete()
def delete_association_by_policy(self, policy_id):
with sql.transaction() as session:
with sql.session_for_write() as session:
query = session.query(PolicyAssociation)
query = query.filter_by(policy_id=policy_id)
query.delete()

View File

@ -161,13 +161,13 @@ class Federation(core.FederationDriverV8):
@sql.handle_conflicts(conflict_type='identity_provider')
def create_idp(self, idp_id, idp):
idp['id'] = idp_id
with sql.transaction() as session:
with sql.session_for_write() as session:
idp_ref = IdentityProviderModel.from_dict(idp)
session.add(idp_ref)
return idp_ref.to_dict()
return idp_ref.to_dict()
def delete_idp(self, idp_id):
with sql.transaction() as session:
with sql.session_for_write() as session:
self._delete_assigned_protocols(session, idp_id)
idp_ref = self._get_idp(session, idp_id)
session.delete(idp_ref)
@ -187,30 +187,30 @@ class Federation(core.FederationDriverV8):
raise exception.IdentityProviderNotFound(idp_id=remote_id)
def list_idps(self):
with sql.transaction() as session:
with sql.session_for_read() as session:
idps = session.query(IdentityProviderModel)
idps_list = [idp.to_dict() for idp in idps]
return idps_list
idps_list = [idp.to_dict() for idp in idps]
return idps_list
def get_idp(self, idp_id):
with sql.transaction() as session:
with sql.session_for_read() as session:
idp_ref = self._get_idp(session, idp_id)
return idp_ref.to_dict()
return idp_ref.to_dict()
def get_idp_from_remote_id(self, remote_id):
with sql.transaction() as session:
with sql.session_for_read() as session:
ref = self._get_idp_from_remote_id(session, remote_id)
return ref.to_dict()
return ref.to_dict()
def update_idp(self, idp_id, idp):
with sql.transaction() as session:
with sql.session_for_write() as session:
idp_ref = self._get_idp(session, idp_id)
old_idp = idp_ref.to_dict()
old_idp.update(idp)
new_idp = IdentityProviderModel.from_dict(old_idp)
for attr in IdentityProviderModel.mutable_attributes:
setattr(idp_ref, attr, getattr(new_idp, attr))
return idp_ref.to_dict()
return idp_ref.to_dict()
# Protocol CRUD
def _get_protocol(self, session, idp_id, protocol_id):
@ -227,36 +227,36 @@ class Federation(core.FederationDriverV8):
def create_protocol(self, idp_id, protocol_id, protocol):
protocol['id'] = protocol_id
protocol['idp_id'] = idp_id
with sql.transaction() as session:
with sql.session_for_write() as session:
self._get_idp(session, idp_id)
protocol_ref = FederationProtocolModel.from_dict(protocol)
session.add(protocol_ref)
return protocol_ref.to_dict()
return protocol_ref.to_dict()
def update_protocol(self, idp_id, protocol_id, protocol):
with sql.transaction() as session:
with sql.session_for_write() as session:
proto_ref = self._get_protocol(session, idp_id, protocol_id)
old_proto = proto_ref.to_dict()
old_proto.update(protocol)
new_proto = FederationProtocolModel.from_dict(old_proto)
for attr in FederationProtocolModel.mutable_attributes:
setattr(proto_ref, attr, getattr(new_proto, attr))
return proto_ref.to_dict()
return proto_ref.to_dict()
def get_protocol(self, idp_id, protocol_id):
with sql.transaction() as session:
with sql.session_for_read() as session:
protocol_ref = self._get_protocol(session, idp_id, protocol_id)
return protocol_ref.to_dict()
return protocol_ref.to_dict()
def list_protocols(self, idp_id):
with sql.transaction() as session:
with sql.session_for_read() as session:
q = session.query(FederationProtocolModel)
q = q.filter_by(idp_id=idp_id)
protocols = [protocol.to_dict() for protocol in q]
return protocols
protocols = [protocol.to_dict() for protocol in q]
return protocols
def delete_protocol(self, idp_id, protocol_id):
with sql.transaction() as session:
with sql.session_for_write() as session:
key_ref = self._get_protocol(session, idp_id, protocol_id)
session.delete(key_ref)
@ -277,58 +277,58 @@ class Federation(core.FederationDriverV8):
ref = {}
ref['id'] = mapping_id
ref['rules'] = mapping.get('rules')
with sql.transaction() as session:
with sql.session_for_write() as session:
mapping_ref = MappingModel.from_dict(ref)
session.add(mapping_ref)
return mapping_ref.to_dict()
return mapping_ref.to_dict()
def delete_mapping(self, mapping_id):
with sql.transaction() as session:
with sql.session_for_write() as session:
mapping_ref = self._get_mapping(session, mapping_id)
session.delete(mapping_ref)
def list_mappings(self):
with sql.transaction() as session:
with sql.session_for_read() as session:
mappings = session.query(MappingModel)
return [x.to_dict() for x in mappings]
return [x.to_dict() for x in mappings]
def get_mapping(self, mapping_id):
with sql.transaction() as session:
with sql.session_for_read() as session:
mapping_ref = self._get_mapping(session, mapping_id)
return mapping_ref.to_dict()
return mapping_ref.to_dict()
@sql.handle_conflicts(conflict_type='mapping')
def update_mapping(self, mapping_id, mapping):
ref = {}
ref['id'] = mapping_id
ref['rules'] = mapping.get('rules')
with sql.transaction() as session:
with sql.session_for_write() as session:
mapping_ref = self._get_mapping(session, mapping_id)
old_mapping = mapping_ref.to_dict()
old_mapping.update(ref)
new_mapping = MappingModel.from_dict(old_mapping)
for attr in MappingModel.attributes:
setattr(mapping_ref, attr, getattr(new_mapping, attr))
return mapping_ref.to_dict()
return mapping_ref.to_dict()
def get_mapping_from_idp_and_protocol(self, idp_id, protocol_id):
with sql.transaction() as session:
with sql.session_for_read() as session:
protocol_ref = self._get_protocol(session, idp_id, protocol_id)
mapping_id = protocol_ref.mapping_id
mapping_ref = self._get_mapping(session, mapping_id)
return mapping_ref.to_dict()
return mapping_ref.to_dict()
# Service Provider CRUD
@sql.handle_conflicts(conflict_type='service_provider')
def create_sp(self, sp_id, sp):
sp['id'] = sp_id
with sql.transaction() as session:
with sql.session_for_write() as session:
sp_ref = ServiceProviderModel.from_dict(sp)
session.add(sp_ref)
return sp_ref.to_dict()
return sp_ref.to_dict()
def delete_sp(self, sp_id):
with sql.transaction() as session:
with sql.session_for_write() as session:
sp_ref = self._get_sp(session, sp_id)
session.delete(sp_ref)
@ -339,28 +339,28 @@ class Federation(core.FederationDriverV8):
return sp_ref
def list_sps(self):
with sql.transaction() as session:
with sql.session_for_read() as session:
sps = session.query(ServiceProviderModel)
sps_list = [sp.to_dict() for sp in sps]
return sps_list
sps_list = [sp.to_dict() for sp in sps]
return sps_list
def get_sp(self, sp_id):
with sql.transaction() as session:
with sql.session_for_read() as session:
sp_ref = self._get_sp(session, sp_id)
return sp_ref.to_dict()
return sp_ref.to_dict()
def update_sp(self, sp_id, sp):
with sql.transaction() as session:
with sql.session_for_write() as session:
sp_ref = self._get_sp(session, sp_id)
old_sp = sp_ref.to_dict()
old_sp.update(sp)
new_sp = ServiceProviderModel.from_dict(old_sp)
for attr in ServiceProviderModel.mutable_attributes:
setattr(sp_ref, attr, getattr(new_sp, attr))
return sp_ref.to_dict()
return sp_ref.to_dict()
def get_enabled_service_providers(self):
with sql.transaction() as session:
with sql.session_for_read() as session:
service_providers = session.query(ServiceProviderModel)
service_providers = service_providers.filter_by(enabled=True)
return service_providers
return service_providers

View File

@ -169,10 +169,10 @@ class Federation(core.FederationDriverV9):
def create_idp(self, idp_id, idp):
idp['id'] = idp_id
try:
with sql.transaction() as session:
with sql.session_for_write() as session:
idp_ref = IdentityProviderModel.from_dict(idp)
session.add(idp_ref)
return idp_ref.to_dict()
return idp_ref.to_dict()
except sql.DBDuplicateEntry as e:
conflict_type = 'identity_provider'
details = six.text_type(e)
@ -186,7 +186,7 @@ class Federation(core.FederationDriverV9):
raise exception.Conflict(type=conflict_type, details=msg)
def delete_idp(self, idp_id):
with sql.transaction() as session:
with sql.session_for_write() as session:
self._delete_assigned_protocols(session, idp_id)
idp_ref = self._get_idp(session, idp_id)
session.delete(idp_ref)
@ -206,31 +206,31 @@ class Federation(core.FederationDriverV9):
raise exception.IdentityProviderNotFound(idp_id=remote_id)
def list_idps(self, hints=None):
with sql.transaction() as session:
with sql.session_for_read() as session:
query = session.query(IdentityProviderModel)
idps = sql.filter_limit_query(IdentityProviderModel, query, hints)
idps_list = [idp.to_dict() for idp in idps]
return idps_list
idps_list = [idp.to_dict() for idp in idps]
return idps_list
def get_idp(self, idp_id):
with sql.transaction() as session:
with sql.session_for_read() as session:
idp_ref = self._get_idp(session, idp_id)
return idp_ref.to_dict()
return idp_ref.to_dict()
def get_idp_from_remote_id(self, remote_id):
with sql.transaction() as session:
with sql.session_for_read() as session:
ref = self._get_idp_from_remote_id(session, remote_id)
return ref.to_dict()
return ref.to_dict()
def update_idp(self, idp_id, idp):
with sql.transaction() as session:
with sql.session_for_write() as session:
idp_ref = self._get_idp(session, idp_id)
old_idp = idp_ref.to_dict()
old_idp.update(idp)
new_idp = IdentityProviderModel.from_dict(old_idp)
for attr in IdentityProviderModel.mutable_attributes:
setattr(idp_ref, attr, getattr(new_idp, attr))
return idp_ref.to_dict()
return idp_ref.to_dict()
# Protocol CRUD
def _get_protocol(self, session, idp_id, protocol_id):
@ -247,36 +247,36 @@ class Federation(core.FederationDriverV9):
def create_protocol(self, idp_id, protocol_id, protocol):
protocol['id'] = protocol_id
protocol['idp_id'] = idp_id
with sql.transaction() as session:
with sql.session_for_write() as session:
self._get_idp(session, idp_id)
protocol_ref = FederationProtocolModel.from_dict(protocol)
session.add(protocol_ref)
return protocol_ref.to_dict()
return protocol_ref.to_dict()
def update_protocol(self, idp_id, protocol_id, protocol):
with sql.transaction() as session:
with sql.session_for_write() as session:
proto_ref = self._get_protocol(session, idp_id, protocol_id)
old_proto = proto_ref.to_dict()
old_proto.update(protocol)
new_proto = FederationProtocolModel.from_dict(old_proto)
for attr in FederationProtocolModel.mutable_attributes:
setattr(proto_ref, attr, getattr(new_proto, attr))
return proto_ref.to_dict()
return proto_ref.to_dict()
def get_protocol(self, idp_id, protocol_id):
with sql.transaction() as session:
with sql.session_for_read() as session:
protocol_ref = self._get_protocol(session, idp_id, protocol_id)
return protocol_ref.to_dict()
return protocol_ref.to_dict()
def list_protocols(self, idp_id):
with sql.transaction() as session:
with sql.session_for_read() as session:
q = session.query(FederationProtocolModel)
q = q.filter_by(idp_id=idp_id)
protocols = [protocol.to_dict() for protocol in q]
return protocols
protocols = [protocol.to_dict() for protocol in q]
return protocols
def delete_protocol(self, idp_id, protocol_id):
with sql.transaction() as session:
with sql.session_for_write() as session:
key_ref = self._get_protocol(session, idp_id, protocol_id)
session.delete(key_ref)
@ -297,58 +297,58 @@ class Federation(core.FederationDriverV9):
ref = {}
ref['id'] = mapping_id
ref['rules'] = mapping.get('rules')
with sql.transaction() as session:
with sql.session_for_write() as session:
mapping_ref = MappingModel.from_dict(ref)
session.add(mapping_ref)
return mapping_ref.to_dict()
return mapping_ref.to_dict()
def delete_mapping(self, mapping_id):
with sql.transaction() as session:
with sql.session_for_write() as session:
mapping_ref = self._get_mapping(session, mapping_id)
session.delete(mapping_ref)
def list_mappings(self):
with sql.transaction() as session:
with sql.session_for_read() as session:
mappings = session.query(MappingModel)
return [x.to_dict() for x in mappings]
return [x.to_dict() for x in mappings]
def get_mapping(self, mapping_id):
with sql.transaction() as session:
with sql.session_for_read() as session:
mapping_ref = self._get_mapping(session, mapping_id)
return mapping_ref.to_dict()
return mapping_ref.to_dict()
@sql.handle_conflicts(conflict_type='mapping')
def update_mapping(self, mapping_id, mapping):
ref = {}
ref['id'] = mapping_id
ref['rules'] = mapping.get('rules')
with sql.transaction() as session:
with sql.session_for_write() as session:
mapping_ref = self._get_mapping(session, mapping_id)
old_mapping = mapping_ref.to_dict()
old_mapping.update(ref)
new_mapping = MappingModel.from_dict(old_mapping)
for attr in MappingModel.attributes:
setattr(mapping_ref, attr, getattr(new_mapping, attr))
return mapping_ref.to_dict()
return mapping_ref.to_dict()
def get_mapping_from_idp_and_protocol(self, idp_id, protocol_id):
with sql.transaction() as session:
with sql.session_for_read() as session:
protocol_ref = self._get_protocol(session, idp_id, protocol_id)
mapping_id = protocol_ref.mapping_id
mapping_ref = self._get_mapping(session, mapping_id)
return mapping_ref.to_dict()
return mapping_ref.to_dict()
# Service Provider CRUD
@sql.handle_conflicts(conflict_type='service_provider')
def create_sp(self, sp_id, sp):
sp['id'] = sp_id
with sql.transaction() as session:
with sql.session_for_write() as session:
sp_ref = ServiceProviderModel.from_dict(sp)
session.add(sp_ref)
return sp_ref.to_dict()
return sp_ref.to_dict()
def delete_sp(self, sp_id):
with sql.transaction() as session:
with sql.session_for_write() as session:
sp_ref = self._get_sp(session, sp_id)
session.delete(sp_ref)
@ -359,28 +359,28 @@ class Federation(core.FederationDriverV9):
return sp_ref
def list_sps(self):
with sql.transaction() as session:
with sql.session_for_read() as session:
sps = session.query(ServiceProviderModel)
sps_list = [sp.to_dict() for sp in sps]
return sps_list
sps_list = [sp.to_dict() for sp in sps]
return sps_list
def get_sp(self, sp_id):
with sql.transaction() as session:
with sql.session_for_read() as session:
sp_ref = self._get_sp(session, sp_id)
return sp_ref.to_dict()
return sp_ref.to_dict()
def update_sp(self, sp_id, sp):
with sql.transaction() as session:
with sql.session_for_write() as session:
sp_ref = self._get_sp(session, sp_id)
old_sp = sp_ref.to_dict()
old_sp.update(sp)
new_sp = ServiceProviderModel.from_dict(old_sp)
for attr in ServiceProviderModel.mutable_attributes:
setattr(sp_ref, attr, getattr(new_sp, attr))
return sp_ref.to_dict()
return sp_ref.to_dict()
def get_enabled_service_providers(self):
with sql.transaction() as session:
with sql.session_for_read() as session:
service_providers = session.query(ServiceProviderModel)
service_providers = service_providers.filter_by(enabled=True)
return service_providers
return service_providers

View File

@ -178,33 +178,32 @@ class Identity(identity.IdentityDriverV8):
# Identity interface
def authenticate(self, user_id, password):
session = sql.get_session()
user_ref = None
try:
user_ref = self._get_user(session, user_id)
except exception.UserNotFound:
raise AssertionError(_('Invalid user / password'))
if not self._check_password(password, user_ref):
raise AssertionError(_('Invalid user / password'))
return identity.filter_user(user_ref.to_dict())
with sql.session_for_read() as session:
user_ref = None
try:
user_ref = self._get_user(session, user_id)
except exception.UserNotFound:
raise AssertionError(_('Invalid user / password'))
if not self._check_password(password, user_ref):
raise AssertionError(_('Invalid user / password'))
return identity.filter_user(user_ref.to_dict())
# user crud
@sql.handle_conflicts(conflict_type='user')
def create_user(self, user_id, user):
user = utils.hash_user_password(user)
session = sql.get_session()
with session.begin():
with sql.session_for_write() as session:
user_ref = User.from_dict(user)
session.add(user_ref)
return identity.filter_user(user_ref.to_dict())
return identity.filter_user(user_ref.to_dict())
@driver_hints.truncated
def list_users(self, hints):
session = sql.get_session()
query = session.query(User).outerjoin(LocalUser)
user_refs = sql.filter_limit_query(User, query, hints)
return [identity.filter_user(x.to_dict()) for x in user_refs]
with sql.session_for_read() as session:
query = session.query(User).outerjoin(LocalUser)
user_refs = sql.filter_limit_query(User, query, hints)
return [identity.filter_user(x.to_dict()) for x in user_refs]
def _get_user(self, session, user_id):
user_ref = session.query(User).get(user_id)
@ -213,25 +212,24 @@ class Identity(identity.IdentityDriverV8):
return user_ref
def get_user(self, user_id):
session = sql.get_session()
return identity.filter_user(self._get_user(session, user_id).to_dict())
with sql.session_for_read() as session:
return identity.filter_user(
self._get_user(session, user_id).to_dict())
def get_user_by_name(self, user_name, domain_id):
session = sql.get_session()
query = session.query(User).join(LocalUser)
query = query.filter(and_(LocalUser.name == user_name,
LocalUser.domain_id == domain_id))
try:
user_ref = query.one()
except sql.NotFound:
raise exception.UserNotFound(user_id=user_name)
return identity.filter_user(user_ref.to_dict())
with sql.session_for_read() as session:
query = session.query(User).join(LocalUser)
query = query.filter(and_(LocalUser.name == user_name,
LocalUser.domain_id == domain_id))
try:
user_ref = query.one()
except sql.NotFound:
raise exception.UserNotFound(user_id=user_name)
return identity.filter_user(user_ref.to_dict())
@sql.handle_conflicts(conflict_type='user')
def update_user(self, user_id, user):
session = sql.get_session()
with session.begin():
with sql.session_for_write() as session:
user_ref = self._get_user(session, user_id)
old_user_dict = user_ref.to_dict()
user = utils.hash_user_password(user)
@ -242,77 +240,74 @@ class Identity(identity.IdentityDriverV8):
if attr != 'id':
setattr(user_ref, attr, getattr(new_user, attr))
user_ref.extra = new_user.extra
return identity.filter_user(user_ref.to_dict(include_extra_dict=True))
return identity.filter_user(
user_ref.to_dict(include_extra_dict=True))
def add_user_to_group(self, user_id, group_id):
session = sql.get_session()
self.get_group(group_id)
self.get_user(user_id)
query = session.query(UserGroupMembership)
query = query.filter_by(user_id=user_id)
query = query.filter_by(group_id=group_id)
rv = query.first()
if rv:
return
with sql.session_for_write() as session:
self.get_group(group_id)
self.get_user(user_id)
query = session.query(UserGroupMembership)
query = query.filter_by(user_id=user_id)
query = query.filter_by(group_id=group_id)
rv = query.first()
if rv:
return
with session.begin():
session.add(UserGroupMembership(user_id=user_id,
group_id=group_id))
def check_user_in_group(self, user_id, group_id):
session = sql.get_session()
self.get_group(group_id)
self.get_user(user_id)
query = session.query(UserGroupMembership)
query = query.filter_by(user_id=user_id)
query = query.filter_by(group_id=group_id)
if not query.first():
raise exception.NotFound(_("User '%(user_id)s' not found in"
" group '%(group_id)s'") %
{'user_id': user_id,
'group_id': group_id})
def remove_user_from_group(self, user_id, group_id):
session = sql.get_session()
# We don't check if user or group are still valid and let the remove
# be tried anyway - in case this is some kind of clean-up operation
query = session.query(UserGroupMembership)
query = query.filter_by(user_id=user_id)
query = query.filter_by(group_id=group_id)
membership_ref = query.first()
if membership_ref is None:
# Check if the group and user exist to return descriptive
# exceptions.
with sql.session_for_read() as session:
self.get_group(group_id)
self.get_user(user_id)
raise exception.NotFound(_("User '%(user_id)s' not found in"
" group '%(group_id)s'") %
{'user_id': user_id,
'group_id': group_id})
with session.begin():
query = session.query(UserGroupMembership)
query = query.filter_by(user_id=user_id)
query = query.filter_by(group_id=group_id)
if not query.first():
raise exception.NotFound(_("User '%(user_id)s' not found in"
" group '%(group_id)s'") %
{'user_id': user_id,
'group_id': group_id})
def remove_user_from_group(self, user_id, group_id):
# We don't check if user or group are still valid and let the remove
# be tried anyway - in case this is some kind of clean-up operation
with sql.session_for_write() as session:
query = session.query(UserGroupMembership)
query = query.filter_by(user_id=user_id)
query = query.filter_by(group_id=group_id)
membership_ref = query.first()
if membership_ref is None:
# Check if the group and user exist to return descriptive
# exceptions.
self.get_group(group_id)
self.get_user(user_id)
raise exception.NotFound(_("User '%(user_id)s' not found in"
" group '%(group_id)s'") %
{'user_id': user_id,
'group_id': group_id})
session.delete(membership_ref)
def list_groups_for_user(self, user_id, hints):
session = sql.get_session()
self.get_user(user_id)
query = session.query(Group).join(UserGroupMembership)
query = query.filter(UserGroupMembership.user_id == user_id)
query = sql.filter_limit_query(Group, query, hints)
return [g.to_dict() for g in query]
with sql.session_for_read() as session:
self.get_user(user_id)
query = session.query(Group).join(UserGroupMembership)
query = query.filter(UserGroupMembership.user_id == user_id)
query = sql.filter_limit_query(Group, query, hints)
return [g.to_dict() for g in query]
def list_users_in_group(self, group_id, hints):
session = sql.get_session()
self.get_group(group_id)
query = session.query(User).outerjoin(LocalUser)
query = query.join(UserGroupMembership)
query = query.filter(UserGroupMembership.group_id == group_id)
query = sql.filter_limit_query(User, query, hints)
return [identity.filter_user(u.to_dict()) for u in query]
with sql.session_for_read() as session:
self.get_group(group_id)
query = session.query(User).outerjoin(LocalUser)
query = query.join(UserGroupMembership)
query = query.filter(UserGroupMembership.group_id == group_id)
query = sql.filter_limit_query(User, query, hints)
return [identity.filter_user(u.to_dict()) for u in query]
def delete_user(self, user_id):
session = sql.get_session()
with session.begin():
with sql.session_for_write() as session:
ref = self._get_user(session, user_id)
q = session.query(UserGroupMembership)
@ -325,18 +320,17 @@ class Identity(identity.IdentityDriverV8):
@sql.handle_conflicts(conflict_type='group')
def create_group(self, group_id, group):
session = sql.get_session()
with session.begin():
with sql.session_for_write() as session:
ref = Group.from_dict(group)
session.add(ref)
return ref.to_dict()
return ref.to_dict()
@driver_hints.truncated
def list_groups(self, hints):
session = sql.get_session()
query = session.query(Group)
refs = sql.filter_limit_query(Group, query, hints)
return [ref.to_dict() for ref in refs]
with sql.session_for_read() as session:
query = session.query(Group)
refs = sql.filter_limit_query(Group, query, hints)
return [ref.to_dict() for ref in refs]
def _get_group(self, session, group_id):
ref = session.query(Group).get(group_id)
@ -345,25 +339,23 @@ class Identity(identity.IdentityDriverV8):
return ref
def get_group(self, group_id):
session = sql.get_session()
return self._get_group(session, group_id).to_dict()
with sql.session_for_read() as session:
return self._get_group(session, group_id).to_dict()
def get_group_by_name(self, group_name, domain_id):
session = sql.get_session()
query = session.query(Group)
query = query.filter_by(name=group_name)
query = query.filter_by(domain_id=domain_id)
try:
group_ref = query.one()
except sql.NotFound:
raise exception.GroupNotFound(group_id=group_name)
return group_ref.to_dict()
with sql.session_for_read() as session:
query = session.query(Group)
query = query.filter_by(name=group_name)
query = query.filter_by(domain_id=domain_id)
try:
group_ref = query.one()
except sql.NotFound:
raise exception.GroupNotFound(group_id=group_name)
return group_ref.to_dict()
@sql.handle_conflicts(conflict_type='group')
def update_group(self, group_id, group):
session = sql.get_session()
with session.begin():
with sql.session_for_write() as session:
ref = self._get_group(session, group_id)
old_dict = ref.to_dict()
for k in group:
@ -373,12 +365,10 @@ class Identity(identity.IdentityDriverV8):
if attr != 'id':
setattr(ref, attr, getattr(new_group, attr))
ref.extra = new_group.extra
return ref.to_dict()
return ref.to_dict()
def delete_group(self, group_id):
session = sql.get_session()
with session.begin():
with sql.session_for_write() as session:
ref = self._get_group(session, group_id)
q = session.query(UserGroupMembership)

View File

@ -45,27 +45,27 @@ class Mapping(identity.MappingDriverV8):
# work if we hashed all the entries, even those that already generate
# UUIDs, like SQL. Further, this would only work if the generation
# algorithm was immutable (e.g. it had always been sha256).
session = sql.get_session()
query = session.query(IDMapping.public_id)
query = query.filter_by(domain_id=local_entity['domain_id'])
query = query.filter_by(local_id=local_entity['local_id'])
query = query.filter_by(entity_type=local_entity['entity_type'])
try:
public_ref = query.one()
public_id = public_ref.public_id
return public_id
except sql.NotFound:
return None
with sql.session_for_read() as session:
query = session.query(IDMapping.public_id)
query = query.filter_by(domain_id=local_entity['domain_id'])
query = query.filter_by(local_id=local_entity['local_id'])
query = query.filter_by(entity_type=local_entity['entity_type'])
try:
public_ref = query.one()
public_id = public_ref.public_id
return public_id
except sql.NotFound:
return None
def get_id_mapping(self, public_id):
session = sql.get_session()
mapping_ref = session.query(IDMapping).get(public_id)
if mapping_ref:
return mapping_ref.to_dict()
with sql.session_for_read() as session:
mapping_ref = session.query(IDMapping).get(public_id)
if mapping_ref:
return mapping_ref.to_dict()
def create_id_mapping(self, local_entity, public_id=None):
entity = local_entity.copy()
with sql.transaction() as session:
with sql.session_for_write() as session:
if public_id is None:
public_id = self.id_generator_api.generate_public_ID(entity)
entity['public_id'] = public_id
@ -74,7 +74,7 @@ class Mapping(identity.MappingDriverV8):
return public_id
def delete_id_mapping(self, public_id):
with sql.transaction() as session:
with sql.session_for_write() as session:
try:
session.query(IDMapping).filter(
IDMapping.public_id == public_id).delete()
@ -84,14 +84,15 @@ class Mapping(identity.MappingDriverV8):
pass
def purge_mappings(self, purge_filter):
session = sql.get_session()
query = session.query(IDMapping)
if 'domain_id' in purge_filter:
query = query.filter_by(domain_id=purge_filter['domain_id'])
if 'public_id' in purge_filter:
query = query.filter_by(public_id=purge_filter['public_id'])
if 'local_id' in purge_filter:
query = query.filter_by(local_id=purge_filter['local_id'])
if 'entity_type' in purge_filter:
query = query.filter_by(entity_type=purge_filter['entity_type'])
query.delete()
with sql.session_for_write() as session:
query = session.query(IDMapping)
if 'domain_id' in purge_filter:
query = query.filter_by(domain_id=purge_filter['domain_id'])
if 'public_id' in purge_filter:
query = query.filter_by(public_id=purge_filter['public_id'])
if 'local_id' in purge_filter:
query = query.filter_by(local_id=purge_filter['local_id'])
if 'entity_type' in purge_filter:
query = query.filter_by(
entity_type=purge_filter['entity_type'])
query.delete()

View File

@ -92,17 +92,16 @@ class OAuth1(core.Oauth1DriverV8):
return consumer_ref
def get_consumer_with_secret(self, consumer_id):
session = sql.get_session()
consumer_ref = self._get_consumer(session, consumer_id)
return consumer_ref.to_dict()
with sql.session_for_read() as session:
consumer_ref = self._get_consumer(session, consumer_id)
return consumer_ref.to_dict()
def get_consumer(self, consumer_id):
return core.filter_consumer(
self.get_consumer_with_secret(consumer_id))
def create_consumer(self, consumer_ref):
session = sql.get_session()
with session.begin():
with sql.session_for_write() as session:
consumer = Consumer.from_dict(consumer_ref)
session.add(consumer)
return consumer.to_dict()
@ -128,20 +127,18 @@ class OAuth1(core.Oauth1DriverV8):
session.delete(token_ref)
def delete_consumer(self, consumer_id):
session = sql.get_session()
with session.begin():
with sql.session_for_write() as session:
self._delete_request_tokens(session, consumer_id)
self._delete_access_tokens(session, consumer_id)
self._delete_consumer(session, consumer_id)
def list_consumers(self):
session = sql.get_session()
cons = session.query(Consumer)
return [core.filter_consumer(x.to_dict()) for x in cons]
with sql.session_for_read() as session:
cons = session.query(Consumer)
return [core.filter_consumer(x.to_dict()) for x in cons]
def update_consumer(self, consumer_id, consumer_ref):
session = sql.get_session()
with session.begin():
with sql.session_for_write() as session:
consumer = self._get_consumer(session, consumer_id)
old_consumer_dict = consumer.to_dict()
old_consumer_dict.update(consumer_ref)
@ -169,11 +166,10 @@ class OAuth1(core.Oauth1DriverV8):
ref['role_ids'] = None
ref['consumer_id'] = consumer_id
ref['expires_at'] = expiry_date
session = sql.get_session()
with session.begin():
with sql.session_for_write() as session:
token_ref = RequestToken.from_dict(ref)
session.add(token_ref)
return token_ref.to_dict()
return token_ref.to_dict()
def _get_request_token(self, session, request_token_id):
token_ref = session.query(RequestToken).get(request_token_id)
@ -182,14 +178,13 @@ class OAuth1(core.Oauth1DriverV8):
return token_ref
def get_request_token(self, request_token_id):
session = sql.get_session()
token_ref = self._get_request_token(session, request_token_id)
return token_ref.to_dict()
with sql.session_for_read() as session:
token_ref = self._get_request_token(session, request_token_id)
return token_ref.to_dict()
def authorize_request_token(self, request_token_id, user_id,
role_ids):
session = sql.get_session()
with session.begin():
with sql.session_for_write() as session:
token_ref = self._get_request_token(session, request_token_id)
token_dict = token_ref.to_dict()
token_dict['authorizing_user_id'] = user_id
@ -203,13 +198,12 @@ class OAuth1(core.Oauth1DriverV8):
or attr == 'role_ids'):
setattr(token_ref, attr, getattr(new_token, attr))
return token_ref.to_dict()
return token_ref.to_dict()
def create_access_token(self, request_id, access_token_duration):
access_token_id = uuid.uuid4().hex
access_token_secret = uuid.uuid4().hex
session = sql.get_session()
with session.begin():
with sql.session_for_write() as session:
req_token_ref = self._get_request_token(session, request_id)
token_dict = req_token_ref.to_dict()
@ -235,7 +229,7 @@ class OAuth1(core.Oauth1DriverV8):
# remove request token, it's been used
session.delete(req_token_ref)
return token_ref.to_dict()
return token_ref.to_dict()
def _get_access_token(self, session, access_token_id):
token_ref = session.query(AccessToken).get(access_token_id)
@ -244,19 +238,18 @@ class OAuth1(core.Oauth1DriverV8):
return token_ref
def get_access_token(self, access_token_id):
session = sql.get_session()
token_ref = self._get_access_token(session, access_token_id)
return token_ref.to_dict()
with sql.session_for_read() as session:
token_ref = self._get_access_token(session, access_token_id)
return token_ref.to_dict()
def list_access_tokens(self, user_id):
session = sql.get_session()
q = session.query(AccessToken)
user_auths = q.filter_by(authorizing_user_id=user_id)
return [core.filter_token(x.to_dict()) for x in user_auths]
with sql.session_for_read() as session:
q = session.query(AccessToken)
user_auths = q.filter_by(authorizing_user_id=user_id)
return [core.filter_token(x.to_dict()) for x in user_auths]
def delete_access_token(self, user_id, access_token_id):
session = sql.get_session()
with session.begin():
with sql.session_for_write() as session:
token_ref = self._get_access_token(session, access_token_id)
token_dict = token_ref.to_dict()
if token_dict['authorizing_user_id'] != user_id:

View File

@ -30,19 +30,16 @@ class Policy(rules.Policy):
@sql.handle_conflicts(conflict_type='policy')
def create_policy(self, policy_id, policy):
session = sql.get_session()
with session.begin():
with sql.session_for_write() as session:
ref = PolicyModel.from_dict(policy)
session.add(ref)
return ref.to_dict()
return ref.to_dict()
def list_policies(self):
session = sql.get_session()
refs = session.query(PolicyModel).all()
return [ref.to_dict() for ref in refs]
with sql.session_for_read() as session:
refs = session.query(PolicyModel).all()
return [ref.to_dict() for ref in refs]
def _get_policy(self, session, policy_id):
"""Private method to get a policy model object (NOT a dictionary)."""
@ -52,15 +49,12 @@ class Policy(rules.Policy):
return ref
def get_policy(self, policy_id):
session = sql.get_session()
return self._get_policy(session, policy_id).to_dict()
with sql.session_for_read() as session:
return self._get_policy(session, policy_id).to_dict()
@sql.handle_conflicts(conflict_type='policy')
def update_policy(self, policy_id, policy):
session = sql.get_session()
with session.begin():
with sql.session_for_write() as session:
ref = self._get_policy(session, policy_id)
old_dict = ref.to_dict()
old_dict.update(policy)
@ -72,8 +66,6 @@ class Policy(rules.Policy):
return ref.to_dict()
def delete_policy(self, policy_id):
session = sql.get_session()
with session.begin():
with sql.session_for_write() as session:
ref = self._get_policy(session, policy_id)
session.delete(ref)

View File

@ -35,11 +35,11 @@ class Resource(keystone_resource.ResourceDriverV8):
return project_ref
def get_project(self, tenant_id):
with sql.transaction() as session:
with sql.session_for_read() as session:
return self._get_project(session, tenant_id).to_dict()
def get_project_by_name(self, tenant_name, domain_id):
with sql.transaction() as session:
with sql.session_for_read() as session:
query = session.query(Project)
query = query.filter_by(name=tenant_name)
query = query.filter_by(domain_id=domain_id)
@ -51,7 +51,7 @@ class Resource(keystone_resource.ResourceDriverV8):
@driver_hints.truncated
def list_projects(self, hints):
with sql.transaction() as session:
with sql.session_for_read() as session:
query = session.query(Project)
project_refs = sql.filter_limit_query(Project, query, hints)
return [project_ref.to_dict() for project_ref in project_refs]
@ -60,7 +60,7 @@ class Resource(keystone_resource.ResourceDriverV8):
if not ids:
return []
else:
with sql.transaction() as session:
with sql.session_for_read() as session:
query = session.query(Project)
query = query.filter(Project.id.in_(ids))
return [project_ref.to_dict() for project_ref in query.all()]
@ -69,14 +69,14 @@ class Resource(keystone_resource.ResourceDriverV8):
if not domain_ids:
return []
else:
with sql.transaction() as session:
with sql.session_for_read() as session:
query = session.query(Project.id)
query = (
query.filter(Project.domain_id.in_(domain_ids)))
return [x.id for x in query.all()]
def list_projects_in_domain(self, domain_id):
with sql.transaction() as session:
with sql.session_for_read() as session:
self._get_domain(session, domain_id)
query = session.query(Project)
project_refs = query.filter_by(domain_id=domain_id)
@ -89,7 +89,7 @@ class Resource(keystone_resource.ResourceDriverV8):
return [project_ref.to_dict() for project_ref in project_refs]
def list_projects_in_subtree(self, project_id):
with sql.transaction() as session:
with sql.session_for_read() as session:
children = self._get_children(session, [project_id])
subtree = []
examined = set([project_id])
@ -110,7 +110,7 @@ class Resource(keystone_resource.ResourceDriverV8):
return subtree
def list_project_parents(self, project_id):
with sql.transaction() as session:
with sql.session_for_read() as session:
project = self._get_project(session, project_id).to_dict()
parents = []
examined = set()
@ -130,7 +130,7 @@ class Resource(keystone_resource.ResourceDriverV8):
return parents
def is_leaf_project(self, project_id):
with sql.transaction() as session:
with sql.session_for_read() as session:
project_refs = self._get_children(session, [project_id])
return not project_refs
@ -138,7 +138,7 @@ class Resource(keystone_resource.ResourceDriverV8):
@sql.handle_conflicts(conflict_type='project')
def create_project(self, tenant_id, tenant):
tenant['name'] = clean.project_name(tenant['name'])
with sql.transaction() as session:
with sql.session_for_write() as session:
tenant_ref = Project.from_dict(tenant)
session.add(tenant_ref)
return tenant_ref.to_dict()
@ -148,7 +148,7 @@ class Resource(keystone_resource.ResourceDriverV8):
if 'name' in tenant:
tenant['name'] = clean.project_name(tenant['name'])
with sql.transaction() as session:
with sql.session_for_write() as session:
tenant_ref = self._get_project(session, tenant_id)
old_project_dict = tenant_ref.to_dict()
for k in tenant:
@ -162,7 +162,7 @@ class Resource(keystone_resource.ResourceDriverV8):
@sql.handle_conflicts(conflict_type='project')
def delete_project(self, tenant_id):
with sql.transaction() as session:
with sql.session_for_write() as session:
tenant_ref = self._get_project(session, tenant_id)
session.delete(tenant_ref)
@ -170,14 +170,14 @@ class Resource(keystone_resource.ResourceDriverV8):
@sql.handle_conflicts(conflict_type='domain')
def create_domain(self, domain_id, domain):
with sql.transaction() as session:
with sql.session_for_write() as session:
ref = Domain.from_dict(domain)
session.add(ref)
return ref.to_dict()
@driver_hints.truncated
def list_domains(self, hints):
with sql.transaction() as session:
with sql.session_for_read() as session:
query = session.query(Domain)
refs = sql.filter_limit_query(Domain, query, hints)
return [ref.to_dict() for ref in refs]
@ -186,7 +186,7 @@ class Resource(keystone_resource.ResourceDriverV8):
if not ids:
return []
else:
with sql.transaction() as session:
with sql.session_for_read() as session:
query = session.query(Domain)
query = query.filter(Domain.id.in_(ids))
domain_refs = query.all()
@ -199,11 +199,11 @@ class Resource(keystone_resource.ResourceDriverV8):
return ref
def get_domain(self, domain_id):
with sql.transaction() as session:
with sql.session_for_read() as session:
return self._get_domain(session, domain_id).to_dict()
def get_domain_by_name(self, domain_name):
with sql.transaction() as session:
with sql.session_for_read() as session:
try:
ref = (session.query(Domain).
filter_by(name=domain_name).one())
@ -213,7 +213,7 @@ class Resource(keystone_resource.ResourceDriverV8):
@sql.handle_conflicts(conflict_type='domain')
def update_domain(self, domain_id, domain):
with sql.transaction() as session:
with sql.session_for_write() as session:
ref = self._get_domain(session, domain_id)
old_dict = ref.to_dict()
for k in domain:
@ -226,7 +226,7 @@ class Resource(keystone_resource.ResourceDriverV8):
return ref.to_dict()
def delete_domain(self, domain_id):
with sql.transaction() as session:
with sql.session_for_write() as session:
ref = self._get_domain(session, domain_id)
session.delete(ref)

View File

@ -38,11 +38,11 @@ class Resource(keystone_resource.ResourceDriverV9):
return project_ref
def get_project(self, project_id):
with sql.transaction() as session:
with sql.session_for_read() as session:
return self._get_project(session, project_id).to_dict()
def get_project_by_name(self, project_name, domain_id):
with sql.transaction() as session:
with sql.session_for_read() as session:
query = session.query(Project)
query = query.filter_by(name=project_name)
if domain_id is None:
@ -70,7 +70,7 @@ class Resource(keystone_resource.ResourceDriverV9):
for f in hints.filters:
if (f['name'] == 'domain_id' and f['value'] is None):
f['value'] = keystone_resource.NULL_DOMAIN_ID
with sql.transaction() as session:
with sql.session_for_read() as session:
query = session.query(Project)
project_refs = sql.filter_limit_query(Project, query, hints)
return [project_ref.to_dict() for project_ref in project_refs
@ -80,7 +80,7 @@ class Resource(keystone_resource.ResourceDriverV9):
if not ids:
return []
else:
with sql.transaction() as session:
with sql.session_for_read() as session:
query = session.query(Project)
query = query.filter(Project.id.in_(ids))
return [project_ref.to_dict() for project_ref in query.all()
@ -90,7 +90,7 @@ class Resource(keystone_resource.ResourceDriverV9):
if not domain_ids:
return []
else:
with sql.transaction() as session:
with sql.session_for_read() as session:
query = session.query(Project.id)
query = (
query.filter(Project.domain_id.in_(domain_ids)))
@ -98,7 +98,7 @@ class Resource(keystone_resource.ResourceDriverV9):
if not self._is_hidden_ref(x)]
def list_projects_in_domain(self, domain_id):
with sql.transaction() as session:
with sql.session_for_read() as session:
self._get_domain(session, domain_id)
query = session.query(Project)
project_refs = query.filter_by(domain_id=domain_id)
@ -111,7 +111,7 @@ class Resource(keystone_resource.ResourceDriverV9):
return [project_ref.to_dict() for project_ref in project_refs]
def list_projects_in_subtree(self, project_id):
with sql.transaction() as session:
with sql.session_for_read() as session:
children = self._get_children(session, [project_id])
subtree = []
examined = set([project_id])
@ -132,7 +132,7 @@ class Resource(keystone_resource.ResourceDriverV9):
return subtree
def list_project_parents(self, project_id):
with sql.transaction() as session:
with sql.session_for_read() as session:
project = self._get_project(session, project_id).to_dict()
parents = []
examined = set()
@ -152,7 +152,7 @@ class Resource(keystone_resource.ResourceDriverV9):
return parents
def is_leaf_project(self, project_id):
with sql.transaction() as session:
with sql.session_for_read() as session:
project_refs = self._get_children(session, [project_id])
return not project_refs
@ -161,7 +161,7 @@ class Resource(keystone_resource.ResourceDriverV9):
def create_project(self, project_id, project):
project['name'] = clean.project_name(project['name'])
new_project = self._encode_domain_id(project)
with sql.transaction() as session:
with sql.session_for_write() as session:
project_ref = Project.from_dict(new_project)
session.add(project_ref)
return project_ref.to_dict()
@ -172,7 +172,7 @@ class Resource(keystone_resource.ResourceDriverV9):
project['name'] = clean.project_name(project['name'])
update_project = self._encode_domain_id(project)
with sql.transaction() as session:
with sql.session_for_write() as session:
project_ref = self._get_project(session, project_id)
old_project_dict = project_ref.to_dict()
for k in update_project:
@ -189,7 +189,7 @@ class Resource(keystone_resource.ResourceDriverV9):
@sql.handle_conflicts(conflict_type='project')
def delete_project(self, project_id):
with sql.transaction() as session:
with sql.session_for_write() as session:
project_ref = self._get_project(session, project_id)
session.delete(project_ref)
@ -197,7 +197,7 @@ class Resource(keystone_resource.ResourceDriverV9):
def delete_projects_from_ids(self, project_ids):
if not project_ids:
return
with sql.transaction() as session:
with sql.session_for_write() as session:
query = session.query(Project).filter(Project.id.in_(
project_ids))
project_ids_from_bd = [p['id'] for p in query.all()]
@ -212,14 +212,14 @@ class Resource(keystone_resource.ResourceDriverV9):
@sql.handle_conflicts(conflict_type='domain')
def create_domain(self, domain_id, domain):
with sql.transaction() as session:
with sql.session_for_write() as session:
ref = Domain.from_dict(domain)
session.add(ref)
return ref.to_dict()
return ref.to_dict()
@driver_hints.truncated
def list_domains(self, hints):
with sql.transaction() as session:
with sql.session_for_read() as session:
query = session.query(Domain)
refs = sql.filter_limit_query(Domain, query, hints)
return [ref.to_dict() for ref in refs
@ -229,7 +229,7 @@ class Resource(keystone_resource.ResourceDriverV9):
if not ids:
return []
else:
with sql.transaction() as session:
with sql.session_for_read() as session:
query = session.query(Domain)
query = query.filter(Domain.id.in_(ids))
domain_refs = query.all()
@ -243,11 +243,11 @@ class Resource(keystone_resource.ResourceDriverV9):
return ref
def get_domain(self, domain_id):
with sql.transaction() as session:
with sql.session_for_read() as session:
return self._get_domain(session, domain_id).to_dict()
def get_domain_by_name(self, domain_name):
with sql.transaction() as session:
with sql.session_for_read() as session:
try:
ref = (session.query(Domain).
filter_by(name=domain_name).one())
@ -260,7 +260,7 @@ class Resource(keystone_resource.ResourceDriverV9):
@sql.handle_conflicts(conflict_type='domain')
def update_domain(self, domain_id, domain):
with sql.transaction() as session:
with sql.session_for_write() as session:
ref = self._get_domain(session, domain_id)
old_dict = ref.to_dict()
for k in domain:
@ -273,7 +273,7 @@ class Resource(keystone_resource.ResourceDriverV9):
return ref.to_dict()
def delete_domain(self, domain_id):
with sql.transaction() as session:
with sql.session_for_write() as session:
ref = self._get_domain(session, domain_id)
session.delete(ref)

View File

@ -59,12 +59,12 @@ class DomainConfig(resource.DomainConfigDriverV8):
@sql.handle_conflicts(conflict_type='domain_config')
def create_config_option(self, domain_id, group, option, value,
sensitive=False):
with sql.transaction() as session:
with sql.session_for_write() as session:
config_table = self.choose_table(sensitive)
ref = config_table(domain_id=domain_id, group=group,
option=option, value=value)
session.add(ref)
return ref.to_dict()
return ref.to_dict()
def _get_config_option(self, session, domain_id, group, option, sensitive):
try:
@ -80,14 +80,14 @@ class DomainConfig(resource.DomainConfigDriverV8):
return ref
def get_config_option(self, domain_id, group, option, sensitive=False):
with sql.transaction() as session:
with sql.session_for_read() as session:
ref = self._get_config_option(session, domain_id, group, option,
sensitive)
return ref.to_dict()
return ref.to_dict()
def list_config_options(self, domain_id, group=None, option=None,
sensitive=False):
with sql.transaction() as session:
with sql.session_for_read() as session:
config_table = self.choose_table(sensitive)
query = session.query(config_table)
query = query.filter_by(domain_id=domain_id)
@ -99,11 +99,11 @@ class DomainConfig(resource.DomainConfigDriverV8):
def update_config_option(self, domain_id, group, option, value,
sensitive=False):
with sql.transaction() as session:
with sql.session_for_write() as session:
ref = self._get_config_option(session, domain_id, group, option,
sensitive)
ref.value = value
return ref.to_dict()
return ref.to_dict()
def delete_config_options(self, domain_id, group=None, option=None,
sensitive=False):
@ -114,7 +114,7 @@ class DomainConfig(resource.DomainConfigDriverV8):
if there was nothing to delete.
"""
with sql.transaction() as session:
with sql.session_for_write() as session:
config_table = self.choose_table(sensitive)
query = session.query(config_table)
query = query.filter_by(domain_id=domain_id)
@ -126,7 +126,7 @@ class DomainConfig(resource.DomainConfigDriverV8):
def obtain_registration(self, domain_id, type):
try:
with sql.transaction() as session:
with sql.session_for_write() as session:
ref = ConfigRegister(type=type, domain_id=domain_id)
session.add(ref)
return True
@ -136,15 +136,15 @@ class DomainConfig(resource.DomainConfigDriverV8):
return False
def read_registration(self, type):
with sql.transaction() as session:
with sql.session_for_read() as session:
ref = session.query(ConfigRegister).get(type)
if not ref:
raise exception.ConfigRegistrationNotFound()
return ref.domain_id
return ref.domain_id
def release_registration(self, domain_id, type=None):
"""Silently delete anything registered for the domain specified."""
with sql.transaction() as session:
with sql.session_for_write() as session:
query = session.query(ConfigRegister)
if type:
query = query.filter_by(type=type)

View File

@ -60,37 +60,37 @@ class Revoke(revoke.RevokeDriverV8):
def _prune_expired_events(self):
oldest = revoke.revoked_before_cutoff_time()
session = sql.get_session()
dialect = session.bind.dialect.name
batch_size = self._flush_batch_size(dialect)
if batch_size > 0:
query = session.query(RevocationEvent.id)
query = query.filter(RevocationEvent.revoked_at < oldest)
query = query.limit(batch_size).subquery()
delete_query = (session.query(RevocationEvent).
filter(RevocationEvent.id.in_(query)))
while True:
rowcount = delete_query.delete(synchronize_session=False)
if rowcount == 0:
break
else:
query = session.query(RevocationEvent)
query = query.filter(RevocationEvent.revoked_at < oldest)
query.delete(synchronize_session=False)
with sql.session_for_write() as session:
dialect = session.bind.dialect.name
batch_size = self._flush_batch_size(dialect)
if batch_size > 0:
query = session.query(RevocationEvent.id)
query = query.filter(RevocationEvent.revoked_at < oldest)
query = query.limit(batch_size).subquery()
delete_query = (session.query(RevocationEvent).
filter(RevocationEvent.id.in_(query)))
while True:
rowcount = delete_query.delete(synchronize_session=False)
if rowcount == 0:
break
else:
query = session.query(RevocationEvent)
query = query.filter(RevocationEvent.revoked_at < oldest)
query.delete(synchronize_session=False)
session.flush()
session.flush()
def list_events(self, last_fetch=None):
session = sql.get_session()
query = session.query(RevocationEvent).order_by(
RevocationEvent.revoked_at)
with sql.session_for_read() as session:
query = session.query(RevocationEvent).order_by(
RevocationEvent.revoked_at)
if last_fetch:
query = query.filter(RevocationEvent.revoked_at > last_fetch)
if last_fetch:
query = query.filter(RevocationEvent.revoked_at > last_fetch)
events = [model.RevokeEvent(**e.to_dict()) for e in query]
events = [model.RevokeEvent(**e.to_dict()) for e in query]
return events
return events
def revoke(self, event):
kwargs = dict()
@ -98,7 +98,6 @@ class Revoke(revoke.RevokeDriverV8):
kwargs[attr] = getattr(event, attr)
kwargs['id'] = uuid.uuid4().hex
record = RevocationEvent(**kwargs)
session = sql.get_session()
with session.begin():
with sql.session_for_write() as session:
session.add(record)
self._prune_expired_events()
self._prune_expired_events()

View File

@ -17,6 +17,6 @@ from keystone.identity.mapping_backends import sql as mapping_sql
def list_id_mappings():
"""List all id_mappings for testing purposes."""
a_session = sql.get_session()
refs = a_session.query(mapping_sql.IDMapping).all()
return [x.to_dict() for x in refs]
with sql.session_for_read() as session:
refs = session.query(mapping_sql.IDMapping).all()
return [x.to_dict() for x in refs]

View File

@ -611,7 +611,7 @@ class SqlToken(SqlTests, test_backend.TokenTests):
tok = token_sql.Token()
tok.list_revoked_tokens()
mock_query = mock_sql.get_session().query
mock_query = mock_sql.session_for_read().__enter__().query
mock_query.assert_called_with(*expected_query_args)
def test_flush_expired_tokens_batch(self):
@ -636,8 +636,12 @@ class SqlToken(SqlTests, test_backend.TokenTests):
# other tests below test the differences between how they use the batch
# strategy
with mock.patch.object(token_sql, 'sql') as mock_sql:
mock_sql.get_session().query().filter().delete.return_value = 0
mock_sql.get_session().bind.dialect.name = 'mysql'
mock_sql.session_for_write().__enter__(
).query().filter().delete.return_value = 0
mock_sql.session_for_write().__enter__(
).bind.dialect.name = 'mysql'
tok = token_sql.Token()
expiry_mock = mock.Mock()
ITERS = [1, 2, 3]
@ -648,7 +652,10 @@ class SqlToken(SqlTests, test_backend.TokenTests):
# The expiry strategy is only invoked once, the other calls are via
# the yield return.
self.assertEqual(1, expiry_mock.call_count)
mock_delete = mock_sql.get_session().query().filter().delete
mock_delete = mock_sql.session_for_write().__enter__(
).query().filter().delete
self.assertThat(mock_delete.call_args_list,
matchers.HasLength(len(ITERS)))

View File

@ -161,7 +161,8 @@ class SqlMigrateBase(unit.SQLDriverOverrides, unit.TestCase):
self.repo_package())
self.schema = versioning_api.ControlledSchema.create(
self.engine,
self.repo_path, self.initial_db_version)
self.repo_path,
self.initial_db_version)
# auto-detect the highest available schema version in the migrate_repo
self.max_version = self.schema.repository.version().version

View File

@ -86,11 +86,11 @@ class Token(token.persistence.TokenDriverV8):
def get_token(self, token_id):
if token_id is None:
raise exception.TokenNotFound(token_id=token_id)
session = sql.get_session()
token_ref = session.query(TokenModel).get(token_id)
if not token_ref or not token_ref.valid:
raise exception.TokenNotFound(token_id=token_id)
return token_ref.to_dict()
with sql.session_for_read() as session:
token_ref = session.query(TokenModel).get(token_id)
if not token_ref or not token_ref.valid:
raise exception.TokenNotFound(token_id=token_id)
return token_ref.to_dict()
def create_token(self, token_id, data):
data_copy = copy.deepcopy(data)
@ -101,14 +101,12 @@ class Token(token.persistence.TokenDriverV8):
token_ref = TokenModel.from_dict(data_copy)
token_ref.valid = True
session = sql.get_session()
with session.begin():
with sql.session_for_write() as session:
session.add(token_ref)
return token_ref.to_dict()
def delete_token(self, token_id):
session = sql.get_session()
with session.begin():
with sql.session_for_write() as session:
token_ref = session.query(TokenModel).get(token_id)
if not token_ref or not token_ref.valid:
raise exception.TokenNotFound(token_id=token_id)
@ -124,9 +122,8 @@ class Token(token.persistence.TokenDriverV8):
or the trustor's user ID, so will use trust_id to query the tokens.
"""
session = sql.get_session()
token_list = []
with session.begin():
with sql.session_for_write() as session:
now = timeutils.utcnow()
query = session.query(TokenModel)
query = query.filter_by(valid=True)
@ -167,38 +164,37 @@ class Token(token.persistence.TokenDriverV8):
return False
def _list_tokens_for_trust(self, trust_id):
session = sql.get_session()
tokens = []
now = timeutils.utcnow()
query = session.query(TokenModel)
query = query.filter(TokenModel.expires > now)
query = query.filter(TokenModel.trust_id == trust_id)
with sql.session_for_read() as session:
tokens = []
now = timeutils.utcnow()
query = session.query(TokenModel)
query = query.filter(TokenModel.expires > now)
query = query.filter(TokenModel.trust_id == trust_id)
token_references = query.filter_by(valid=True)
for token_ref in token_references:
token_ref_dict = token_ref.to_dict()
tokens.append(token_ref_dict['id'])
return tokens
token_references = query.filter_by(valid=True)
for token_ref in token_references:
token_ref_dict = token_ref.to_dict()
tokens.append(token_ref_dict['id'])
return tokens
def _list_tokens_for_user(self, user_id, tenant_id=None):
session = sql.get_session()
tokens = []
now = timeutils.utcnow()
query = session.query(TokenModel)
query = query.filter(TokenModel.expires > now)
query = query.filter(TokenModel.user_id == user_id)
with sql.session_for_read() as session:
tokens = []
now = timeutils.utcnow()
query = session.query(TokenModel)
query = query.filter(TokenModel.expires > now)
query = query.filter(TokenModel.user_id == user_id)
token_references = query.filter_by(valid=True)
for token_ref in token_references:
token_ref_dict = token_ref.to_dict()
if self._tenant_matches(tenant_id, token_ref_dict):
tokens.append(token_ref['id'])
return tokens
token_references = query.filter_by(valid=True)
for token_ref in token_references:
token_ref_dict = token_ref.to_dict()
if self._tenant_matches(tenant_id, token_ref_dict):
tokens.append(token_ref['id'])
return tokens
def _list_tokens_for_consumer(self, user_id, consumer_id):
tokens = []
session = sql.get_session()
with session.begin():
with sql.session_for_write() as session:
now = timeutils.utcnow()
query = session.query(TokenModel)
query = query.filter(TokenModel.expires > now)
@ -223,29 +219,29 @@ class Token(token.persistence.TokenDriverV8):
return self._list_tokens_for_user(user_id, tenant_id)
def list_revoked_tokens(self):
session = sql.get_session()
tokens = []
now = timeutils.utcnow()
query = session.query(TokenModel.id, TokenModel.expires,
TokenModel.extra)
query = query.filter(TokenModel.expires > now)
token_references = query.filter_by(valid=False)
for token_ref in token_references:
token_data = token_ref[2]['token_data']
if 'access' in token_data:
# It's a v2 token.
audit_ids = token_data['access']['token']['audit_ids']
else:
# It's a v3 token.
audit_ids = token_data['token']['audit_ids']
with sql.session_for_read() as session:
tokens = []
now = timeutils.utcnow()
query = session.query(TokenModel.id, TokenModel.expires,
TokenModel.extra)
query = query.filter(TokenModel.expires > now)
token_references = query.filter_by(valid=False)
for token_ref in token_references:
token_data = token_ref[2]['token_data']
if 'access' in token_data:
# It's a v2 token.
audit_ids = token_data['access']['token']['audit_ids']
else:
# It's a v3 token.
audit_ids = token_data['token']['audit_ids']
record = {
'id': token_ref[0],
'expires': token_ref[1],
'audit_id': audit_ids[0],
}
tokens.append(record)
return tokens
record = {
'id': token_ref[0],
'expires': token_ref[1],
'audit_id': audit_ids[0],
}
tokens.append(record)
return tokens
def _expiry_range_strategy(self, dialect):
"""Choose a token range expiration strategy
@ -273,18 +269,18 @@ class Token(token.persistence.TokenDriverV8):
return _expiry_range_all
def flush_expired_tokens(self):
session = sql.get_session()
dialect = session.bind.dialect.name
expiry_range_func = self._expiry_range_strategy(dialect)
query = session.query(TokenModel.expires)
total_removed = 0
upper_bound_func = timeutils.utcnow
for expiry_time in expiry_range_func(session, upper_bound_func):
delete_query = query.filter(TokenModel.expires <=
expiry_time)
row_count = delete_query.delete(synchronize_session=False)
total_removed += row_count
LOG.debug('Removed %d total expired tokens', total_removed)
with sql.session_for_write() as session:
dialect = session.bind.dialect.name
expiry_range_func = self._expiry_range_strategy(dialect)
query = session.query(TokenModel.expires)
total_removed = 0
upper_bound_func = timeutils.utcnow
for expiry_time in expiry_range_func(session, upper_bound_func):
delete_query = query.filter(TokenModel.expires <=
expiry_time)
row_count = delete_query.delete(synchronize_session=False)
total_removed += row_count
LOG.debug('Removed %d total expired tokens', total_removed)
session.flush()
LOG.info(_LI('Total expired tokens removed: %d'), total_removed)
session.flush()
LOG.info(_LI('Total expired tokens removed: %d'), total_removed)

View File

@ -59,7 +59,7 @@ class TrustRole(sql.ModelBase):
class Trust(trust.TrustDriverV8):
@sql.handle_conflicts(conflict_type='trust')
def create_trust(self, trust_id, trust, roles):
with sql.transaction() as session:
with sql.session_for_write() as session:
ref = TrustModel.from_dict(trust)
ref['id'] = trust_id
if ref.get('expires_at') and ref['expires_at'].tzinfo is not None:
@ -72,9 +72,9 @@ class Trust(trust.TrustDriverV8):
trust_role.role_id = role['id']
added_roles.append({'id': role['id']})
session.add(trust_role)
trust_dict = ref.to_dict()
trust_dict['roles'] = added_roles
return trust_dict
trust_dict = ref.to_dict()
trust_dict['roles'] = added_roles
return trust_dict
def _add_roles(self, trust_id, session, trust_dict):
roles = []
@ -86,7 +86,7 @@ class Trust(trust.TrustDriverV8):
def consume_use(self, trust_id):
for attempt in range(MAXIMUM_CONSUME_ATTEMPTS):
with sql.transaction() as session:
with sql.session_for_write() as session:
try:
query_result = (session.query(TrustModel.remaining_uses).
filter_by(id=trust_id).
@ -132,51 +132,51 @@ class Trust(trust.TrustDriverV8):
raise exception.TrustConsumeMaximumAttempt(trust_id=trust_id)
def get_trust(self, trust_id, deleted=False):
session = sql.get_session()
query = session.query(TrustModel).filter_by(id=trust_id)
if not deleted:
query = query.filter_by(deleted_at=None)
ref = query.first()
if ref is None:
raise exception.TrustNotFound(trust_id=trust_id)
if ref.expires_at is not None and not deleted:
now = timeutils.utcnow()
if now > ref.expires_at:
with sql.session_for_read() as session:
query = session.query(TrustModel).filter_by(id=trust_id)
if not deleted:
query = query.filter_by(deleted_at=None)
ref = query.first()
if ref is None:
raise exception.TrustNotFound(trust_id=trust_id)
# Do not return trusts that can't be used anymore
if ref.remaining_uses is not None and not deleted:
if ref.remaining_uses <= 0:
raise exception.TrustNotFound(trust_id=trust_id)
trust_dict = ref.to_dict()
if ref.expires_at is not None and not deleted:
now = timeutils.utcnow()
if now > ref.expires_at:
raise exception.TrustNotFound(trust_id=trust_id)
# Do not return trusts that can't be used anymore
if ref.remaining_uses is not None and not deleted:
if ref.remaining_uses <= 0:
raise exception.TrustNotFound(trust_id=trust_id)
trust_dict = ref.to_dict()
self._add_roles(trust_id, session, trust_dict)
return trust_dict
self._add_roles(trust_id, session, trust_dict)
return trust_dict
@sql.handle_conflicts(conflict_type='trust')
def list_trusts(self):
session = sql.get_session()
trusts = session.query(TrustModel).filter_by(deleted_at=None)
return [trust_ref.to_dict() for trust_ref in trusts]
with sql.session_for_read() as session:
trusts = session.query(TrustModel).filter_by(deleted_at=None)
return [trust_ref.to_dict() for trust_ref in trusts]
@sql.handle_conflicts(conflict_type='trust')
def list_trusts_for_trustee(self, trustee_user_id):
session = sql.get_session()
trusts = (session.query(TrustModel).
filter_by(deleted_at=None).
filter_by(trustee_user_id=trustee_user_id))
return [trust_ref.to_dict() for trust_ref in trusts]
with sql.session_for_read() as session:
trusts = (session.query(TrustModel).
filter_by(deleted_at=None).
filter_by(trustee_user_id=trustee_user_id))
return [trust_ref.to_dict() for trust_ref in trusts]
@sql.handle_conflicts(conflict_type='trust')
def list_trusts_for_trustor(self, trustor_user_id):
session = sql.get_session()
trusts = (session.query(TrustModel).
filter_by(deleted_at=None).
filter_by(trustor_user_id=trustor_user_id))
return [trust_ref.to_dict() for trust_ref in trusts]
with sql.session_for_read() as session:
trusts = (session.query(TrustModel).
filter_by(deleted_at=None).
filter_by(trustor_user_id=trustor_user_id))
return [trust_ref.to_dict() for trust_ref in trusts]
@sql.handle_conflicts(conflict_type='trust')
def delete_trust(self, trust_id):
with sql.transaction() as session:
with sql.session_for_write() as session:
trust_ref = session.query(TrustModel).get(trust_id)
if not trust_ref:
raise exception.TrustNotFound(trust_id=trust_id)