Use the new enginefacade from oslo.db

EngineFacade is deprecated. This partially switches keystone to
use oslo.db.sqlalchemy.enginefacade. 'get_session' and 'get_engine'
methods are still used in sql migrations and related tests.

Change-Id: I221232d50821fe2adb9881f237f06714003ce79d
Partial-Bug: #1490571
changes/58/257458/5
Grzegorz Grasza 7 years ago committed by Morgan Fainberg
parent e943768088
commit 0e156737d0

@ -56,7 +56,7 @@ class Assignment(keystone_assignment.AssignmentDriverV8):
return 'sql'
def list_user_ids_for_project(self, tenant_id):
with sql.transaction() as session:
with sql.session_for_read() as session:
query = session.query(RoleAssignment.actor_id)
query = query.filter_by(type=AssignmentType.USER_PROJECT)
query = query.filter_by(target_id=tenant_id)
@ -71,7 +71,7 @@ class Assignment(keystone_assignment.AssignmentDriverV8):
assignment_type = AssignmentType.calculate_type(
user_id, group_id, project_id, domain_id)
try:
with sql.transaction() as session:
with sql.session_for_write() as session:
session.add(RoleAssignment(
type=assignment_type,
actor_id=user_id or group_id,
@ -85,7 +85,7 @@ class Assignment(keystone_assignment.AssignmentDriverV8):
def list_grant_role_ids(self, user_id=None, group_id=None,
domain_id=None, project_id=None,
inherited_to_projects=False):
with sql.transaction() as session:
with sql.session_for_read() as session:
q = session.query(RoleAssignment.role_id)
q = q.filter(RoleAssignment.actor_id == (user_id or group_id))
q = q.filter(RoleAssignment.target_id == (project_id or domain_id))
@ -104,7 +104,7 @@ class Assignment(keystone_assignment.AssignmentDriverV8):
def check_grant_role_id(self, role_id, user_id=None, group_id=None,
domain_id=None, project_id=None,
inherited_to_projects=False):
with sql.transaction() as session:
with sql.session_for_read() as session:
try:
q = self._build_grant_filter(
session, role_id, user_id, group_id, domain_id, project_id,
@ -120,7 +120,7 @@ class Assignment(keystone_assignment.AssignmentDriverV8):
def delete_grant(self, role_id, user_id=None, group_id=None,
domain_id=None, project_id=None,
inherited_to_projects=False):
with sql.transaction() as session:
with sql.session_for_write() as session:
q = self._build_grant_filter(
session, role_id, user_id, group_id, domain_id, project_id,
inherited_to_projects)
@ -145,11 +145,11 @@ class Assignment(keystone_assignment.AssignmentDriverV8):
RoleAssignment.inherited == inherited,
RoleAssignment.actor_id.in_(actors))
with sql.transaction() as session:
with sql.session_for_read() as session:
query = session.query(RoleAssignment.target_id).filter(
sql_constraints).distinct()
return [x.target_id for x in query.all()]
return [x.target_id for x in query.all()]
def list_project_ids_for_user(self, user_id, group_ids, hints,
inherited=False):
@ -161,7 +161,7 @@ class Assignment(keystone_assignment.AssignmentDriverV8):
def list_domain_ids_for_user(self, user_id, group_ids, hints,
inherited=False):
with sql.transaction() as session:
with sql.session_for_read() as session:
query = session.query(RoleAssignment.target_id)
filters = []
@ -197,10 +197,10 @@ class Assignment(keystone_assignment.AssignmentDriverV8):
RoleAssignment.inherited == false(),
RoleAssignment.actor_id.in_(group_ids))
with sql.transaction() as session:
with sql.session_for_read() as session:
query = session.query(RoleAssignment.role_id).filter(
sql_constraints).distinct()
return [role.role_id for role in query.all()]
return [role.role_id for role in query.all()]
def list_role_ids_for_groups_on_project(
self, group_ids, project_id, project_domain_id, project_parents):
@ -237,13 +237,13 @@ class Assignment(keystone_assignment.AssignmentDriverV8):
sql_constraints = sqlalchemy.and_(
sql_constraints, RoleAssignment.actor_id.in_(group_ids))
with sql.transaction() as session:
with sql.session_for_read() as session:
# NOTE(morganfainberg): Only select the columns we actually care
# about here, in this case role_id.
query = session.query(RoleAssignment.role_id).filter(
sql_constraints).distinct()
return [result.role_id for result in query.all()]
return [result.role_id for result in query.all()]
def list_project_ids_for_groups(self, group_ids, hints,
inherited=False):
@ -260,14 +260,14 @@ class Assignment(keystone_assignment.AssignmentDriverV8):
RoleAssignment.inherited == inherited,
RoleAssignment.actor_id.in_(group_ids))
with sql.transaction() as session:
with sql.session_for_read() as session:
query = session.query(RoleAssignment.target_id).filter(
group_sql_conditions).distinct()
return [x.target_id for x in query.all()]
return [x.target_id for x in query.all()]
def add_role_to_user_and_project(self, user_id, tenant_id, role_id):
try:
with sql.transaction() as session:
with sql.session_for_write() as session:
session.add(RoleAssignment(
type=AssignmentType.USER_PROJECT,
actor_id=user_id, target_id=tenant_id,
@ -278,7 +278,7 @@ class Assignment(keystone_assignment.AssignmentDriverV8):
raise exception.Conflict(type='role grant', details=msg)
def remove_role_from_user_and_project(self, user_id, tenant_id, role_id):
with sql.transaction() as session:
with sql.session_for_write() as session:
q = session.query(RoleAssignment)
q = q.filter_by(actor_id=user_id)
q = q.filter_by(target_id=tenant_id)
@ -368,7 +368,7 @@ class Assignment(keystone_assignment.AssignmentDriverV8):
assignment['inherited_to_projects'] = 'projects'
return assignment
with sql.transaction() as session:
with sql.session_for_read() as session:
assignment_types = self._get_assignment_types(
user_id, group_ids, project_ids, domain_id)
@ -400,25 +400,25 @@ class Assignment(keystone_assignment.AssignmentDriverV8):
return [denormalize_role(ref) for ref in query.all()]
def delete_project_assignments(self, project_id):
with sql.transaction() as session:
with sql.session_for_write() as session:
q = session.query(RoleAssignment)
q = q.filter_by(target_id=project_id)
q.delete(False)
def delete_role_assignments(self, role_id):
with sql.transaction() as session:
with sql.session_for_write() as session:
q = session.query(RoleAssignment)
q = q.filter_by(role_id=role_id)
q.delete(False)
def delete_user_assignments(self, user_id):
with sql.transaction() as session:
with sql.session_for_write() as session:
q = session.query(RoleAssignment)
q = q.filter_by(actor_id=user_id)
q.delete(False)
def delete_group_assignments(self, group_id):
with sql.transaction() as session:
with sql.session_for_write() as session:
q = session.query(RoleAssignment)
q = q.filter_by(actor_id=group_id)
q.delete(False)

@ -19,14 +19,14 @@ class Role(assignment.RoleDriverV8):
@sql.handle_conflicts(conflict_type='role')
def create_role(self, role_id, role):
with sql.transaction() as session:
with sql.session_for_write() as session:
ref = RoleTable.from_dict(role)
session.add(ref)
return ref.to_dict()
@sql.truncated
def list_roles(self, hints):
with sql.transaction() as session:
with sql.session_for_read() as session:
query = session.query(RoleTable)
refs = sql.filter_limit_query(RoleTable, query, hints)
return [ref.to_dict() for ref in refs]
@ -35,7 +35,7 @@ class Role(assignment.RoleDriverV8):
if not ids:
return []
else:
with sql.transaction() as session:
with sql.session_for_read() as session:
query = session.query(RoleTable)
query = query.filter(RoleTable.id.in_(ids))
role_refs = query.all()
@ -48,12 +48,12 @@ class Role(assignment.RoleDriverV8):
return ref
def get_role(self, role_id):
with sql.transaction() as session:
with sql.session_for_read() as session:
return self._get_role(session, role_id).to_dict()
@sql.handle_conflicts(conflict_type='role')
def update_role(self, role_id, role):
with sql.transaction() as session:
with sql.session_for_write() as session:
ref = self._get_role(session, role_id)
old_dict = ref.to_dict()
for k in role:
@ -66,7 +66,7 @@ class Role(assignment.RoleDriverV8):
return ref.to_dict()
def delete_role(self, role_id):
with sql.transaction() as session:
with sql.session_for_write() as session:
ref = self._get_role(session, role_id)
session.delete(ref)

@ -55,7 +55,7 @@ class Assignment(keystone_assignment.AssignmentDriverV9):
assignment_type = AssignmentType.calculate_type(
user_id, group_id, project_id, domain_id)
try:
with sql.transaction() as session:
with sql.session_for_write() as session:
session.add(RoleAssignment(
type=assignment_type,
actor_id=user_id or group_id,
@ -69,7 +69,7 @@ class Assignment(keystone_assignment.AssignmentDriverV9):
def list_grant_role_ids(self, user_id=None, group_id=None,
domain_id=None, project_id=None,
inherited_to_projects=False):
with sql.transaction() as session:
with sql.session_for_read() as session:
q = session.query(RoleAssignment.role_id)
q = q.filter(RoleAssignment.actor_id == (user_id or group_id))
q = q.filter(RoleAssignment.target_id == (project_id or domain_id))
@ -88,7 +88,7 @@ class Assignment(keystone_assignment.AssignmentDriverV9):
def check_grant_role_id(self, role_id, user_id=None, group_id=None,
domain_id=None, project_id=None,
inherited_to_projects=False):
with sql.transaction() as session:
with sql.session_for_read() as session:
try:
q = self._build_grant_filter(
session, role_id, user_id, group_id, domain_id, project_id,
@ -104,7 +104,7 @@ class Assignment(keystone_assignment.AssignmentDriverV9):
def delete_grant(self, role_id, user_id=None, group_id=None,
domain_id=None, project_id=None,
inherited_to_projects=False):
with sql.transaction() as session:
with sql.session_for_write() as session:
q = self._build_grant_filter(
session, role_id, user_id, group_id, domain_id, project_id,
inherited_to_projects)
@ -117,7 +117,7 @@ class Assignment(keystone_assignment.AssignmentDriverV9):
def add_role_to_user_and_project(self, user_id, tenant_id, role_id):
try:
with sql.transaction() as session:
with sql.session_for_write() as session:
session.add(RoleAssignment(
type=AssignmentType.USER_PROJECT,
actor_id=user_id, target_id=tenant_id,
@ -128,7 +128,7 @@ class Assignment(keystone_assignment.AssignmentDriverV9):
raise exception.Conflict(type='role grant', details=msg)
def remove_role_from_user_and_project(self, user_id, tenant_id, role_id):
with sql.transaction() as session:
with sql.session_for_write() as session:
q = session.query(RoleAssignment)
q = q.filter_by(actor_id=user_id)
q = q.filter_by(target_id=tenant_id)
@ -218,7 +218,7 @@ class Assignment(keystone_assignment.AssignmentDriverV9):
assignment['inherited_to_projects'] = 'projects'
return assignment
with sql.transaction() as session:
with sql.session_for_read() as session:
assignment_types = self._get_assignment_types(
user_id, group_ids, project_ids, domain_id)
@ -250,7 +250,7 @@ class Assignment(keystone_assignment.AssignmentDriverV9):
return [denormalize_role(ref) for ref in query.all()]
def delete_project_assignments(self, project_id):
with sql.transaction() as session:
with sql.session_for_write() as session:
q = session.query(RoleAssignment)
q = q.filter_by(target_id=project_id).filter(
RoleAssignment.type.in_((AssignmentType.USER_PROJECT,
@ -259,13 +259,13 @@ class Assignment(keystone_assignment.AssignmentDriverV9):
q.delete(False)
def delete_role_assignments(self, role_id):
with sql.transaction() as session:
with sql.session_for_write() as session:
q = session.query(RoleAssignment)
q = q.filter_by(role_id=role_id)
q.delete(False)
def delete_user_assignments(self, user_id):
with sql.transaction() as session:
with sql.session_for_write() as session:
q = session.query(RoleAssignment)
q = q.filter_by(actor_id=user_id).filter(
RoleAssignment.type.in_((AssignmentType.USER_PROJECT,
@ -274,7 +274,7 @@ class Assignment(keystone_assignment.AssignmentDriverV9):
q.delete(False)
def delete_group_assignments(self, group_id):
with sql.transaction() as session:
with sql.session_for_write() as session:
q = session.query(RoleAssignment)
q = q.filter_by(actor_id=group_id).filter(
RoleAssignment.type.in_((AssignmentType.GROUP_PROJECT,

@ -29,7 +29,7 @@ class Role(assignment.RoleDriverV9):
@sql.handle_conflicts(conflict_type='role')
def create_role(self, role_id, role):
with sql.transaction() as session:
with sql.session_for_write() as session:
ref = RoleTable.from_dict(role)
session.add(ref)
return ref.to_dict()
@ -46,7 +46,7 @@ class Role(assignment.RoleDriverV9):
if (f['name'] == 'domain_id' and f['value'] is None):
f['value'] = NULL_DOMAIN_ID
with sql.transaction() as session:
with sql.session_for_read() as session:
query = session.query(RoleTable)
refs = sql.filter_limit_query(RoleTable, query, hints)
return [ref.to_dict() for ref in refs]
@ -55,7 +55,7 @@ class Role(assignment.RoleDriverV9):
if not ids:
return []
else:
with sql.transaction() as session:
with sql.session_for_read() as session:
query = session.query(RoleTable)
query = query.filter(RoleTable.id.in_(ids))
role_refs = query.all()
@ -68,12 +68,12 @@ class Role(assignment.RoleDriverV9):
return ref
def get_role(self, role_id):
with sql.transaction() as session:
with sql.session_for_read() as session:
return self._get_role(session, role_id).to_dict()
@sql.handle_conflicts(conflict_type='role')
def update_role(self, role_id, role):
with sql.transaction() as session:
with sql.session_for_write() as session:
ref = self._get_role(session, role_id)
old_dict = ref.to_dict()
for k in role:
@ -86,7 +86,7 @@ class Role(assignment.RoleDriverV9):
return ref.to_dict()
def delete_role(self, role_id):
with sql.transaction() as session:
with sql.session_for_write() as session:
ref = self._get_role(session, role_id)
session.delete(ref)
@ -105,7 +105,7 @@ class Role(assignment.RoleDriverV9):
@sql.handle_conflicts(conflict_type='implied_role')
def create_implied_role(self, prior_role_id, implied_role_id):
with sql.transaction() as session:
with sql.session_for_write() as session:
inference = {'prior_role_id': prior_role_id,
'implied_role_id': implied_role_id}
ref = ImpliedRoleTable.from_dict(inference)
@ -119,13 +119,13 @@ class Role(assignment.RoleDriverV9):
return ref.to_dict()
def delete_implied_role(self, prior_role_id, implied_role_id):
with sql.transaction() as session:
with sql.session_for_write() as session:
ref = self._get_implied_role(session, prior_role_id,
implied_role_id)
session.delete(ref)
def list_implied_roles(self, prior_role_id):
with sql.transaction() as session:
with sql.session_for_read() as session:
query = session.query(
ImpliedRoleTable).filter(
ImpliedRoleTable.prior_role_id == prior_role_id)
@ -133,13 +133,13 @@ class Role(assignment.RoleDriverV9):
return [ref.to_dict() for ref in refs]
def list_role_inference_rules(self):
with sql.transaction() as session:
with sql.session_for_read() as session:
query = session.query(ImpliedRoleTable)
refs = query.all()
return [ref.to_dict() for ref in refs]
def get_implied_role(self, prior_role_id, implied_role_id):
with sql.transaction() as session:
with sql.session_for_read() as session:
ref = self._get_implied_role(session, prior_role_id,
implied_role_id)
return ref.to_dict()

@ -84,10 +84,10 @@ class Endpoint(sql.ModelBase, sql.DictBase):
class Catalog(catalog.CatalogDriverV8):
# Regions
def list_regions(self, hints):
session = sql.get_session()
regions = session.query(Region)
regions = sql.filter_limit_query(Region, regions, hints)
return [s.to_dict() for s in list(regions)]
with sql.session_for_read() as session:
regions = session.query(Region)
regions = sql.filter_limit_query(Region, regions, hints)
return [s.to_dict() for s in list(regions)]
def _get_region(self, session, region_id):
ref = session.query(Region).get(region_id)
@ -136,12 +136,11 @@ class Catalog(catalog.CatalogDriverV8):
return False
def get_region(self, region_id):
session = sql.get_session()
return self._get_region(session, region_id).to_dict()
with sql.session_for_read() as session:
return self._get_region(session, region_id).to_dict()
def delete_region(self, region_id):
session = sql.get_session()
with session.begin():
with sql.session_for_write() as session:
ref = self._get_region(session, region_id)
if self._has_endpoints(session, ref, ref):
raise exception.RegionDeletionError(region_id=region_id)
@ -150,16 +149,14 @@ class Catalog(catalog.CatalogDriverV8):
@sql.handle_conflicts(conflict_type='region')
def create_region(self, region_ref):
session = sql.get_session()
with session.begin():
with sql.session_for_write() as session:
self._check_parent_region(session, region_ref)
region = Region.from_dict(region_ref)
session.add(region)
return region.to_dict()
return region.to_dict()
def update_region(self, region_id, region_ref):
session = sql.get_session()
with session.begin():
with sql.session_for_write() as session:
self._check_parent_region(session, region_ref)
ref = self._get_region(session, region_id)
old_dict = ref.to_dict()
@ -169,15 +166,15 @@ class Catalog(catalog.CatalogDriverV8):
for attr in Region.attributes:
if attr != 'id':
setattr(ref, attr, getattr(new_region, attr))
return ref.to_dict()
return ref.to_dict()
# Services
@driver_hints.truncated
def list_services(self, hints):
session = sql.get_session()
services = session.query(Service)
services = sql.filter_limit_query(Service, services, hints)
return [s.to_dict() for s in list(services)]
with sql.session_for_read() as session:
services = session.query(Service)
services = sql.filter_limit_query(Service, services, hints)
return [s.to_dict() for s in list(services)]
def _get_service(self, session, service_id):
ref = session.query(Service).get(service_id)
@ -186,26 +183,23 @@ class Catalog(catalog.CatalogDriverV8):
return ref
def get_service(self, service_id):
session = sql.get_session()
return self._get_service(session, service_id).to_dict()
with sql.session_for_read() as session:
return self._get_service(session, service_id).to_dict()
def delete_service(self, service_id):
session = sql.get_session()
with session.begin():
with sql.session_for_write() as session:
ref = self._get_service(session, service_id)
session.query(Endpoint).filter_by(service_id=service_id).delete()
session.delete(ref)
def create_service(self, service_id, service_ref):
session = sql.get_session()
with session.begin():
with sql.session_for_write() as session:
service = Service.from_dict(service_ref)
session.add(service)
return service.to_dict()
return service.to_dict()
def update_service(self, service_id, service_ref):
session = sql.get_session()
with session.begin():
with sql.session_for_write() as session:
ref = self._get_service(session, service_id)
old_dict = ref.to_dict()
old_dict.update(service_ref)
@ -214,20 +208,17 @@ class Catalog(catalog.CatalogDriverV8):
if attr != 'id':
setattr(ref, attr, getattr(new_service, attr))
ref.extra = new_service.extra
return ref.to_dict()
return ref.to_dict()
# Endpoints
def create_endpoint(self, endpoint_id, endpoint_ref):
session = sql.get_session()
new_endpoint = Endpoint.from_dict(endpoint_ref)
with session.begin():
with sql.session_for_write() as session:
session.add(new_endpoint)
return new_endpoint.to_dict()
def delete_endpoint(self, endpoint_id):
session = sql.get_session()
with session.begin():
with sql.session_for_write() as session:
ref = self._get_endpoint(session, endpoint_id)
session.delete(ref)
@ -238,20 +229,18 @@ class Catalog(catalog.CatalogDriverV8):
raise exception.EndpointNotFound(endpoint_id=endpoint_id)
def get_endpoint(self, endpoint_id):
session = sql.get_session()
return self._get_endpoint(session, endpoint_id).to_dict()
with sql.session_for_read() as session:
return self._get_endpoint(session, endpoint_id).to_dict()
@driver_hints.truncated
def list_endpoints(self, hints):
session = sql.get_session()
endpoints = session.query(Endpoint)
endpoints = sql.filter_limit_query(Endpoint, endpoints, hints)
return [e.to_dict() for e in list(endpoints)]
with sql.session_for_read() as session:
endpoints = session.query(Endpoint)
endpoints = sql.filter_limit_query(Endpoint, endpoints, hints)
return [e.to_dict() for e in list(endpoints)]
def update_endpoint(self, endpoint_id, endpoint_ref):
session = sql.get_session()
with session.begin():
with sql.session_for_write() as session:
ref = self._get_endpoint(session, endpoint_id)
old_dict = ref.to_dict()
old_dict.update(endpoint_ref)
@ -260,7 +249,7 @@ class Catalog(catalog.CatalogDriverV8):
if attr != 'id':
setattr(ref, attr, getattr(new_endpoint, attr))
ref.extra = new_endpoint.extra
return ref.to_dict()
return ref.to_dict()
def get_catalog(self, user_id, tenant_id):
"""Retrieve and format the V2 service catalog.
@ -289,40 +278,40 @@ class Catalog(catalog.CatalogDriverV8):
else:
silent_keyerror_failures = ['tenant_id', 'project_id', ]
session = sql.get_session()
endpoints = (session.query(Endpoint).
options(sql.joinedload(Endpoint.service)).
filter(Endpoint.enabled == true()).all())
catalog = {}
for endpoint in endpoints:
if not endpoint.service['enabled']:
continue
try:
formatted_url = core.format_url(
endpoint['url'], substitutions,
silent_keyerror_failures=silent_keyerror_failures)
if formatted_url is not None:
url = formatted_url
else:
with sql.session_for_read() as session:
endpoints = (session.query(Endpoint).
options(sql.joinedload(Endpoint.service)).
filter(Endpoint.enabled == true()).all())
catalog = {}
for endpoint in endpoints:
if not endpoint.service['enabled']:
continue
except exception.MalformedEndpoint:
continue # this failure is already logged in format_url()
region = endpoint['region_id']
service_type = endpoint.service['type']
default_service = {
'id': endpoint['id'],
'name': endpoint.service.extra.get('name', ''),
'publicURL': ''
}
catalog.setdefault(region, {})
catalog[region].setdefault(service_type, default_service)
interface_url = '%sURL' % endpoint['interface']
catalog[region][service_type][interface_url] = url
return catalog
try:
formatted_url = core.format_url(
endpoint['url'], substitutions,
silent_keyerror_failures=silent_keyerror_failures)
if formatted_url is not None:
url = formatted_url
else:
continue
except exception.MalformedEndpoint:
continue # this failure is already logged in format_url()
region = endpoint['region_id']
service_type = endpoint.service['type']
default_service = {
'id': endpoint['id'],
'name': endpoint.service.extra.get('name', ''),
'publicURL': ''
}
catalog.setdefault(region, {})
catalog[region].setdefault(service_type, default_service)
interface_url = '%sURL' % endpoint['interface']
catalog[region][service_type][interface_url] = url
return catalog
def get_v3_catalog(self, user_id, tenant_id):
"""Retrieve and format the current V3 service catalog.
@ -349,44 +338,46 @@ class Catalog(catalog.CatalogDriverV8):
else:
silent_keyerror_failures = ['tenant_id', 'project_id', ]
session = sql.get_session()
services = (session.query(Service).filter(Service.enabled == true()).
options(sql.joinedload(Service.endpoints)).
all())
def make_v3_endpoints(endpoints):
for endpoint in (ep.to_dict() for ep in endpoints if ep.enabled):
del endpoint['service_id']
del endpoint['legacy_endpoint_id']
del endpoint['enabled']
endpoint['region'] = endpoint['region_id']
try:
formatted_url = core.format_url(
endpoint['url'], d,
silent_keyerror_failures=silent_keyerror_failures)
if formatted_url:
endpoint['url'] = formatted_url
else:
with sql.session_for_read() as session:
services = (session.query(Service).filter(
Service.enabled == true()).options(
sql.joinedload(Service.endpoints)).all())
def make_v3_endpoints(endpoints):
for endpoint in (ep.to_dict()
for ep in endpoints if ep.enabled):
del endpoint['service_id']
del endpoint['legacy_endpoint_id']
del endpoint['enabled']
endpoint['region'] = endpoint['region_id']
try:
formatted_url = core.format_url(
endpoint['url'], d,
silent_keyerror_failures=silent_keyerror_failures)
if formatted_url:
endpoint['url'] = formatted_url
else:
continue
except exception.MalformedEndpoint:
# this failure is already logged in format_url()
continue
except exception.MalformedEndpoint:
continue # this failure is already logged in format_url()
yield endpoint
yield endpoint
# TODO(davechen): If there is service with no endpoints, we should skip
# the service instead of keeping it in the catalog, see bug #1436704.
def make_v3_service(svc):
eps = list(make_v3_endpoints(svc.endpoints))
service = {'endpoints': eps, 'id': svc.id, 'type': svc.type}
service['name'] = svc.extra.get('name', '')
return service
# TODO(davechen): If there is service with no endpoints, we should
# skip the service instead of keeping it in the catalog,
# see bug #1436704.
def make_v3_service(svc):
eps = list(make_v3_endpoints(svc.endpoints))
service = {'endpoints': eps, 'id': svc.id, 'type': svc.type}
service['name'] = svc.extra.get('name', '')
return service
return [make_v3_service(svc) for svc in services]
return [make_v3_service(svc) for svc in services]
@sql.handle_conflicts(conflict_type='project_endpoint')
def add_endpoint_to_project(self, endpoint_id, project_id):
session = sql.get_session()
with session.begin():
with sql.session_for_write() as session:
endpoint_filter_ref = ProjectEndpoint(endpoint_id=endpoint_id,
project_id=project_id)
session.add(endpoint_filter_ref)
@ -402,50 +393,46 @@ class Catalog(catalog.CatalogDriverV8):
return endpoint_filter_ref
def check_endpoint_in_project(self, endpoint_id, project_id):
session = sql.get_session()
self._get_project_endpoint_ref(session, endpoint_id, project_id)
with sql.session_for_read() as session:
self._get_project_endpoint_ref(session, endpoint_id, project_id)
def remove_endpoint_from_project(self, endpoint_id, project_id):
session = sql.get_session()
endpoint_filter_ref = self._get_project_endpoint_ref(
session, endpoint_id, project_id)
with session.begin():
with sql.session_for_write() as session:
endpoint_filter_ref = self._get_project_endpoint_ref(
session, endpoint_id, project_id)
session.delete(endpoint_filter_ref)
def list_endpoints_for_project(self, project_id):
session = sql.get_session()
query = session.query(ProjectEndpoint)
query = query.filter_by(project_id=project_id)
endpoint_filter_refs = query.all()
return [ref.to_dict() for ref in endpoint_filter_refs]
with sql.session_for_read() as session:
query = session.query(ProjectEndpoint)
query = query.filter_by(project_id=project_id)
endpoint_filter_refs = query.all()
return [ref.to_dict() for ref in endpoint_filter_refs]
def list_projects_for_endpoint(self, endpoint_id):
session = sql.get_session()
query = session.query(ProjectEndpoint)
query = query.filter_by(endpoint_id=endpoint_id)
endpoint_filter_refs = query.all()
return [ref.to_dict() for ref in endpoint_filter_refs]
with sql.session_for_read() as session:
query = session.query(ProjectEndpoint)
query = query.filter_by(endpoint_id=endpoint_id)
endpoint_filter_refs = query.all()
return [ref.to_dict() for ref in endpoint_filter_refs]
def delete_association_by_endpoint(self, endpoint_id):
session = sql.get_session()
with session.begin():
with sql.session_for_write() as session:
query = session.query(ProjectEndpoint)
query = query.filter_by(endpoint_id=endpoint_id)
query.delete(synchronize_session=False)
def delete_association_by_project(self, project_id):
session = sql.get_session()
with session.begin():
with sql.session_for_write() as session:
query = session.query(ProjectEndpoint)
query = query.filter_by(project_id=project_id)
query.delete(synchronize_session=False)
def create_endpoint_group(self, endpoint_group_id, endpoint_group):
session = sql.get_session()
with session.begin():
with sql.session_for_write() as session:
endpoint_group_ref = EndpointGroup.from_dict(endpoint_group)
session.add(endpoint_group_ref)
return endpoint_group_ref.to_dict()
return endpoint_group_ref.to_dict()
def _get_endpoint_group(self, session, endpoint_group_id):
endpoint_group_ref = session.query(EndpointGroup).get(
@ -456,14 +443,13 @@ class Catalog(catalog.CatalogDriverV8):
return endpoint_group_ref
def get_endpoint_group(self, endpoint_group_id):
session = sql.get_session()
endpoint_group_ref = self._get_endpoint_group(session,
endpoint_group_id)
return endpoint_group_ref.to_dict()
with sql.session_for_read() as session:
endpoint_group_ref = self._get_endpoint_group(session,
endpoint_group_id)
return endpoint_group_ref.to_dict()
def update_endpoint_group(self, endpoint_group_id, endpoint_group):
session = sql.get_session()
with session.begin():
with sql.session_for_write() as session:
endpoint_group_ref = self._get_endpoint_group(session,
endpoint_group_id)
old_endpoint_group = endpoint_group_ref.to_dict()
@ -472,29 +458,26 @@ class Catalog(catalog.CatalogDriverV8):
for attr in EndpointGroup.mutable_attributes:
setattr(endpoint_group_ref, attr,
getattr(new_endpoint_group, attr))
return endpoint_group_ref.to_dict()
return endpoint_group_ref.to_dict()
def delete_endpoint_group(self, endpoint_group_id):
session = sql.get_session()
endpoint_group_ref = self._get_endpoint_group(session,
endpoint_group_id)
with session.begin():
with sql.session_for_write() as session:
endpoint_group_ref = self._get_endpoint_group(session,
endpoint_group_id)
self._delete_endpoint_group_association_by_endpoint_group(
session, endpoint_group_id)
session.delete(endpoint_group_ref)
def get_endpoint_group_in_project(self, endpoint_group_id, project_id):
session = sql.get_session()
ref = self._get_endpoint_group_in_project(session,
endpoint_group_id,
project_id)
return ref.to_dict()
with sql.session_for_read() as session:
ref = self._get_endpoint_group_in_project(session,
endpoint_group_id,
project_id)
return ref.to_dict()
@sql.handle_conflicts(conflict_type='project_endpoint_group')
def add_endpoint_group_to_project(self, endpoint_group_id, project_id):
session = sql.get_session()
with session.begin():
with sql.session_for_write() as session:
# Create a new Project Endpoint group entity
endpoint_group_project_ref = ProjectEndpointGroupMembership(
endpoint_group_id=endpoint_group_id, project_id=project_id)
@ -512,32 +495,31 @@ class Catalog(catalog.CatalogDriverV8):
return endpoint_group_project_ref
def list_endpoint_groups(self):
session = sql.get_session()
query = session.query(EndpointGroup)
endpoint_group_refs = query.all()
return [e.to_dict() for e in endpoint_group_refs]
with sql.session_for_read() as session:
query = session.query(EndpointGroup)
endpoint_group_refs = query.all()
return [e.to_dict() for e in endpoint_group_refs]
def list_endpoint_groups_for_project(self, project_id):
session = sql.get_session()
query = session.query(ProjectEndpointGroupMembership)
query = query.filter_by(project_id=project_id)
endpoint_group_refs = query.all()
return [ref.to_dict() for ref in endpoint_group_refs]
with sql.session_for_read() as session:
query = session.query(ProjectEndpointGroupMembership)
query = query.filter_by(project_id=project_id)
endpoint_group_refs = query.all()
return [ref.to_dict() for ref in endpoint_group_refs]
def remove_endpoint_group_from_project(self, endpoint_group_id,
project_id):
session = sql.get_session()
endpoint_group_project_ref = self._get_endpoint_group_in_project(
session, endpoint_group_id, project_id)
with session.begin():
with sql.session_for_write() as session:
endpoint_group_project_ref = self._get_endpoint_group_in_project(
session, endpoint_group_id, project_id)
session.delete(endpoint_group_project_ref)
def list_projects_associated_with_endpoint_group(self, endpoint_group_id):
session = sql.get_session()
query = session.query(ProjectEndpointGroupMembership)
query = query.filter_by(endpoint_group_id=endpoint_group_id)
endpoint_group_refs = query.all()
return [ref.to_dict() for ref in endpoint_group_refs]
with sql.session_for_read() as session:
query = session.query(ProjectEndpointGroupMembership)
query = query.filter_by(endpoint_group_id=endpoint_group_id)
endpoint_group_refs = query.all()
return [ref.to_dict() for ref in endpoint_group_refs]
def _delete_endpoint_group_association_by_endpoint_group(
self, session, endpoint_group_id):
@ -546,8 +528,7 @@ class Catalog(catalog.CatalogDriverV8):
query.delete()
def delete_endpoint_group_association_by_project(self, project_id):
session = sql.get_session()
with session.begin():
with sql.session_for_write() as session:
query = session.query(ProjectEndpointGroupMembership)
query = query.filter_by(project_id=project_id)
query.delete()

@ -18,14 +18,14 @@ Before using this module, call initialize(). This has to be done before
CONF() because it sets up configuration options.
"""
import contextlib
import functools
import threading
from oslo_config import cfg
from oslo_db import exception as db_exception
from oslo_db import options as db_options
from oslo_db.sqlalchemy import enginefacade
from oslo_db.sqlalchemy import models
from oslo_db.sqlalchemy import session as db_session
from oslo_log import log
from oslo_serialization import jsonutils
import six
@ -166,38 +166,41 @@ class ModelDictMixin(object):
return {name: getattr(self, name) for name in names}
_engine_facade = None
_main_context_manager = None
def _get_engine_facade():
global _engine_facade
def _get_main_context_manager():
global _main_context_manager
if not _engine_facade:
_engine_facade = db_session.EngineFacade.from_config(CONF)
if not _main_context_manager:
_main_context_manager = enginefacade.transaction_context()
return _engine_facade
return _main_context_manager
def cleanup():
global _engine_facade
global _main_context_manager
_engine_facade = None
_main_context_manager = None
def get_engine():
return _get_engine_facade().get_engine()
return _get_main_context_manager().get_legacy_facade().get_engine()
def get_session(expire_on_commit=False):
return _get_engine_facade().get_session(expire_on_commit=expire_on_commit)
def get_session():
return _get_main_context_manager().get_legacy_facade().get_session()
@contextlib.contextmanager
def transaction(expire_on_commit=False):
"""Return a SQLAlchemy session in a scoped transaction."""
session = get_session(expire_on_commit=expire_on_commit)
with session.begin():
yield session
_CONTEXT = threading.local()
def session_for_read():
return _get_main_context_manager().reader.using(_CONTEXT)
def session_for_write():
return _get_main_context_manager().writer.using(_CONTEXT)
def truncated(f):

@ -178,7 +178,7 @@ def _sync_extension_repo(extension, version):
try:
abs_path = find_migrate_repo(package)
try:
migration.db_version_control(sql.get_engine(), abs_path)
migration.db_version_control(engine, abs_path)
# Register the repo with the version control API
# If it already knows about the repo, it will throw
# an exception that we can safely ignore

@ -36,28 +36,27 @@ class Credential(credential.CredentialDriverV8):
@sql.handle_conflicts(conflict_type='credential')
def create_credential(self, credential_id, credential):
session = sql.get_session()
with session.begin():
with sql.session_for_write() as session:
ref = CredentialModel.from_dict(credential)
session.add(ref)
return ref.to_dict()
return ref.to_dict()
@driver_hints.truncated
def list_credentials(self, hints):
session = sql.get_session()
credentials = session.query(CredentialModel)
credentials = sql.filter_limit_query(CredentialModel,
credentials, hints)
return [s.to_dict() for s in credentials]
with sql.session_for_read() as session:
credentials = session.query(CredentialModel)
credentials = sql.filter_limit_query(CredentialModel,
credentials, hints)
return [s.to_dict() for s in credentials]
def list_credentials_for_user(self, user_id, type=None):
session = sql.get_session()
query = session.query(CredentialModel)
query = query.filter_by(user_id=user_id)
if type:
query = query.filter_by(type=type)
refs = query.all()
return [ref.to_dict() for ref in refs]
with sql.session_for_read() as session:
query = session.query(CredentialModel)
query = query.filter_by(user_id=user_id)
if type:
query = query.filter_by(type=type)
refs = query.all()
return [ref.to_dict() for ref in refs]
def _get_credential(self, session, credential_id):
ref = session.query(CredentialModel).get(credential_id)
@ -66,13 +65,12 @@ class Credential(credential.CredentialDriverV8):
return ref
def get_credential(self, credential_id):
session = sql.get_session()
return self._get_credential(session, credential_id).to_dict()
with sql.session_for_read() as session:
return self._get_credential(session, credential_id).to_dict()
@sql.handle_conflicts(conflict_type='credential')
def update_credential(self, credential_id, credential):
session = sql.get_session()
with session.begin():
with sql.session_for_write() as session:
ref = self._get_credential(session, credential_id)
old_dict = ref.to_dict()
for k in credential:
@ -82,27 +80,21 @@ class Credential(credential.CredentialDriverV8):
if attr != 'id':
setattr(ref, attr, getattr(new_credential, attr))
ref.extra = new_credential.extra
return ref.to_dict()
return ref.to_dict()
def delete_credential(self, credential_id):
session = sql.get_session()
with session.begin():
with sql.session_for_write() as session:
ref = self._get_credential(session, credential_id)
session.delete(ref)
def delete_credentials_for_project(self, project_id):
session = sql.get_session()
with session.begin():
with sql.session_for_write() as session:
query = session.query(CredentialModel)
query = query.filter_by(project_id=project_id)
query.delete()
def delete_credentials_for_user(self, user_id):
session = sql.get_session()
with session.begin():
with sql.session_for_write() as session:
query = session.query(CredentialModel)
query = query.filter_by(user_id=user_id)
query.delete()

@ -51,7 +51,7 @@ class EndpointPolicy(object):
def create_policy_association(self, policy_id, endpoint_id=None,
service_id=None, region_id=None):
with sql.transaction() as session:
with sql.session_for_write() as session:
try:
# See if there is already a row for this association, and if
# so, update it with the new policy_id
@ -79,14 +79,14 @@ class EndpointPolicy(object):
# NOTE(henry-nash): Getting a single value to save object
# management overhead.
with sql.transaction() as session:
with sql.session_for_read() as session:
if session.query(PolicyAssociation.id).filter(
sql_constraints).distinct().count() == 0:
raise exception.PolicyAssociationNotFound()
def delete_policy_association(self, policy_id, endpoint_id=None,
service_id=None, region_id=None):
with sql.transaction() as session:
with sql.session_for_write() as session:
query = session.query(PolicyAssociation)
query = query.filter_by(policy_id=policy_id)
query = query.filter_by(endpoint_id=endpoint_id)
@ -102,7 +102,7 @@ class EndpointPolicy(object):
PolicyAssociation.region_id == region_id)
try:
with sql.transaction() as session:
with sql.session_for_read() as session:
policy_id = session.query(PolicyAssociation.policy_id).filter(
sql_constraints).distinct().one()
return {'policy_id': policy_id}
@ -110,31 +110,31 @@ class EndpointPolicy(object):
raise exception.PolicyAssociationNotFound()
def list_associations_for_policy(self, policy_id):
with sql.transaction() as session:
with sql.session_for_read() as session:
query = session.query(PolicyAssociation)
query = query.filter_by(policy_id=policy_id)
return [ref.to_dict() for ref in query.all()]
def delete_association_by_endpoint(self, endpoint_id):
with sql.transaction() as session:
with sql.session_for_write() as session:
query = session.query(PolicyAssociation)
query = query.filter_by(endpoint_id=endpoint_id)
query.delete()
def delete_association_by_service(self, service_id):
with sql.transaction() as session:
with sql.session_for_write() as session:
query = session.query(PolicyAssociation)
query = query.filter_by(service_id=service_id)
query.delete()
def delete_association_by_region(self, region_id):
with sql.transaction() as session:
with sql.session_for_write() as session:
query = session.query(PolicyAssociation)
query = query.filter_by(region_id=region_id)
query.delete()
def delete_association_by_policy(self, policy_id):
with sql.transaction() as session:
with sql.session_for_write() as session:
query = session.query(PolicyAssociation)
query = query.filter_by(policy_id=policy_id)
query.delete()

@ -161,13 +161,13 @@ class Federation(core.FederationDriverV8):
@sql.handle_conflicts(conflict_type='identity_provider')
def create_idp(self, idp_id, idp):
idp['id'] = idp_id
with sql.transaction() as session:
with sql.session_for_write() as session:
idp_ref = IdentityProviderModel.from_dict(idp)
session.add(idp_ref)
return idp_ref.to_dict()
return idp_ref.to_dict()
def delete_idp(self, idp_id):
with sql.transaction() as session:
with sql.session_for_write() as session:
self._delete_assigned_protocols(session, idp_id)
idp_ref = self._get_idp(session, idp_id)
session.delete(idp_ref)
@ -187,30 +187,30 @@ class Federation(core.FederationDriverV8):
raise exception.IdentityProviderNotFound(idp_id=remote_id)
def list_idps(self):
with sql.transaction() as session:
with sql.session_for_read() as session:
idps = session.query(IdentityProviderModel)
idps_list = [idp.to_dict() for idp in idps]
return idps_list
idps_list = [idp.to_dict() for idp in idps]
return idps_list
def get_idp(self, idp_id):
with sql.transaction() as session:
with sql.session_for_read() as session:
idp_ref = self._get_idp(session, idp_id)
return idp_ref.to_dict()
return idp_ref.to_dict()
def get_idp_from_remote_id(self, remote_id):
with sql.transaction() as session:
with sql.session_for_read() as session:
ref = self._get_idp_from_remote_id(session, remote_id)
return ref.to_dict()
return ref.to_dict()
def update_idp(self, idp_id, idp):
with sql.transaction() as session:
with sql.session_for_write() as session:
idp_ref = self._get_idp(session, idp_id)
old_idp = idp_ref.to_dict()
old_idp.update(idp)
new_idp = IdentityProviderModel.from_dict(old_idp)
for attr in IdentityProviderModel.mutable_attributes:
setattr(idp_ref, attr, getattr(new_idp, attr))
return idp_ref.to_dict()
return idp_ref.to_dict()
# Protocol CRUD
def _get_protocol(self, session, idp_id, protocol_id):
@ -227,36 +227,36 @@ class Federation(core.FederationDriverV8):
def create_protocol(self, idp_id, protocol_id, protocol):
protocol['id'] = protocol_id
protocol['idp_id'] = idp_id
with sql.transaction() as session:
with sql.session_for_write() as session:
self._get_idp(session, idp_id)
protocol_ref = FederationProtocolModel.from_dict(protocol)
session.add(protocol_ref)
return protocol_ref.to_dict()
return protocol_ref.to_dict()
def update_protocol(self, idp_id, protocol_id, protocol):
with sql.transaction() as session:
with sql.session_for_write() as session:
proto_ref = self._get_protocol(session, idp_id, protocol_id)
old_proto = proto_ref.to_dict()
old_proto.update(protocol)
new_proto = FederationProtocolModel.from_dict(old_proto)
for attr in FederationProtocolModel.mutable_attributes:
setattr(proto_ref, attr, getattr(new_proto, attr))
return proto_ref.to_dict()
return proto_ref.to_dict()
def get_protocol(self, idp_id, protocol_id):