diff --git a/keystone/assignment/V8_backends/sql.py b/keystone/assignment/V8_backends/sql.py index 6ae563766e..88c10a6a9a 100644 --- a/keystone/assignment/V8_backends/sql.py +++ b/keystone/assignment/V8_backends/sql.py @@ -56,7 +56,7 @@ class Assignment(keystone_assignment.AssignmentDriverV8): return 'sql' def list_user_ids_for_project(self, tenant_id): - with sql.transaction() as session: + with sql.session_for_read() as session: query = session.query(RoleAssignment.actor_id) query = query.filter_by(type=AssignmentType.USER_PROJECT) query = query.filter_by(target_id=tenant_id) @@ -71,7 +71,7 @@ class Assignment(keystone_assignment.AssignmentDriverV8): assignment_type = AssignmentType.calculate_type( user_id, group_id, project_id, domain_id) try: - with sql.transaction() as session: + with sql.session_for_write() as session: session.add(RoleAssignment( type=assignment_type, actor_id=user_id or group_id, @@ -85,7 +85,7 @@ class Assignment(keystone_assignment.AssignmentDriverV8): def list_grant_role_ids(self, user_id=None, group_id=None, domain_id=None, project_id=None, inherited_to_projects=False): - with sql.transaction() as session: + with sql.session_for_read() as session: q = session.query(RoleAssignment.role_id) q = q.filter(RoleAssignment.actor_id == (user_id or group_id)) q = q.filter(RoleAssignment.target_id == (project_id or domain_id)) @@ -104,7 +104,7 @@ class Assignment(keystone_assignment.AssignmentDriverV8): def check_grant_role_id(self, role_id, user_id=None, group_id=None, domain_id=None, project_id=None, inherited_to_projects=False): - with sql.transaction() as session: + with sql.session_for_read() as session: try: q = self._build_grant_filter( session, role_id, user_id, group_id, domain_id, project_id, @@ -120,7 +120,7 @@ class Assignment(keystone_assignment.AssignmentDriverV8): def delete_grant(self, role_id, user_id=None, group_id=None, domain_id=None, project_id=None, inherited_to_projects=False): - with sql.transaction() as session: + with sql.session_for_write() as session: q = self._build_grant_filter( session, role_id, user_id, group_id, domain_id, project_id, inherited_to_projects) @@ -145,11 +145,11 @@ class Assignment(keystone_assignment.AssignmentDriverV8): RoleAssignment.inherited == inherited, RoleAssignment.actor_id.in_(actors)) - with sql.transaction() as session: + with sql.session_for_read() as session: query = session.query(RoleAssignment.target_id).filter( sql_constraints).distinct() - return [x.target_id for x in query.all()] + return [x.target_id for x in query.all()] def list_project_ids_for_user(self, user_id, group_ids, hints, inherited=False): @@ -161,7 +161,7 @@ class Assignment(keystone_assignment.AssignmentDriverV8): def list_domain_ids_for_user(self, user_id, group_ids, hints, inherited=False): - with sql.transaction() as session: + with sql.session_for_read() as session: query = session.query(RoleAssignment.target_id) filters = [] @@ -197,10 +197,10 @@ class Assignment(keystone_assignment.AssignmentDriverV8): RoleAssignment.inherited == false(), RoleAssignment.actor_id.in_(group_ids)) - with sql.transaction() as session: + with sql.session_for_read() as session: query = session.query(RoleAssignment.role_id).filter( sql_constraints).distinct() - return [role.role_id for role in query.all()] + return [role.role_id for role in query.all()] def list_role_ids_for_groups_on_project( self, group_ids, project_id, project_domain_id, project_parents): @@ -237,13 +237,13 @@ class Assignment(keystone_assignment.AssignmentDriverV8): sql_constraints = sqlalchemy.and_( sql_constraints, RoleAssignment.actor_id.in_(group_ids)) - with sql.transaction() as session: + with sql.session_for_read() as session: # NOTE(morganfainberg): Only select the columns we actually care # about here, in this case role_id. query = session.query(RoleAssignment.role_id).filter( sql_constraints).distinct() - return [result.role_id for result in query.all()] + return [result.role_id for result in query.all()] def list_project_ids_for_groups(self, group_ids, hints, inherited=False): @@ -260,14 +260,14 @@ class Assignment(keystone_assignment.AssignmentDriverV8): RoleAssignment.inherited == inherited, RoleAssignment.actor_id.in_(group_ids)) - with sql.transaction() as session: + with sql.session_for_read() as session: query = session.query(RoleAssignment.target_id).filter( group_sql_conditions).distinct() - return [x.target_id for x in query.all()] + return [x.target_id for x in query.all()] def add_role_to_user_and_project(self, user_id, tenant_id, role_id): try: - with sql.transaction() as session: + with sql.session_for_write() as session: session.add(RoleAssignment( type=AssignmentType.USER_PROJECT, actor_id=user_id, target_id=tenant_id, @@ -278,7 +278,7 @@ class Assignment(keystone_assignment.AssignmentDriverV8): raise exception.Conflict(type='role grant', details=msg) def remove_role_from_user_and_project(self, user_id, tenant_id, role_id): - with sql.transaction() as session: + with sql.session_for_write() as session: q = session.query(RoleAssignment) q = q.filter_by(actor_id=user_id) q = q.filter_by(target_id=tenant_id) @@ -368,7 +368,7 @@ class Assignment(keystone_assignment.AssignmentDriverV8): assignment['inherited_to_projects'] = 'projects' return assignment - with sql.transaction() as session: + with sql.session_for_read() as session: assignment_types = self._get_assignment_types( user_id, group_ids, project_ids, domain_id) @@ -400,25 +400,25 @@ class Assignment(keystone_assignment.AssignmentDriverV8): return [denormalize_role(ref) for ref in query.all()] def delete_project_assignments(self, project_id): - with sql.transaction() as session: + with sql.session_for_write() as session: q = session.query(RoleAssignment) q = q.filter_by(target_id=project_id) q.delete(False) def delete_role_assignments(self, role_id): - with sql.transaction() as session: + with sql.session_for_write() as session: q = session.query(RoleAssignment) q = q.filter_by(role_id=role_id) q.delete(False) def delete_user_assignments(self, user_id): - with sql.transaction() as session: + with sql.session_for_write() as session: q = session.query(RoleAssignment) q = q.filter_by(actor_id=user_id) q.delete(False) def delete_group_assignments(self, group_id): - with sql.transaction() as session: + with sql.session_for_write() as session: q = session.query(RoleAssignment) q = q.filter_by(actor_id=group_id) q.delete(False) diff --git a/keystone/assignment/V8_role_backends/sql.py b/keystone/assignment/V8_role_backends/sql.py index b6533bdee5..2e2e119a18 100644 --- a/keystone/assignment/V8_role_backends/sql.py +++ b/keystone/assignment/V8_role_backends/sql.py @@ -19,14 +19,14 @@ class Role(assignment.RoleDriverV8): @sql.handle_conflicts(conflict_type='role') def create_role(self, role_id, role): - with sql.transaction() as session: + with sql.session_for_write() as session: ref = RoleTable.from_dict(role) session.add(ref) return ref.to_dict() @sql.truncated def list_roles(self, hints): - with sql.transaction() as session: + with sql.session_for_read() as session: query = session.query(RoleTable) refs = sql.filter_limit_query(RoleTable, query, hints) return [ref.to_dict() for ref in refs] @@ -35,7 +35,7 @@ class Role(assignment.RoleDriverV8): if not ids: return [] else: - with sql.transaction() as session: + with sql.session_for_read() as session: query = session.query(RoleTable) query = query.filter(RoleTable.id.in_(ids)) role_refs = query.all() @@ -48,12 +48,12 @@ class Role(assignment.RoleDriverV8): return ref def get_role(self, role_id): - with sql.transaction() as session: + with sql.session_for_read() as session: return self._get_role(session, role_id).to_dict() @sql.handle_conflicts(conflict_type='role') def update_role(self, role_id, role): - with sql.transaction() as session: + with sql.session_for_write() as session: ref = self._get_role(session, role_id) old_dict = ref.to_dict() for k in role: @@ -66,7 +66,7 @@ class Role(assignment.RoleDriverV8): return ref.to_dict() def delete_role(self, role_id): - with sql.transaction() as session: + with sql.session_for_write() as session: ref = self._get_role(session, role_id) session.delete(ref) diff --git a/keystone/assignment/backends/sql.py b/keystone/assignment/backends/sql.py index a61ea9e5aa..f69fb10956 100644 --- a/keystone/assignment/backends/sql.py +++ b/keystone/assignment/backends/sql.py @@ -55,7 +55,7 @@ class Assignment(keystone_assignment.AssignmentDriverV9): assignment_type = AssignmentType.calculate_type( user_id, group_id, project_id, domain_id) try: - with sql.transaction() as session: + with sql.session_for_write() as session: session.add(RoleAssignment( type=assignment_type, actor_id=user_id or group_id, @@ -69,7 +69,7 @@ class Assignment(keystone_assignment.AssignmentDriverV9): def list_grant_role_ids(self, user_id=None, group_id=None, domain_id=None, project_id=None, inherited_to_projects=False): - with sql.transaction() as session: + with sql.session_for_read() as session: q = session.query(RoleAssignment.role_id) q = q.filter(RoleAssignment.actor_id == (user_id or group_id)) q = q.filter(RoleAssignment.target_id == (project_id or domain_id)) @@ -88,7 +88,7 @@ class Assignment(keystone_assignment.AssignmentDriverV9): def check_grant_role_id(self, role_id, user_id=None, group_id=None, domain_id=None, project_id=None, inherited_to_projects=False): - with sql.transaction() as session: + with sql.session_for_read() as session: try: q = self._build_grant_filter( session, role_id, user_id, group_id, domain_id, project_id, @@ -104,7 +104,7 @@ class Assignment(keystone_assignment.AssignmentDriverV9): def delete_grant(self, role_id, user_id=None, group_id=None, domain_id=None, project_id=None, inherited_to_projects=False): - with sql.transaction() as session: + with sql.session_for_write() as session: q = self._build_grant_filter( session, role_id, user_id, group_id, domain_id, project_id, inherited_to_projects) @@ -117,7 +117,7 @@ class Assignment(keystone_assignment.AssignmentDriverV9): def add_role_to_user_and_project(self, user_id, tenant_id, role_id): try: - with sql.transaction() as session: + with sql.session_for_write() as session: session.add(RoleAssignment( type=AssignmentType.USER_PROJECT, actor_id=user_id, target_id=tenant_id, @@ -128,7 +128,7 @@ class Assignment(keystone_assignment.AssignmentDriverV9): raise exception.Conflict(type='role grant', details=msg) def remove_role_from_user_and_project(self, user_id, tenant_id, role_id): - with sql.transaction() as session: + with sql.session_for_write() as session: q = session.query(RoleAssignment) q = q.filter_by(actor_id=user_id) q = q.filter_by(target_id=tenant_id) @@ -218,7 +218,7 @@ class Assignment(keystone_assignment.AssignmentDriverV9): assignment['inherited_to_projects'] = 'projects' return assignment - with sql.transaction() as session: + with sql.session_for_read() as session: assignment_types = self._get_assignment_types( user_id, group_ids, project_ids, domain_id) @@ -250,7 +250,7 @@ class Assignment(keystone_assignment.AssignmentDriverV9): return [denormalize_role(ref) for ref in query.all()] def delete_project_assignments(self, project_id): - with sql.transaction() as session: + with sql.session_for_write() as session: q = session.query(RoleAssignment) q = q.filter_by(target_id=project_id).filter( RoleAssignment.type.in_((AssignmentType.USER_PROJECT, @@ -259,13 +259,13 @@ class Assignment(keystone_assignment.AssignmentDriverV9): q.delete(False) def delete_role_assignments(self, role_id): - with sql.transaction() as session: + with sql.session_for_write() as session: q = session.query(RoleAssignment) q = q.filter_by(role_id=role_id) q.delete(False) def delete_user_assignments(self, user_id): - with sql.transaction() as session: + with sql.session_for_write() as session: q = session.query(RoleAssignment) q = q.filter_by(actor_id=user_id).filter( RoleAssignment.type.in_((AssignmentType.USER_PROJECT, @@ -274,7 +274,7 @@ class Assignment(keystone_assignment.AssignmentDriverV9): q.delete(False) def delete_group_assignments(self, group_id): - with sql.transaction() as session: + with sql.session_for_write() as session: q = session.query(RoleAssignment) q = q.filter_by(actor_id=group_id).filter( RoleAssignment.type.in_((AssignmentType.GROUP_PROJECT, diff --git a/keystone/assignment/role_backends/sql.py b/keystone/assignment/role_backends/sql.py index 050bf5bec5..1045f23a51 100644 --- a/keystone/assignment/role_backends/sql.py +++ b/keystone/assignment/role_backends/sql.py @@ -29,7 +29,7 @@ class Role(assignment.RoleDriverV9): @sql.handle_conflicts(conflict_type='role') def create_role(self, role_id, role): - with sql.transaction() as session: + with sql.session_for_write() as session: ref = RoleTable.from_dict(role) session.add(ref) return ref.to_dict() @@ -46,7 +46,7 @@ class Role(assignment.RoleDriverV9): if (f['name'] == 'domain_id' and f['value'] is None): f['value'] = NULL_DOMAIN_ID - with sql.transaction() as session: + with sql.session_for_read() as session: query = session.query(RoleTable) refs = sql.filter_limit_query(RoleTable, query, hints) return [ref.to_dict() for ref in refs] @@ -55,7 +55,7 @@ class Role(assignment.RoleDriverV9): if not ids: return [] else: - with sql.transaction() as session: + with sql.session_for_read() as session: query = session.query(RoleTable) query = query.filter(RoleTable.id.in_(ids)) role_refs = query.all() @@ -68,12 +68,12 @@ class Role(assignment.RoleDriverV9): return ref def get_role(self, role_id): - with sql.transaction() as session: + with sql.session_for_read() as session: return self._get_role(session, role_id).to_dict() @sql.handle_conflicts(conflict_type='role') def update_role(self, role_id, role): - with sql.transaction() as session: + with sql.session_for_write() as session: ref = self._get_role(session, role_id) old_dict = ref.to_dict() for k in role: @@ -86,7 +86,7 @@ class Role(assignment.RoleDriverV9): return ref.to_dict() def delete_role(self, role_id): - with sql.transaction() as session: + with sql.session_for_write() as session: ref = self._get_role(session, role_id) session.delete(ref) @@ -105,7 +105,7 @@ class Role(assignment.RoleDriverV9): @sql.handle_conflicts(conflict_type='implied_role') def create_implied_role(self, prior_role_id, implied_role_id): - with sql.transaction() as session: + with sql.session_for_write() as session: inference = {'prior_role_id': prior_role_id, 'implied_role_id': implied_role_id} ref = ImpliedRoleTable.from_dict(inference) @@ -119,13 +119,13 @@ class Role(assignment.RoleDriverV9): return ref.to_dict() def delete_implied_role(self, prior_role_id, implied_role_id): - with sql.transaction() as session: + with sql.session_for_write() as session: ref = self._get_implied_role(session, prior_role_id, implied_role_id) session.delete(ref) def list_implied_roles(self, prior_role_id): - with sql.transaction() as session: + with sql.session_for_read() as session: query = session.query( ImpliedRoleTable).filter( ImpliedRoleTable.prior_role_id == prior_role_id) @@ -133,13 +133,13 @@ class Role(assignment.RoleDriverV9): return [ref.to_dict() for ref in refs] def list_role_inference_rules(self): - with sql.transaction() as session: + with sql.session_for_read() as session: query = session.query(ImpliedRoleTable) refs = query.all() return [ref.to_dict() for ref in refs] def get_implied_role(self, prior_role_id, implied_role_id): - with sql.transaction() as session: + with sql.session_for_read() as session: ref = self._get_implied_role(session, prior_role_id, implied_role_id) return ref.to_dict() diff --git a/keystone/catalog/backends/sql.py b/keystone/catalog/backends/sql.py index c923b1302e..bd92f10794 100644 --- a/keystone/catalog/backends/sql.py +++ b/keystone/catalog/backends/sql.py @@ -84,10 +84,10 @@ class Endpoint(sql.ModelBase, sql.DictBase): class Catalog(catalog.CatalogDriverV8): # Regions def list_regions(self, hints): - session = sql.get_session() - regions = session.query(Region) - regions = sql.filter_limit_query(Region, regions, hints) - return [s.to_dict() for s in list(regions)] + with sql.session_for_read() as session: + regions = session.query(Region) + regions = sql.filter_limit_query(Region, regions, hints) + return [s.to_dict() for s in list(regions)] def _get_region(self, session, region_id): ref = session.query(Region).get(region_id) @@ -136,12 +136,11 @@ class Catalog(catalog.CatalogDriverV8): return False def get_region(self, region_id): - session = sql.get_session() - return self._get_region(session, region_id).to_dict() + with sql.session_for_read() as session: + return self._get_region(session, region_id).to_dict() def delete_region(self, region_id): - session = sql.get_session() - with session.begin(): + with sql.session_for_write() as session: ref = self._get_region(session, region_id) if self._has_endpoints(session, ref, ref): raise exception.RegionDeletionError(region_id=region_id) @@ -150,16 +149,14 @@ class Catalog(catalog.CatalogDriverV8): @sql.handle_conflicts(conflict_type='region') def create_region(self, region_ref): - session = sql.get_session() - with session.begin(): + with sql.session_for_write() as session: self._check_parent_region(session, region_ref) region = Region.from_dict(region_ref) session.add(region) - return region.to_dict() + return region.to_dict() def update_region(self, region_id, region_ref): - session = sql.get_session() - with session.begin(): + with sql.session_for_write() as session: self._check_parent_region(session, region_ref) ref = self._get_region(session, region_id) old_dict = ref.to_dict() @@ -169,15 +166,15 @@ class Catalog(catalog.CatalogDriverV8): for attr in Region.attributes: if attr != 'id': setattr(ref, attr, getattr(new_region, attr)) - return ref.to_dict() + return ref.to_dict() # Services @driver_hints.truncated def list_services(self, hints): - session = sql.get_session() - services = session.query(Service) - services = sql.filter_limit_query(Service, services, hints) - return [s.to_dict() for s in list(services)] + with sql.session_for_read() as session: + services = session.query(Service) + services = sql.filter_limit_query(Service, services, hints) + return [s.to_dict() for s in list(services)] def _get_service(self, session, service_id): ref = session.query(Service).get(service_id) @@ -186,26 +183,23 @@ class Catalog(catalog.CatalogDriverV8): return ref def get_service(self, service_id): - session = sql.get_session() - return self._get_service(session, service_id).to_dict() + with sql.session_for_read() as session: + return self._get_service(session, service_id).to_dict() def delete_service(self, service_id): - session = sql.get_session() - with session.begin(): + with sql.session_for_write() as session: ref = self._get_service(session, service_id) session.query(Endpoint).filter_by(service_id=service_id).delete() session.delete(ref) def create_service(self, service_id, service_ref): - session = sql.get_session() - with session.begin(): + with sql.session_for_write() as session: service = Service.from_dict(service_ref) session.add(service) - return service.to_dict() + return service.to_dict() def update_service(self, service_id, service_ref): - session = sql.get_session() - with session.begin(): + with sql.session_for_write() as session: ref = self._get_service(session, service_id) old_dict = ref.to_dict() old_dict.update(service_ref) @@ -214,20 +208,17 @@ class Catalog(catalog.CatalogDriverV8): if attr != 'id': setattr(ref, attr, getattr(new_service, attr)) ref.extra = new_service.extra - return ref.to_dict() + return ref.to_dict() # Endpoints def create_endpoint(self, endpoint_id, endpoint_ref): - session = sql.get_session() new_endpoint = Endpoint.from_dict(endpoint_ref) - - with session.begin(): + with sql.session_for_write() as session: session.add(new_endpoint) return new_endpoint.to_dict() def delete_endpoint(self, endpoint_id): - session = sql.get_session() - with session.begin(): + with sql.session_for_write() as session: ref = self._get_endpoint(session, endpoint_id) session.delete(ref) @@ -238,20 +229,18 @@ class Catalog(catalog.CatalogDriverV8): raise exception.EndpointNotFound(endpoint_id=endpoint_id) def get_endpoint(self, endpoint_id): - session = sql.get_session() - return self._get_endpoint(session, endpoint_id).to_dict() + with sql.session_for_read() as session: + return self._get_endpoint(session, endpoint_id).to_dict() @driver_hints.truncated def list_endpoints(self, hints): - session = sql.get_session() - endpoints = session.query(Endpoint) - endpoints = sql.filter_limit_query(Endpoint, endpoints, hints) - return [e.to_dict() for e in list(endpoints)] + with sql.session_for_read() as session: + endpoints = session.query(Endpoint) + endpoints = sql.filter_limit_query(Endpoint, endpoints, hints) + return [e.to_dict() for e in list(endpoints)] def update_endpoint(self, endpoint_id, endpoint_ref): - session = sql.get_session() - - with session.begin(): + with sql.session_for_write() as session: ref = self._get_endpoint(session, endpoint_id) old_dict = ref.to_dict() old_dict.update(endpoint_ref) @@ -260,7 +249,7 @@ class Catalog(catalog.CatalogDriverV8): if attr != 'id': setattr(ref, attr, getattr(new_endpoint, attr)) ref.extra = new_endpoint.extra - return ref.to_dict() + return ref.to_dict() def get_catalog(self, user_id, tenant_id): """Retrieve and format the V2 service catalog. @@ -289,40 +278,40 @@ class Catalog(catalog.CatalogDriverV8): else: silent_keyerror_failures = ['tenant_id', 'project_id', ] - session = sql.get_session() - endpoints = (session.query(Endpoint). - options(sql.joinedload(Endpoint.service)). - filter(Endpoint.enabled == true()).all()) + with sql.session_for_read() as session: + endpoints = (session.query(Endpoint). + options(sql.joinedload(Endpoint.service)). + filter(Endpoint.enabled == true()).all()) - catalog = {} + catalog = {} - for endpoint in endpoints: - if not endpoint.service['enabled']: - continue - try: - formatted_url = core.format_url( - endpoint['url'], substitutions, - silent_keyerror_failures=silent_keyerror_failures) - if formatted_url is not None: - url = formatted_url - else: + for endpoint in endpoints: + if not endpoint.service['enabled']: continue - except exception.MalformedEndpoint: - continue # this failure is already logged in format_url() + try: + formatted_url = core.format_url( + endpoint['url'], substitutions, + silent_keyerror_failures=silent_keyerror_failures) + if formatted_url is not None: + url = formatted_url + else: + continue + except exception.MalformedEndpoint: + continue # this failure is already logged in format_url() - region = endpoint['region_id'] - service_type = endpoint.service['type'] - default_service = { - 'id': endpoint['id'], - 'name': endpoint.service.extra.get('name', ''), - 'publicURL': '' - } - catalog.setdefault(region, {}) - catalog[region].setdefault(service_type, default_service) - interface_url = '%sURL' % endpoint['interface'] - catalog[region][service_type][interface_url] = url + region = endpoint['region_id'] + service_type = endpoint.service['type'] + default_service = { + 'id': endpoint['id'], + 'name': endpoint.service.extra.get('name', ''), + 'publicURL': '' + } + catalog.setdefault(region, {}) + catalog[region].setdefault(service_type, default_service) + interface_url = '%sURL' % endpoint['interface'] + catalog[region][service_type][interface_url] = url - return catalog + return catalog def get_v3_catalog(self, user_id, tenant_id): """Retrieve and format the current V3 service catalog. @@ -349,44 +338,46 @@ class Catalog(catalog.CatalogDriverV8): else: silent_keyerror_failures = ['tenant_id', 'project_id', ] - session = sql.get_session() - services = (session.query(Service).filter(Service.enabled == true()). - options(sql.joinedload(Service.endpoints)). - all()) + with sql.session_for_read() as session: + services = (session.query(Service).filter( + Service.enabled == true()).options( + sql.joinedload(Service.endpoints)).all()) - def make_v3_endpoints(endpoints): - for endpoint in (ep.to_dict() for ep in endpoints if ep.enabled): - del endpoint['service_id'] - del endpoint['legacy_endpoint_id'] - del endpoint['enabled'] - endpoint['region'] = endpoint['region_id'] - try: - formatted_url = core.format_url( - endpoint['url'], d, - silent_keyerror_failures=silent_keyerror_failures) - if formatted_url: - endpoint['url'] = formatted_url - else: + def make_v3_endpoints(endpoints): + for endpoint in (ep.to_dict() + for ep in endpoints if ep.enabled): + del endpoint['service_id'] + del endpoint['legacy_endpoint_id'] + del endpoint['enabled'] + endpoint['region'] = endpoint['region_id'] + try: + formatted_url = core.format_url( + endpoint['url'], d, + silent_keyerror_failures=silent_keyerror_failures) + if formatted_url: + endpoint['url'] = formatted_url + else: + continue + except exception.MalformedEndpoint: + # this failure is already logged in format_url() continue - except exception.MalformedEndpoint: - continue # this failure is already logged in format_url() - yield endpoint + yield endpoint - # TODO(davechen): If there is service with no endpoints, we should skip - # the service instead of keeping it in the catalog, see bug #1436704. - def make_v3_service(svc): - eps = list(make_v3_endpoints(svc.endpoints)) - service = {'endpoints': eps, 'id': svc.id, 'type': svc.type} - service['name'] = svc.extra.get('name', '') - return service + # TODO(davechen): If there is service with no endpoints, we should + # skip the service instead of keeping it in the catalog, + # see bug #1436704. + def make_v3_service(svc): + eps = list(make_v3_endpoints(svc.endpoints)) + service = {'endpoints': eps, 'id': svc.id, 'type': svc.type} + service['name'] = svc.extra.get('name', '') + return service - return [make_v3_service(svc) for svc in services] + return [make_v3_service(svc) for svc in services] @sql.handle_conflicts(conflict_type='project_endpoint') def add_endpoint_to_project(self, endpoint_id, project_id): - session = sql.get_session() - with session.begin(): + with sql.session_for_write() as session: endpoint_filter_ref = ProjectEndpoint(endpoint_id=endpoint_id, project_id=project_id) session.add(endpoint_filter_ref) @@ -402,50 +393,46 @@ class Catalog(catalog.CatalogDriverV8): return endpoint_filter_ref def check_endpoint_in_project(self, endpoint_id, project_id): - session = sql.get_session() - self._get_project_endpoint_ref(session, endpoint_id, project_id) + with sql.session_for_read() as session: + self._get_project_endpoint_ref(session, endpoint_id, project_id) def remove_endpoint_from_project(self, endpoint_id, project_id): - session = sql.get_session() - endpoint_filter_ref = self._get_project_endpoint_ref( - session, endpoint_id, project_id) - with session.begin(): + with sql.session_for_write() as session: + endpoint_filter_ref = self._get_project_endpoint_ref( + session, endpoint_id, project_id) session.delete(endpoint_filter_ref) def list_endpoints_for_project(self, project_id): - session = sql.get_session() - query = session.query(ProjectEndpoint) - query = query.filter_by(project_id=project_id) - endpoint_filter_refs = query.all() - return [ref.to_dict() for ref in endpoint_filter_refs] + with sql.session_for_read() as session: + query = session.query(ProjectEndpoint) + query = query.filter_by(project_id=project_id) + endpoint_filter_refs = query.all() + return [ref.to_dict() for ref in endpoint_filter_refs] def list_projects_for_endpoint(self, endpoint_id): - session = sql.get_session() - query = session.query(ProjectEndpoint) - query = query.filter_by(endpoint_id=endpoint_id) - endpoint_filter_refs = query.all() - return [ref.to_dict() for ref in endpoint_filter_refs] + with sql.session_for_read() as session: + query = session.query(ProjectEndpoint) + query = query.filter_by(endpoint_id=endpoint_id) + endpoint_filter_refs = query.all() + return [ref.to_dict() for ref in endpoint_filter_refs] def delete_association_by_endpoint(self, endpoint_id): - session = sql.get_session() - with session.begin(): + with sql.session_for_write() as session: query = session.query(ProjectEndpoint) query = query.filter_by(endpoint_id=endpoint_id) query.delete(synchronize_session=False) def delete_association_by_project(self, project_id): - session = sql.get_session() - with session.begin(): + with sql.session_for_write() as session: query = session.query(ProjectEndpoint) query = query.filter_by(project_id=project_id) query.delete(synchronize_session=False) def create_endpoint_group(self, endpoint_group_id, endpoint_group): - session = sql.get_session() - with session.begin(): + with sql.session_for_write() as session: endpoint_group_ref = EndpointGroup.from_dict(endpoint_group) session.add(endpoint_group_ref) - return endpoint_group_ref.to_dict() + return endpoint_group_ref.to_dict() def _get_endpoint_group(self, session, endpoint_group_id): endpoint_group_ref = session.query(EndpointGroup).get( @@ -456,14 +443,13 @@ class Catalog(catalog.CatalogDriverV8): return endpoint_group_ref def get_endpoint_group(self, endpoint_group_id): - session = sql.get_session() - endpoint_group_ref = self._get_endpoint_group(session, - endpoint_group_id) - return endpoint_group_ref.to_dict() + with sql.session_for_read() as session: + endpoint_group_ref = self._get_endpoint_group(session, + endpoint_group_id) + return endpoint_group_ref.to_dict() def update_endpoint_group(self, endpoint_group_id, endpoint_group): - session = sql.get_session() - with session.begin(): + with sql.session_for_write() as session: endpoint_group_ref = self._get_endpoint_group(session, endpoint_group_id) old_endpoint_group = endpoint_group_ref.to_dict() @@ -472,29 +458,26 @@ class Catalog(catalog.CatalogDriverV8): for attr in EndpointGroup.mutable_attributes: setattr(endpoint_group_ref, attr, getattr(new_endpoint_group, attr)) - return endpoint_group_ref.to_dict() + return endpoint_group_ref.to_dict() def delete_endpoint_group(self, endpoint_group_id): - session = sql.get_session() - endpoint_group_ref = self._get_endpoint_group(session, - endpoint_group_id) - with session.begin(): + with sql.session_for_write() as session: + endpoint_group_ref = self._get_endpoint_group(session, + endpoint_group_id) self._delete_endpoint_group_association_by_endpoint_group( session, endpoint_group_id) session.delete(endpoint_group_ref) def get_endpoint_group_in_project(self, endpoint_group_id, project_id): - session = sql.get_session() - ref = self._get_endpoint_group_in_project(session, - endpoint_group_id, - project_id) - return ref.to_dict() + with sql.session_for_read() as session: + ref = self._get_endpoint_group_in_project(session, + endpoint_group_id, + project_id) + return ref.to_dict() @sql.handle_conflicts(conflict_type='project_endpoint_group') def add_endpoint_group_to_project(self, endpoint_group_id, project_id): - session = sql.get_session() - - with session.begin(): + with sql.session_for_write() as session: # Create a new Project Endpoint group entity endpoint_group_project_ref = ProjectEndpointGroupMembership( endpoint_group_id=endpoint_group_id, project_id=project_id) @@ -512,32 +495,31 @@ class Catalog(catalog.CatalogDriverV8): return endpoint_group_project_ref def list_endpoint_groups(self): - session = sql.get_session() - query = session.query(EndpointGroup) - endpoint_group_refs = query.all() - return [e.to_dict() for e in endpoint_group_refs] + with sql.session_for_read() as session: + query = session.query(EndpointGroup) + endpoint_group_refs = query.all() + return [e.to_dict() for e in endpoint_group_refs] def list_endpoint_groups_for_project(self, project_id): - session = sql.get_session() - query = session.query(ProjectEndpointGroupMembership) - query = query.filter_by(project_id=project_id) - endpoint_group_refs = query.all() - return [ref.to_dict() for ref in endpoint_group_refs] + with sql.session_for_read() as session: + query = session.query(ProjectEndpointGroupMembership) + query = query.filter_by(project_id=project_id) + endpoint_group_refs = query.all() + return [ref.to_dict() for ref in endpoint_group_refs] def remove_endpoint_group_from_project(self, endpoint_group_id, project_id): - session = sql.get_session() - endpoint_group_project_ref = self._get_endpoint_group_in_project( - session, endpoint_group_id, project_id) - with session.begin(): + with sql.session_for_write() as session: + endpoint_group_project_ref = self._get_endpoint_group_in_project( + session, endpoint_group_id, project_id) session.delete(endpoint_group_project_ref) def list_projects_associated_with_endpoint_group(self, endpoint_group_id): - session = sql.get_session() - query = session.query(ProjectEndpointGroupMembership) - query = query.filter_by(endpoint_group_id=endpoint_group_id) - endpoint_group_refs = query.all() - return [ref.to_dict() for ref in endpoint_group_refs] + with sql.session_for_read() as session: + query = session.query(ProjectEndpointGroupMembership) + query = query.filter_by(endpoint_group_id=endpoint_group_id) + endpoint_group_refs = query.all() + return [ref.to_dict() for ref in endpoint_group_refs] def _delete_endpoint_group_association_by_endpoint_group( self, session, endpoint_group_id): @@ -546,8 +528,7 @@ class Catalog(catalog.CatalogDriverV8): query.delete() def delete_endpoint_group_association_by_project(self, project_id): - session = sql.get_session() - with session.begin(): + with sql.session_for_write() as session: query = session.query(ProjectEndpointGroupMembership) query = query.filter_by(project_id=project_id) query.delete() diff --git a/keystone/common/sql/core.py b/keystone/common/sql/core.py index 6a7f6ecf41..e298f747c7 100644 --- a/keystone/common/sql/core.py +++ b/keystone/common/sql/core.py @@ -18,14 +18,14 @@ Before using this module, call initialize(). This has to be done before CONF() because it sets up configuration options. """ -import contextlib import functools +import threading from oslo_config import cfg from oslo_db import exception as db_exception from oslo_db import options as db_options +from oslo_db.sqlalchemy import enginefacade from oslo_db.sqlalchemy import models -from oslo_db.sqlalchemy import session as db_session from oslo_log import log from oslo_serialization import jsonutils import six @@ -166,38 +166,41 @@ class ModelDictMixin(object): return {name: getattr(self, name) for name in names} -_engine_facade = None +_main_context_manager = None -def _get_engine_facade(): - global _engine_facade +def _get_main_context_manager(): + global _main_context_manager - if not _engine_facade: - _engine_facade = db_session.EngineFacade.from_config(CONF) + if not _main_context_manager: + _main_context_manager = enginefacade.transaction_context() - return _engine_facade + return _main_context_manager def cleanup(): - global _engine_facade + global _main_context_manager - _engine_facade = None + _main_context_manager = None def get_engine(): - return _get_engine_facade().get_engine() + return _get_main_context_manager().get_legacy_facade().get_engine() -def get_session(expire_on_commit=False): - return _get_engine_facade().get_session(expire_on_commit=expire_on_commit) +def get_session(): + return _get_main_context_manager().get_legacy_facade().get_session() -@contextlib.contextmanager -def transaction(expire_on_commit=False): - """Return a SQLAlchemy session in a scoped transaction.""" - session = get_session(expire_on_commit=expire_on_commit) - with session.begin(): - yield session +_CONTEXT = threading.local() + + +def session_for_read(): + return _get_main_context_manager().reader.using(_CONTEXT) + + +def session_for_write(): + return _get_main_context_manager().writer.using(_CONTEXT) def truncated(f): diff --git a/keystone/common/sql/migration_helpers.py b/keystone/common/sql/migration_helpers.py index 6701f44425..7571a6d42b 100644 --- a/keystone/common/sql/migration_helpers.py +++ b/keystone/common/sql/migration_helpers.py @@ -178,7 +178,7 @@ def _sync_extension_repo(extension, version): try: abs_path = find_migrate_repo(package) try: - migration.db_version_control(sql.get_engine(), abs_path) + migration.db_version_control(engine, abs_path) # Register the repo with the version control API # If it already knows about the repo, it will throw # an exception that we can safely ignore diff --git a/keystone/credential/backends/sql.py b/keystone/credential/backends/sql.py index 6527c63acd..dfb9d20ae8 100644 --- a/keystone/credential/backends/sql.py +++ b/keystone/credential/backends/sql.py @@ -36,28 +36,27 @@ class Credential(credential.CredentialDriverV8): @sql.handle_conflicts(conflict_type='credential') def create_credential(self, credential_id, credential): - session = sql.get_session() - with session.begin(): + with sql.session_for_write() as session: ref = CredentialModel.from_dict(credential) session.add(ref) - return ref.to_dict() + return ref.to_dict() @driver_hints.truncated def list_credentials(self, hints): - session = sql.get_session() - credentials = session.query(CredentialModel) - credentials = sql.filter_limit_query(CredentialModel, - credentials, hints) - return [s.to_dict() for s in credentials] + with sql.session_for_read() as session: + credentials = session.query(CredentialModel) + credentials = sql.filter_limit_query(CredentialModel, + credentials, hints) + return [s.to_dict() for s in credentials] def list_credentials_for_user(self, user_id, type=None): - session = sql.get_session() - query = session.query(CredentialModel) - query = query.filter_by(user_id=user_id) - if type: - query = query.filter_by(type=type) - refs = query.all() - return [ref.to_dict() for ref in refs] + with sql.session_for_read() as session: + query = session.query(CredentialModel) + query = query.filter_by(user_id=user_id) + if type: + query = query.filter_by(type=type) + refs = query.all() + return [ref.to_dict() for ref in refs] def _get_credential(self, session, credential_id): ref = session.query(CredentialModel).get(credential_id) @@ -66,13 +65,12 @@ class Credential(credential.CredentialDriverV8): return ref def get_credential(self, credential_id): - session = sql.get_session() - return self._get_credential(session, credential_id).to_dict() + with sql.session_for_read() as session: + return self._get_credential(session, credential_id).to_dict() @sql.handle_conflicts(conflict_type='credential') def update_credential(self, credential_id, credential): - session = sql.get_session() - with session.begin(): + with sql.session_for_write() as session: ref = self._get_credential(session, credential_id) old_dict = ref.to_dict() for k in credential: @@ -82,27 +80,21 @@ class Credential(credential.CredentialDriverV8): if attr != 'id': setattr(ref, attr, getattr(new_credential, attr)) ref.extra = new_credential.extra - return ref.to_dict() + return ref.to_dict() def delete_credential(self, credential_id): - session = sql.get_session() - - with session.begin(): + with sql.session_for_write() as session: ref = self._get_credential(session, credential_id) session.delete(ref) def delete_credentials_for_project(self, project_id): - session = sql.get_session() - - with session.begin(): + with sql.session_for_write() as session: query = session.query(CredentialModel) query = query.filter_by(project_id=project_id) query.delete() def delete_credentials_for_user(self, user_id): - session = sql.get_session() - - with session.begin(): + with sql.session_for_write() as session: query = session.query(CredentialModel) query = query.filter_by(user_id=user_id) query.delete() diff --git a/keystone/endpoint_policy/backends/sql.py b/keystone/endpoint_policy/backends/sql.py index b2687f1444..aacbb08321 100644 --- a/keystone/endpoint_policy/backends/sql.py +++ b/keystone/endpoint_policy/backends/sql.py @@ -51,7 +51,7 @@ class EndpointPolicy(object): def create_policy_association(self, policy_id, endpoint_id=None, service_id=None, region_id=None): - with sql.transaction() as session: + with sql.session_for_write() as session: try: # See if there is already a row for this association, and if # so, update it with the new policy_id @@ -79,14 +79,14 @@ class EndpointPolicy(object): # NOTE(henry-nash): Getting a single value to save object # management overhead. - with sql.transaction() as session: + with sql.session_for_read() as session: if session.query(PolicyAssociation.id).filter( sql_constraints).distinct().count() == 0: raise exception.PolicyAssociationNotFound() def delete_policy_association(self, policy_id, endpoint_id=None, service_id=None, region_id=None): - with sql.transaction() as session: + with sql.session_for_write() as session: query = session.query(PolicyAssociation) query = query.filter_by(policy_id=policy_id) query = query.filter_by(endpoint_id=endpoint_id) @@ -102,7 +102,7 @@ class EndpointPolicy(object): PolicyAssociation.region_id == region_id) try: - with sql.transaction() as session: + with sql.session_for_read() as session: policy_id = session.query(PolicyAssociation.policy_id).filter( sql_constraints).distinct().one() return {'policy_id': policy_id} @@ -110,31 +110,31 @@ class EndpointPolicy(object): raise exception.PolicyAssociationNotFound() def list_associations_for_policy(self, policy_id): - with sql.transaction() as session: + with sql.session_for_read() as session: query = session.query(PolicyAssociation) query = query.filter_by(policy_id=policy_id) return [ref.to_dict() for ref in query.all()] def delete_association_by_endpoint(self, endpoint_id): - with sql.transaction() as session: + with sql.session_for_write() as session: query = session.query(PolicyAssociation) query = query.filter_by(endpoint_id=endpoint_id) query.delete() def delete_association_by_service(self, service_id): - with sql.transaction() as session: + with sql.session_for_write() as session: query = session.query(PolicyAssociation) query = query.filter_by(service_id=service_id) query.delete() def delete_association_by_region(self, region_id): - with sql.transaction() as session: + with sql.session_for_write() as session: query = session.query(PolicyAssociation) query = query.filter_by(region_id=region_id) query.delete() def delete_association_by_policy(self, policy_id): - with sql.transaction() as session: + with sql.session_for_write() as session: query = session.query(PolicyAssociation) query = query.filter_by(policy_id=policy_id) query.delete() diff --git a/keystone/federation/V8_backends/sql.py b/keystone/federation/V8_backends/sql.py index 851a91128d..c95e9dbd2b 100644 --- a/keystone/federation/V8_backends/sql.py +++ b/keystone/federation/V8_backends/sql.py @@ -161,13 +161,13 @@ class Federation(core.FederationDriverV8): @sql.handle_conflicts(conflict_type='identity_provider') def create_idp(self, idp_id, idp): idp['id'] = idp_id - with sql.transaction() as session: + with sql.session_for_write() as session: idp_ref = IdentityProviderModel.from_dict(idp) session.add(idp_ref) - return idp_ref.to_dict() + return idp_ref.to_dict() def delete_idp(self, idp_id): - with sql.transaction() as session: + with sql.session_for_write() as session: self._delete_assigned_protocols(session, idp_id) idp_ref = self._get_idp(session, idp_id) session.delete(idp_ref) @@ -187,30 +187,30 @@ class Federation(core.FederationDriverV8): raise exception.IdentityProviderNotFound(idp_id=remote_id) def list_idps(self): - with sql.transaction() as session: + with sql.session_for_read() as session: idps = session.query(IdentityProviderModel) - idps_list = [idp.to_dict() for idp in idps] - return idps_list + idps_list = [idp.to_dict() for idp in idps] + return idps_list def get_idp(self, idp_id): - with sql.transaction() as session: + with sql.session_for_read() as session: idp_ref = self._get_idp(session, idp_id) - return idp_ref.to_dict() + return idp_ref.to_dict() def get_idp_from_remote_id(self, remote_id): - with sql.transaction() as session: + with sql.session_for_read() as session: ref = self._get_idp_from_remote_id(session, remote_id) - return ref.to_dict() + return ref.to_dict() def update_idp(self, idp_id, idp): - with sql.transaction() as session: + with sql.session_for_write() as session: idp_ref = self._get_idp(session, idp_id) old_idp = idp_ref.to_dict() old_idp.update(idp) new_idp = IdentityProviderModel.from_dict(old_idp) for attr in IdentityProviderModel.mutable_attributes: setattr(idp_ref, attr, getattr(new_idp, attr)) - return idp_ref.to_dict() + return idp_ref.to_dict() # Protocol CRUD def _get_protocol(self, session, idp_id, protocol_id): @@ -227,36 +227,36 @@ class Federation(core.FederationDriverV8): def create_protocol(self, idp_id, protocol_id, protocol): protocol['id'] = protocol_id protocol['idp_id'] = idp_id - with sql.transaction() as session: + with sql.session_for_write() as session: self._get_idp(session, idp_id) protocol_ref = FederationProtocolModel.from_dict(protocol) session.add(protocol_ref) - return protocol_ref.to_dict() + return protocol_ref.to_dict() def update_protocol(self, idp_id, protocol_id, protocol): - with sql.transaction() as session: + with sql.session_for_write() as session: proto_ref = self._get_protocol(session, idp_id, protocol_id) old_proto = proto_ref.to_dict() old_proto.update(protocol) new_proto = FederationProtocolModel.from_dict(old_proto) for attr in FederationProtocolModel.mutable_attributes: setattr(proto_ref, attr, getattr(new_proto, attr)) - return proto_ref.to_dict() + return proto_ref.to_dict() def get_protocol(self, idp_id, protocol_id): - with sql.transaction() as session: + with sql.session_for_read() as session: protocol_ref = self._get_protocol(session, idp_id, protocol_id) - return protocol_ref.to_dict() + return protocol_ref.to_dict() def list_protocols(self, idp_id): - with sql.transaction() as session: + with sql.session_for_read() as session: q = session.query(FederationProtocolModel) q = q.filter_by(idp_id=idp_id) - protocols = [protocol.to_dict() for protocol in q] - return protocols + protocols = [protocol.to_dict() for protocol in q] + return protocols def delete_protocol(self, idp_id, protocol_id): - with sql.transaction() as session: + with sql.session_for_write() as session: key_ref = self._get_protocol(session, idp_id, protocol_id) session.delete(key_ref) @@ -277,58 +277,58 @@ class Federation(core.FederationDriverV8): ref = {} ref['id'] = mapping_id ref['rules'] = mapping.get('rules') - with sql.transaction() as session: + with sql.session_for_write() as session: mapping_ref = MappingModel.from_dict(ref) session.add(mapping_ref) - return mapping_ref.to_dict() + return mapping_ref.to_dict() def delete_mapping(self, mapping_id): - with sql.transaction() as session: + with sql.session_for_write() as session: mapping_ref = self._get_mapping(session, mapping_id) session.delete(mapping_ref) def list_mappings(self): - with sql.transaction() as session: + with sql.session_for_read() as session: mappings = session.query(MappingModel) - return [x.to_dict() for x in mappings] + return [x.to_dict() for x in mappings] def get_mapping(self, mapping_id): - with sql.transaction() as session: + with sql.session_for_read() as session: mapping_ref = self._get_mapping(session, mapping_id) - return mapping_ref.to_dict() + return mapping_ref.to_dict() @sql.handle_conflicts(conflict_type='mapping') def update_mapping(self, mapping_id, mapping): ref = {} ref['id'] = mapping_id ref['rules'] = mapping.get('rules') - with sql.transaction() as session: + with sql.session_for_write() as session: mapping_ref = self._get_mapping(session, mapping_id) old_mapping = mapping_ref.to_dict() old_mapping.update(ref) new_mapping = MappingModel.from_dict(old_mapping) for attr in MappingModel.attributes: setattr(mapping_ref, attr, getattr(new_mapping, attr)) - return mapping_ref.to_dict() + return mapping_ref.to_dict() def get_mapping_from_idp_and_protocol(self, idp_id, protocol_id): - with sql.transaction() as session: + with sql.session_for_read() as session: protocol_ref = self._get_protocol(session, idp_id, protocol_id) mapping_id = protocol_ref.mapping_id mapping_ref = self._get_mapping(session, mapping_id) - return mapping_ref.to_dict() + return mapping_ref.to_dict() # Service Provider CRUD @sql.handle_conflicts(conflict_type='service_provider') def create_sp(self, sp_id, sp): sp['id'] = sp_id - with sql.transaction() as session: + with sql.session_for_write() as session: sp_ref = ServiceProviderModel.from_dict(sp) session.add(sp_ref) - return sp_ref.to_dict() + return sp_ref.to_dict() def delete_sp(self, sp_id): - with sql.transaction() as session: + with sql.session_for_write() as session: sp_ref = self._get_sp(session, sp_id) session.delete(sp_ref) @@ -339,28 +339,28 @@ class Federation(core.FederationDriverV8): return sp_ref def list_sps(self): - with sql.transaction() as session: + with sql.session_for_read() as session: sps = session.query(ServiceProviderModel) - sps_list = [sp.to_dict() for sp in sps] - return sps_list + sps_list = [sp.to_dict() for sp in sps] + return sps_list def get_sp(self, sp_id): - with sql.transaction() as session: + with sql.session_for_read() as session: sp_ref = self._get_sp(session, sp_id) - return sp_ref.to_dict() + return sp_ref.to_dict() def update_sp(self, sp_id, sp): - with sql.transaction() as session: + with sql.session_for_write() as session: sp_ref = self._get_sp(session, sp_id) old_sp = sp_ref.to_dict() old_sp.update(sp) new_sp = ServiceProviderModel.from_dict(old_sp) for attr in ServiceProviderModel.mutable_attributes: setattr(sp_ref, attr, getattr(new_sp, attr)) - return sp_ref.to_dict() + return sp_ref.to_dict() def get_enabled_service_providers(self): - with sql.transaction() as session: + with sql.session_for_read() as session: service_providers = session.query(ServiceProviderModel) service_providers = service_providers.filter_by(enabled=True) - return service_providers + return service_providers diff --git a/keystone/federation/backends/sql.py b/keystone/federation/backends/sql.py index 1ac8cb3daa..6bba07daf7 100644 --- a/keystone/federation/backends/sql.py +++ b/keystone/federation/backends/sql.py @@ -169,10 +169,10 @@ class Federation(core.FederationDriverV9): def create_idp(self, idp_id, idp): idp['id'] = idp_id try: - with sql.transaction() as session: + with sql.session_for_write() as session: idp_ref = IdentityProviderModel.from_dict(idp) session.add(idp_ref) - return idp_ref.to_dict() + return idp_ref.to_dict() except sql.DBDuplicateEntry as e: conflict_type = 'identity_provider' details = six.text_type(e) @@ -186,7 +186,7 @@ class Federation(core.FederationDriverV9): raise exception.Conflict(type=conflict_type, details=msg) def delete_idp(self, idp_id): - with sql.transaction() as session: + with sql.session_for_write() as session: self._delete_assigned_protocols(session, idp_id) idp_ref = self._get_idp(session, idp_id) session.delete(idp_ref) @@ -206,31 +206,31 @@ class Federation(core.FederationDriverV9): raise exception.IdentityProviderNotFound(idp_id=remote_id) def list_idps(self, hints=None): - with sql.transaction() as session: + with sql.session_for_read() as session: query = session.query(IdentityProviderModel) idps = sql.filter_limit_query(IdentityProviderModel, query, hints) - idps_list = [idp.to_dict() for idp in idps] - return idps_list + idps_list = [idp.to_dict() for idp in idps] + return idps_list def get_idp(self, idp_id): - with sql.transaction() as session: + with sql.session_for_read() as session: idp_ref = self._get_idp(session, idp_id) - return idp_ref.to_dict() + return idp_ref.to_dict() def get_idp_from_remote_id(self, remote_id): - with sql.transaction() as session: + with sql.session_for_read() as session: ref = self._get_idp_from_remote_id(session, remote_id) - return ref.to_dict() + return ref.to_dict() def update_idp(self, idp_id, idp): - with sql.transaction() as session: + with sql.session_for_write() as session: idp_ref = self._get_idp(session, idp_id) old_idp = idp_ref.to_dict() old_idp.update(idp) new_idp = IdentityProviderModel.from_dict(old_idp) for attr in IdentityProviderModel.mutable_attributes: setattr(idp_ref, attr, getattr(new_idp, attr)) - return idp_ref.to_dict() + return idp_ref.to_dict() # Protocol CRUD def _get_protocol(self, session, idp_id, protocol_id): @@ -247,36 +247,36 @@ class Federation(core.FederationDriverV9): def create_protocol(self, idp_id, protocol_id, protocol): protocol['id'] = protocol_id protocol['idp_id'] = idp_id - with sql.transaction() as session: + with sql.session_for_write() as session: self._get_idp(session, idp_id) protocol_ref = FederationProtocolModel.from_dict(protocol) session.add(protocol_ref) - return protocol_ref.to_dict() + return protocol_ref.to_dict() def update_protocol(self, idp_id, protocol_id, protocol): - with sql.transaction() as session: + with sql.session_for_write() as session: proto_ref = self._get_protocol(session, idp_id, protocol_id) old_proto = proto_ref.to_dict() old_proto.update(protocol) new_proto = FederationProtocolModel.from_dict(old_proto) for attr in FederationProtocolModel.mutable_attributes: setattr(proto_ref, attr, getattr(new_proto, attr)) - return proto_ref.to_dict() + return proto_ref.to_dict() def get_protocol(self, idp_id, protocol_id): - with sql.transaction() as session: + with sql.session_for_read() as session: protocol_ref = self._get_protocol(session, idp_id, protocol_id) - return protocol_ref.to_dict() + return protocol_ref.to_dict() def list_protocols(self, idp_id): - with sql.transaction() as session: + with sql.session_for_read() as session: q = session.query(FederationProtocolModel) q = q.filter_by(idp_id=idp_id) - protocols = [protocol.to_dict() for protocol in q] - return protocols + protocols = [protocol.to_dict() for protocol in q] + return protocols def delete_protocol(self, idp_id, protocol_id): - with sql.transaction() as session: + with sql.session_for_write() as session: key_ref = self._get_protocol(session, idp_id, protocol_id) session.delete(key_ref) @@ -297,58 +297,58 @@ class Federation(core.FederationDriverV9): ref = {} ref['id'] = mapping_id ref['rules'] = mapping.get('rules') - with sql.transaction() as session: + with sql.session_for_write() as session: mapping_ref = MappingModel.from_dict(ref) session.add(mapping_ref) - return mapping_ref.to_dict() + return mapping_ref.to_dict() def delete_mapping(self, mapping_id): - with sql.transaction() as session: + with sql.session_for_write() as session: mapping_ref = self._get_mapping(session, mapping_id) session.delete(mapping_ref) def list_mappings(self): - with sql.transaction() as session: + with sql.session_for_read() as session: mappings = session.query(MappingModel) - return [x.to_dict() for x in mappings] + return [x.to_dict() for x in mappings] def get_mapping(self, mapping_id): - with sql.transaction() as session: + with sql.session_for_read() as session: mapping_ref = self._get_mapping(session, mapping_id) - return mapping_ref.to_dict() + return mapping_ref.to_dict() @sql.handle_conflicts(conflict_type='mapping') def update_mapping(self, mapping_id, mapping): ref = {} ref['id'] = mapping_id ref['rules'] = mapping.get('rules') - with sql.transaction() as session: + with sql.session_for_write() as session: mapping_ref = self._get_mapping(session, mapping_id) old_mapping = mapping_ref.to_dict() old_mapping.update(ref) new_mapping = MappingModel.from_dict(old_mapping) for attr in MappingModel.attributes: setattr(mapping_ref, attr, getattr(new_mapping, attr)) - return mapping_ref.to_dict() + return mapping_ref.to_dict() def get_mapping_from_idp_and_protocol(self, idp_id, protocol_id): - with sql.transaction() as session: + with sql.session_for_read() as session: protocol_ref = self._get_protocol(session, idp_id, protocol_id) mapping_id = protocol_ref.mapping_id mapping_ref = self._get_mapping(session, mapping_id) - return mapping_ref.to_dict() + return mapping_ref.to_dict() # Service Provider CRUD @sql.handle_conflicts(conflict_type='service_provider') def create_sp(self, sp_id, sp): sp['id'] = sp_id - with sql.transaction() as session: + with sql.session_for_write() as session: sp_ref = ServiceProviderModel.from_dict(sp) session.add(sp_ref) - return sp_ref.to_dict() + return sp_ref.to_dict() def delete_sp(self, sp_id): - with sql.transaction() as session: + with sql.session_for_write() as session: sp_ref = self._get_sp(session, sp_id) session.delete(sp_ref) @@ -359,28 +359,28 @@ class Federation(core.FederationDriverV9): return sp_ref def list_sps(self): - with sql.transaction() as session: + with sql.session_for_read() as session: sps = session.query(ServiceProviderModel) - sps_list = [sp.to_dict() for sp in sps] - return sps_list + sps_list = [sp.to_dict() for sp in sps] + return sps_list def get_sp(self, sp_id): - with sql.transaction() as session: + with sql.session_for_read() as session: sp_ref = self._get_sp(session, sp_id) - return sp_ref.to_dict() + return sp_ref.to_dict() def update_sp(self, sp_id, sp): - with sql.transaction() as session: + with sql.session_for_write() as session: sp_ref = self._get_sp(session, sp_id) old_sp = sp_ref.to_dict() old_sp.update(sp) new_sp = ServiceProviderModel.from_dict(old_sp) for attr in ServiceProviderModel.mutable_attributes: setattr(sp_ref, attr, getattr(new_sp, attr)) - return sp_ref.to_dict() + return sp_ref.to_dict() def get_enabled_service_providers(self): - with sql.transaction() as session: + with sql.session_for_read() as session: service_providers = session.query(ServiceProviderModel) service_providers = service_providers.filter_by(enabled=True) - return service_providers + return service_providers diff --git a/keystone/identity/backends/sql.py b/keystone/identity/backends/sql.py index b5b51419b9..7273629b91 100644 --- a/keystone/identity/backends/sql.py +++ b/keystone/identity/backends/sql.py @@ -178,33 +178,32 @@ class Identity(identity.IdentityDriverV8): # Identity interface def authenticate(self, user_id, password): - session = sql.get_session() - user_ref = None - try: - user_ref = self._get_user(session, user_id) - except exception.UserNotFound: - raise AssertionError(_('Invalid user / password')) - if not self._check_password(password, user_ref): - raise AssertionError(_('Invalid user / password')) - return identity.filter_user(user_ref.to_dict()) + with sql.session_for_read() as session: + user_ref = None + try: + user_ref = self._get_user(session, user_id) + except exception.UserNotFound: + raise AssertionError(_('Invalid user / password')) + if not self._check_password(password, user_ref): + raise AssertionError(_('Invalid user / password')) + return identity.filter_user(user_ref.to_dict()) # user crud @sql.handle_conflicts(conflict_type='user') def create_user(self, user_id, user): user = utils.hash_user_password(user) - session = sql.get_session() - with session.begin(): + with sql.session_for_write() as session: user_ref = User.from_dict(user) session.add(user_ref) - return identity.filter_user(user_ref.to_dict()) + return identity.filter_user(user_ref.to_dict()) @driver_hints.truncated def list_users(self, hints): - session = sql.get_session() - query = session.query(User).outerjoin(LocalUser) - user_refs = sql.filter_limit_query(User, query, hints) - return [identity.filter_user(x.to_dict()) for x in user_refs] + with sql.session_for_read() as session: + query = session.query(User).outerjoin(LocalUser) + user_refs = sql.filter_limit_query(User, query, hints) + return [identity.filter_user(x.to_dict()) for x in user_refs] def _get_user(self, session, user_id): user_ref = session.query(User).get(user_id) @@ -213,25 +212,24 @@ class Identity(identity.IdentityDriverV8): return user_ref def get_user(self, user_id): - session = sql.get_session() - return identity.filter_user(self._get_user(session, user_id).to_dict()) + with sql.session_for_read() as session: + return identity.filter_user( + self._get_user(session, user_id).to_dict()) def get_user_by_name(self, user_name, domain_id): - session = sql.get_session() - query = session.query(User).join(LocalUser) - query = query.filter(and_(LocalUser.name == user_name, - LocalUser.domain_id == domain_id)) - try: - user_ref = query.one() - except sql.NotFound: - raise exception.UserNotFound(user_id=user_name) - return identity.filter_user(user_ref.to_dict()) + with sql.session_for_read() as session: + query = session.query(User).join(LocalUser) + query = query.filter(and_(LocalUser.name == user_name, + LocalUser.domain_id == domain_id)) + try: + user_ref = query.one() + except sql.NotFound: + raise exception.UserNotFound(user_id=user_name) + return identity.filter_user(user_ref.to_dict()) @sql.handle_conflicts(conflict_type='user') def update_user(self, user_id, user): - session = sql.get_session() - - with session.begin(): + with sql.session_for_write() as session: user_ref = self._get_user(session, user_id) old_user_dict = user_ref.to_dict() user = utils.hash_user_password(user) @@ -242,77 +240,74 @@ class Identity(identity.IdentityDriverV8): if attr != 'id': setattr(user_ref, attr, getattr(new_user, attr)) user_ref.extra = new_user.extra - return identity.filter_user(user_ref.to_dict(include_extra_dict=True)) + return identity.filter_user( + user_ref.to_dict(include_extra_dict=True)) def add_user_to_group(self, user_id, group_id): - session = sql.get_session() - self.get_group(group_id) - self.get_user(user_id) - query = session.query(UserGroupMembership) - query = query.filter_by(user_id=user_id) - query = query.filter_by(group_id=group_id) - rv = query.first() - if rv: - return + with sql.session_for_write() as session: + self.get_group(group_id) + self.get_user(user_id) + query = session.query(UserGroupMembership) + query = query.filter_by(user_id=user_id) + query = query.filter_by(group_id=group_id) + rv = query.first() + if rv: + return - with session.begin(): session.add(UserGroupMembership(user_id=user_id, group_id=group_id)) def check_user_in_group(self, user_id, group_id): - session = sql.get_session() - self.get_group(group_id) - self.get_user(user_id) - query = session.query(UserGroupMembership) - query = query.filter_by(user_id=user_id) - query = query.filter_by(group_id=group_id) - if not query.first(): - raise exception.NotFound(_("User '%(user_id)s' not found in" - " group '%(group_id)s'") % - {'user_id': user_id, - 'group_id': group_id}) - - def remove_user_from_group(self, user_id, group_id): - session = sql.get_session() - # We don't check if user or group are still valid and let the remove - # be tried anyway - in case this is some kind of clean-up operation - query = session.query(UserGroupMembership) - query = query.filter_by(user_id=user_id) - query = query.filter_by(group_id=group_id) - membership_ref = query.first() - if membership_ref is None: - # Check if the group and user exist to return descriptive - # exceptions. + with sql.session_for_read() as session: self.get_group(group_id) self.get_user(user_id) - raise exception.NotFound(_("User '%(user_id)s' not found in" - " group '%(group_id)s'") % - {'user_id': user_id, - 'group_id': group_id}) - with session.begin(): + query = session.query(UserGroupMembership) + query = query.filter_by(user_id=user_id) + query = query.filter_by(group_id=group_id) + if not query.first(): + raise exception.NotFound(_("User '%(user_id)s' not found in" + " group '%(group_id)s'") % + {'user_id': user_id, + 'group_id': group_id}) + + def remove_user_from_group(self, user_id, group_id): + # We don't check if user or group are still valid and let the remove + # be tried anyway - in case this is some kind of clean-up operation + with sql.session_for_write() as session: + query = session.query(UserGroupMembership) + query = query.filter_by(user_id=user_id) + query = query.filter_by(group_id=group_id) + membership_ref = query.first() + if membership_ref is None: + # Check if the group and user exist to return descriptive + # exceptions. + self.get_group(group_id) + self.get_user(user_id) + raise exception.NotFound(_("User '%(user_id)s' not found in" + " group '%(group_id)s'") % + {'user_id': user_id, + 'group_id': group_id}) session.delete(membership_ref) def list_groups_for_user(self, user_id, hints): - session = sql.get_session() - self.get_user(user_id) - query = session.query(Group).join(UserGroupMembership) - query = query.filter(UserGroupMembership.user_id == user_id) - query = sql.filter_limit_query(Group, query, hints) - return [g.to_dict() for g in query] + with sql.session_for_read() as session: + self.get_user(user_id) + query = session.query(Group).join(UserGroupMembership) + query = query.filter(UserGroupMembership.user_id == user_id) + query = sql.filter_limit_query(Group, query, hints) + return [g.to_dict() for g in query] def list_users_in_group(self, group_id, hints): - session = sql.get_session() - self.get_group(group_id) - query = session.query(User).outerjoin(LocalUser) - query = query.join(UserGroupMembership) - query = query.filter(UserGroupMembership.group_id == group_id) - query = sql.filter_limit_query(User, query, hints) - return [identity.filter_user(u.to_dict()) for u in query] + with sql.session_for_read() as session: + self.get_group(group_id) + query = session.query(User).outerjoin(LocalUser) + query = query.join(UserGroupMembership) + query = query.filter(UserGroupMembership.group_id == group_id) + query = sql.filter_limit_query(User, query, hints) + return [identity.filter_user(u.to_dict()) for u in query] def delete_user(self, user_id): - session = sql.get_session() - - with session.begin(): + with sql.session_for_write() as session: ref = self._get_user(session, user_id) q = session.query(UserGroupMembership) @@ -325,18 +320,17 @@ class Identity(identity.IdentityDriverV8): @sql.handle_conflicts(conflict_type='group') def create_group(self, group_id, group): - session = sql.get_session() - with session.begin(): + with sql.session_for_write() as session: ref = Group.from_dict(group) session.add(ref) - return ref.to_dict() + return ref.to_dict() @driver_hints.truncated def list_groups(self, hints): - session = sql.get_session() - query = session.query(Group) - refs = sql.filter_limit_query(Group, query, hints) - return [ref.to_dict() for ref in refs] + with sql.session_for_read() as session: + query = session.query(Group) + refs = sql.filter_limit_query(Group, query, hints) + return [ref.to_dict() for ref in refs] def _get_group(self, session, group_id): ref = session.query(Group).get(group_id) @@ -345,25 +339,23 @@ class Identity(identity.IdentityDriverV8): return ref def get_group(self, group_id): - session = sql.get_session() - return self._get_group(session, group_id).to_dict() + with sql.session_for_read() as session: + return self._get_group(session, group_id).to_dict() def get_group_by_name(self, group_name, domain_id): - session = sql.get_session() - query = session.query(Group) - query = query.filter_by(name=group_name) - query = query.filter_by(domain_id=domain_id) - try: - group_ref = query.one() - except sql.NotFound: - raise exception.GroupNotFound(group_id=group_name) - return group_ref.to_dict() + with sql.session_for_read() as session: + query = session.query(Group) + query = query.filter_by(name=group_name) + query = query.filter_by(domain_id=domain_id) + try: + group_ref = query.one() + except sql.NotFound: + raise exception.GroupNotFound(group_id=group_name) + return group_ref.to_dict() @sql.handle_conflicts(conflict_type='group') def update_group(self, group_id, group): - session = sql.get_session() - - with session.begin(): + with sql.session_for_write() as session: ref = self._get_group(session, group_id) old_dict = ref.to_dict() for k in group: @@ -373,12 +365,10 @@ class Identity(identity.IdentityDriverV8): if attr != 'id': setattr(ref, attr, getattr(new_group, attr)) ref.extra = new_group.extra - return ref.to_dict() + return ref.to_dict() def delete_group(self, group_id): - session = sql.get_session() - - with session.begin(): + with sql.session_for_write() as session: ref = self._get_group(session, group_id) q = session.query(UserGroupMembership) diff --git a/keystone/identity/mapping_backends/sql.py b/keystone/identity/mapping_backends/sql.py index 7e5c68f0b3..91b33dd762 100644 --- a/keystone/identity/mapping_backends/sql.py +++ b/keystone/identity/mapping_backends/sql.py @@ -45,27 +45,27 @@ class Mapping(identity.MappingDriverV8): # work if we hashed all the entries, even those that already generate # UUIDs, like SQL. Further, this would only work if the generation # algorithm was immutable (e.g. it had always been sha256). - session = sql.get_session() - query = session.query(IDMapping.public_id) - query = query.filter_by(domain_id=local_entity['domain_id']) - query = query.filter_by(local_id=local_entity['local_id']) - query = query.filter_by(entity_type=local_entity['entity_type']) - try: - public_ref = query.one() - public_id = public_ref.public_id - return public_id - except sql.NotFound: - return None + with sql.session_for_read() as session: + query = session.query(IDMapping.public_id) + query = query.filter_by(domain_id=local_entity['domain_id']) + query = query.filter_by(local_id=local_entity['local_id']) + query = query.filter_by(entity_type=local_entity['entity_type']) + try: + public_ref = query.one() + public_id = public_ref.public_id + return public_id + except sql.NotFound: + return None def get_id_mapping(self, public_id): - session = sql.get_session() - mapping_ref = session.query(IDMapping).get(public_id) - if mapping_ref: - return mapping_ref.to_dict() + with sql.session_for_read() as session: + mapping_ref = session.query(IDMapping).get(public_id) + if mapping_ref: + return mapping_ref.to_dict() def create_id_mapping(self, local_entity, public_id=None): entity = local_entity.copy() - with sql.transaction() as session: + with sql.session_for_write() as session: if public_id is None: public_id = self.id_generator_api.generate_public_ID(entity) entity['public_id'] = public_id @@ -74,7 +74,7 @@ class Mapping(identity.MappingDriverV8): return public_id def delete_id_mapping(self, public_id): - with sql.transaction() as session: + with sql.session_for_write() as session: try: session.query(IDMapping).filter( IDMapping.public_id == public_id).delete() @@ -84,14 +84,15 @@ class Mapping(identity.MappingDriverV8): pass def purge_mappings(self, purge_filter): - session = sql.get_session() - query = session.query(IDMapping) - if 'domain_id' in purge_filter: - query = query.filter_by(domain_id=purge_filter['domain_id']) - if 'public_id' in purge_filter: - query = query.filter_by(public_id=purge_filter['public_id']) - if 'local_id' in purge_filter: - query = query.filter_by(local_id=purge_filter['local_id']) - if 'entity_type' in purge_filter: - query = query.filter_by(entity_type=purge_filter['entity_type']) - query.delete() + with sql.session_for_write() as session: + query = session.query(IDMapping) + if 'domain_id' in purge_filter: + query = query.filter_by(domain_id=purge_filter['domain_id']) + if 'public_id' in purge_filter: + query = query.filter_by(public_id=purge_filter['public_id']) + if 'local_id' in purge_filter: + query = query.filter_by(local_id=purge_filter['local_id']) + if 'entity_type' in purge_filter: + query = query.filter_by( + entity_type=purge_filter['entity_type']) + query.delete() diff --git a/keystone/oauth1/backends/sql.py b/keystone/oauth1/backends/sql.py index aa3c34b77e..c5da7873f2 100644 --- a/keystone/oauth1/backends/sql.py +++ b/keystone/oauth1/backends/sql.py @@ -92,17 +92,16 @@ class OAuth1(core.Oauth1DriverV8): return consumer_ref def get_consumer_with_secret(self, consumer_id): - session = sql.get_session() - consumer_ref = self._get_consumer(session, consumer_id) - return consumer_ref.to_dict() + with sql.session_for_read() as session: + consumer_ref = self._get_consumer(session, consumer_id) + return consumer_ref.to_dict() def get_consumer(self, consumer_id): return core.filter_consumer( self.get_consumer_with_secret(consumer_id)) def create_consumer(self, consumer_ref): - session = sql.get_session() - with session.begin(): + with sql.session_for_write() as session: consumer = Consumer.from_dict(consumer_ref) session.add(consumer) return consumer.to_dict() @@ -128,20 +127,18 @@ class OAuth1(core.Oauth1DriverV8): session.delete(token_ref) def delete_consumer(self, consumer_id): - session = sql.get_session() - with session.begin(): + with sql.session_for_write() as session: self._delete_request_tokens(session, consumer_id) self._delete_access_tokens(session, consumer_id) self._delete_consumer(session, consumer_id) def list_consumers(self): - session = sql.get_session() - cons = session.query(Consumer) - return [core.filter_consumer(x.to_dict()) for x in cons] + with sql.session_for_read() as session: + cons = session.query(Consumer) + return [core.filter_consumer(x.to_dict()) for x in cons] def update_consumer(self, consumer_id, consumer_ref): - session = sql.get_session() - with session.begin(): + with sql.session_for_write() as session: consumer = self._get_consumer(session, consumer_id) old_consumer_dict = consumer.to_dict() old_consumer_dict.update(consumer_ref) @@ -169,11 +166,10 @@ class OAuth1(core.Oauth1DriverV8): ref['role_ids'] = None ref['consumer_id'] = consumer_id ref['expires_at'] = expiry_date - session = sql.get_session() - with session.begin(): + with sql.session_for_write() as session: token_ref = RequestToken.from_dict(ref) session.add(token_ref) - return token_ref.to_dict() + return token_ref.to_dict() def _get_request_token(self, session, request_token_id): token_ref = session.query(RequestToken).get(request_token_id) @@ -182,14 +178,13 @@ class OAuth1(core.Oauth1DriverV8): return token_ref def get_request_token(self, request_token_id): - session = sql.get_session() - token_ref = self._get_request_token(session, request_token_id) - return token_ref.to_dict() + with sql.session_for_read() as session: + token_ref = self._get_request_token(session, request_token_id) + return token_ref.to_dict() def authorize_request_token(self, request_token_id, user_id, role_ids): - session = sql.get_session() - with session.begin(): + with sql.session_for_write() as session: token_ref = self._get_request_token(session, request_token_id) token_dict = token_ref.to_dict() token_dict['authorizing_user_id'] = user_id @@ -203,13 +198,12 @@ class OAuth1(core.Oauth1DriverV8): or attr == 'role_ids'): setattr(token_ref, attr, getattr(new_token, attr)) - return token_ref.to_dict() + return token_ref.to_dict() def create_access_token(self, request_id, access_token_duration): access_token_id = uuid.uuid4().hex access_token_secret = uuid.uuid4().hex - session = sql.get_session() - with session.begin(): + with sql.session_for_write() as session: req_token_ref = self._get_request_token(session, request_id) token_dict = req_token_ref.to_dict() @@ -235,7 +229,7 @@ class OAuth1(core.Oauth1DriverV8): # remove request token, it's been used session.delete(req_token_ref) - return token_ref.to_dict() + return token_ref.to_dict() def _get_access_token(self, session, access_token_id): token_ref = session.query(AccessToken).get(access_token_id) @@ -244,19 +238,18 @@ class OAuth1(core.Oauth1DriverV8): return token_ref def get_access_token(self, access_token_id): - session = sql.get_session() - token_ref = self._get_access_token(session, access_token_id) - return token_ref.to_dict() + with sql.session_for_read() as session: + token_ref = self._get_access_token(session, access_token_id) + return token_ref.to_dict() def list_access_tokens(self, user_id): - session = sql.get_session() - q = session.query(AccessToken) - user_auths = q.filter_by(authorizing_user_id=user_id) - return [core.filter_token(x.to_dict()) for x in user_auths] + with sql.session_for_read() as session: + q = session.query(AccessToken) + user_auths = q.filter_by(authorizing_user_id=user_id) + return [core.filter_token(x.to_dict()) for x in user_auths] def delete_access_token(self, user_id, access_token_id): - session = sql.get_session() - with session.begin(): + with sql.session_for_write() as session: token_ref = self._get_access_token(session, access_token_id) token_dict = token_ref.to_dict() if token_dict['authorizing_user_id'] != user_id: diff --git a/keystone/policy/backends/sql.py b/keystone/policy/backends/sql.py index b2cccd015d..94763f0d13 100644 --- a/keystone/policy/backends/sql.py +++ b/keystone/policy/backends/sql.py @@ -30,19 +30,16 @@ class Policy(rules.Policy): @sql.handle_conflicts(conflict_type='policy') def create_policy(self, policy_id, policy): - session = sql.get_session() - - with session.begin(): + with sql.session_for_write() as session: ref = PolicyModel.from_dict(policy) session.add(ref) - return ref.to_dict() + return ref.to_dict() def list_policies(self): - session = sql.get_session() - - refs = session.query(PolicyModel).all() - return [ref.to_dict() for ref in refs] + with sql.session_for_read() as session: + refs = session.query(PolicyModel).all() + return [ref.to_dict() for ref in refs] def _get_policy(self, session, policy_id): """Private method to get a policy model object (NOT a dictionary).""" @@ -52,15 +49,12 @@ class Policy(rules.Policy): return ref def get_policy(self, policy_id): - session = sql.get_session() - - return self._get_policy(session, policy_id).to_dict() + with sql.session_for_read() as session: + return self._get_policy(session, policy_id).to_dict() @sql.handle_conflicts(conflict_type='policy') def update_policy(self, policy_id, policy): - session = sql.get_session() - - with session.begin(): + with sql.session_for_write() as session: ref = self._get_policy(session, policy_id) old_dict = ref.to_dict() old_dict.update(policy) @@ -72,8 +66,6 @@ class Policy(rules.Policy): return ref.to_dict() def delete_policy(self, policy_id): - session = sql.get_session() - - with session.begin(): + with sql.session_for_write() as session: ref = self._get_policy(session, policy_id) session.delete(ref) diff --git a/keystone/resource/V8_backends/sql.py b/keystone/resource/V8_backends/sql.py index 513ec96777..6c9b7912b9 100644 --- a/keystone/resource/V8_backends/sql.py +++ b/keystone/resource/V8_backends/sql.py @@ -35,11 +35,11 @@ class Resource(keystone_resource.ResourceDriverV8): return project_ref def get_project(self, tenant_id): - with sql.transaction() as session: + with sql.session_for_read() as session: return self._get_project(session, tenant_id).to_dict() def get_project_by_name(self, tenant_name, domain_id): - with sql.transaction() as session: + with sql.session_for_read() as session: query = session.query(Project) query = query.filter_by(name=tenant_name) query = query.filter_by(domain_id=domain_id) @@ -51,7 +51,7 @@ class Resource(keystone_resource.ResourceDriverV8): @driver_hints.truncated def list_projects(self, hints): - with sql.transaction() as session: + with sql.session_for_read() as session: query = session.query(Project) project_refs = sql.filter_limit_query(Project, query, hints) return [project_ref.to_dict() for project_ref in project_refs] @@ -60,7 +60,7 @@ class Resource(keystone_resource.ResourceDriverV8): if not ids: return [] else: - with sql.transaction() as session: + with sql.session_for_read() as session: query = session.query(Project) query = query.filter(Project.id.in_(ids)) return [project_ref.to_dict() for project_ref in query.all()] @@ -69,14 +69,14 @@ class Resource(keystone_resource.ResourceDriverV8): if not domain_ids: return [] else: - with sql.transaction() as session: + with sql.session_for_read() as session: query = session.query(Project.id) query = ( query.filter(Project.domain_id.in_(domain_ids))) return [x.id for x in query.all()] def list_projects_in_domain(self, domain_id): - with sql.transaction() as session: + with sql.session_for_read() as session: self._get_domain(session, domain_id) query = session.query(Project) project_refs = query.filter_by(domain_id=domain_id) @@ -89,7 +89,7 @@ class Resource(keystone_resource.ResourceDriverV8): return [project_ref.to_dict() for project_ref in project_refs] def list_projects_in_subtree(self, project_id): - with sql.transaction() as session: + with sql.session_for_read() as session: children = self._get_children(session, [project_id]) subtree = [] examined = set([project_id]) @@ -110,7 +110,7 @@ class Resource(keystone_resource.ResourceDriverV8): return subtree def list_project_parents(self, project_id): - with sql.transaction() as session: + with sql.session_for_read() as session: project = self._get_project(session, project_id).to_dict() parents = [] examined = set() @@ -130,7 +130,7 @@ class Resource(keystone_resource.ResourceDriverV8): return parents def is_leaf_project(self, project_id): - with sql.transaction() as session: + with sql.session_for_read() as session: project_refs = self._get_children(session, [project_id]) return not project_refs @@ -138,7 +138,7 @@ class Resource(keystone_resource.ResourceDriverV8): @sql.handle_conflicts(conflict_type='project') def create_project(self, tenant_id, tenant): tenant['name'] = clean.project_name(tenant['name']) - with sql.transaction() as session: + with sql.session_for_write() as session: tenant_ref = Project.from_dict(tenant) session.add(tenant_ref) return tenant_ref.to_dict() @@ -148,7 +148,7 @@ class Resource(keystone_resource.ResourceDriverV8): if 'name' in tenant: tenant['name'] = clean.project_name(tenant['name']) - with sql.transaction() as session: + with sql.session_for_write() as session: tenant_ref = self._get_project(session, tenant_id) old_project_dict = tenant_ref.to_dict() for k in tenant: @@ -162,7 +162,7 @@ class Resource(keystone_resource.ResourceDriverV8): @sql.handle_conflicts(conflict_type='project') def delete_project(self, tenant_id): - with sql.transaction() as session: + with sql.session_for_write() as session: tenant_ref = self._get_project(session, tenant_id) session.delete(tenant_ref) @@ -170,14 +170,14 @@ class Resource(keystone_resource.ResourceDriverV8): @sql.handle_conflicts(conflict_type='domain') def create_domain(self, domain_id, domain): - with sql.transaction() as session: + with sql.session_for_write() as session: ref = Domain.from_dict(domain) session.add(ref) return ref.to_dict() @driver_hints.truncated def list_domains(self, hints): - with sql.transaction() as session: + with sql.session_for_read() as session: query = session.query(Domain) refs = sql.filter_limit_query(Domain, query, hints) return [ref.to_dict() for ref in refs] @@ -186,7 +186,7 @@ class Resource(keystone_resource.ResourceDriverV8): if not ids: return [] else: - with sql.transaction() as session: + with sql.session_for_read() as session: query = session.query(Domain) query = query.filter(Domain.id.in_(ids)) domain_refs = query.all() @@ -199,11 +199,11 @@ class Resource(keystone_resource.ResourceDriverV8): return ref def get_domain(self, domain_id): - with sql.transaction() as session: + with sql.session_for_read() as session: return self._get_domain(session, domain_id).to_dict() def get_domain_by_name(self, domain_name): - with sql.transaction() as session: + with sql.session_for_read() as session: try: ref = (session.query(Domain). filter_by(name=domain_name).one()) @@ -213,7 +213,7 @@ class Resource(keystone_resource.ResourceDriverV8): @sql.handle_conflicts(conflict_type='domain') def update_domain(self, domain_id, domain): - with sql.transaction() as session: + with sql.session_for_write() as session: ref = self._get_domain(session, domain_id) old_dict = ref.to_dict() for k in domain: @@ -226,7 +226,7 @@ class Resource(keystone_resource.ResourceDriverV8): return ref.to_dict() def delete_domain(self, domain_id): - with sql.transaction() as session: + with sql.session_for_write() as session: ref = self._get_domain(session, domain_id) session.delete(ref) diff --git a/keystone/resource/backends/sql.py b/keystone/resource/backends/sql.py index 0a1b46d36d..41d16fd13f 100644 --- a/keystone/resource/backends/sql.py +++ b/keystone/resource/backends/sql.py @@ -38,11 +38,11 @@ class Resource(keystone_resource.ResourceDriverV9): return project_ref def get_project(self, project_id): - with sql.transaction() as session: + with sql.session_for_read() as session: return self._get_project(session, project_id).to_dict() def get_project_by_name(self, project_name, domain_id): - with sql.transaction() as session: + with sql.session_for_read() as session: query = session.query(Project) query = query.filter_by(name=project_name) if domain_id is None: @@ -70,7 +70,7 @@ class Resource(keystone_resource.ResourceDriverV9): for f in hints.filters: if (f['name'] == 'domain_id' and f['value'] is None): f['value'] = keystone_resource.NULL_DOMAIN_ID - with sql.transaction() as session: + with sql.session_for_read() as session: query = session.query(Project) project_refs = sql.filter_limit_query(Project, query, hints) return [project_ref.to_dict() for project_ref in project_refs @@ -80,7 +80,7 @@ class Resource(keystone_resource.ResourceDriverV9): if not ids: return [] else: - with sql.transaction() as session: + with sql.session_for_read() as session: query = session.query(Project) query = query.filter(Project.id.in_(ids)) return [project_ref.to_dict() for project_ref in query.all() @@ -90,7 +90,7 @@ class Resource(keystone_resource.ResourceDriverV9): if not domain_ids: return [] else: - with sql.transaction() as session: + with sql.session_for_read() as session: query = session.query(Project.id) query = ( query.filter(Project.domain_id.in_(domain_ids))) @@ -98,7 +98,7 @@ class Resource(keystone_resource.ResourceDriverV9): if not self._is_hidden_ref(x)] def list_projects_in_domain(self, domain_id): - with sql.transaction() as session: + with sql.session_for_read() as session: self._get_domain(session, domain_id) query = session.query(Project) project_refs = query.filter_by(domain_id=domain_id) @@ -111,7 +111,7 @@ class Resource(keystone_resource.ResourceDriverV9): return [project_ref.to_dict() for project_ref in project_refs] def list_projects_in_subtree(self, project_id): - with sql.transaction() as session: + with sql.session_for_read() as session: children = self._get_children(session, [project_id]) subtree = [] examined = set([project_id]) @@ -132,7 +132,7 @@ class Resource(keystone_resource.ResourceDriverV9): return subtree def list_project_parents(self, project_id): - with sql.transaction() as session: + with sql.session_for_read() as session: project = self._get_project(session, project_id).to_dict() parents = [] examined = set() @@ -152,7 +152,7 @@ class Resource(keystone_resource.ResourceDriverV9): return parents def is_leaf_project(self, project_id): - with sql.transaction() as session: + with sql.session_for_read() as session: project_refs = self._get_children(session, [project_id]) return not project_refs @@ -161,7 +161,7 @@ class Resource(keystone_resource.ResourceDriverV9): def create_project(self, project_id, project): project['name'] = clean.project_name(project['name']) new_project = self._encode_domain_id(project) - with sql.transaction() as session: + with sql.session_for_write() as session: project_ref = Project.from_dict(new_project) session.add(project_ref) return project_ref.to_dict() @@ -172,7 +172,7 @@ class Resource(keystone_resource.ResourceDriverV9): project['name'] = clean.project_name(project['name']) update_project = self._encode_domain_id(project) - with sql.transaction() as session: + with sql.session_for_write() as session: project_ref = self._get_project(session, project_id) old_project_dict = project_ref.to_dict() for k in update_project: @@ -189,7 +189,7 @@ class Resource(keystone_resource.ResourceDriverV9): @sql.handle_conflicts(conflict_type='project') def delete_project(self, project_id): - with sql.transaction() as session: + with sql.session_for_write() as session: project_ref = self._get_project(session, project_id) session.delete(project_ref) @@ -197,7 +197,7 @@ class Resource(keystone_resource.ResourceDriverV9): def delete_projects_from_ids(self, project_ids): if not project_ids: return - with sql.transaction() as session: + with sql.session_for_write() as session: query = session.query(Project).filter(Project.id.in_( project_ids)) project_ids_from_bd = [p['id'] for p in query.all()] @@ -212,14 +212,14 @@ class Resource(keystone_resource.ResourceDriverV9): @sql.handle_conflicts(conflict_type='domain') def create_domain(self, domain_id, domain): - with sql.transaction() as session: + with sql.session_for_write() as session: ref = Domain.from_dict(domain) session.add(ref) - return ref.to_dict() + return ref.to_dict() @driver_hints.truncated def list_domains(self, hints): - with sql.transaction() as session: + with sql.session_for_read() as session: query = session.query(Domain) refs = sql.filter_limit_query(Domain, query, hints) return [ref.to_dict() for ref in refs @@ -229,7 +229,7 @@ class Resource(keystone_resource.ResourceDriverV9): if not ids: return [] else: - with sql.transaction() as session: + with sql.session_for_read() as session: query = session.query(Domain) query = query.filter(Domain.id.in_(ids)) domain_refs = query.all() @@ -243,11 +243,11 @@ class Resource(keystone_resource.ResourceDriverV9): return ref def get_domain(self, domain_id): - with sql.transaction() as session: + with sql.session_for_read() as session: return self._get_domain(session, domain_id).to_dict() def get_domain_by_name(self, domain_name): - with sql.transaction() as session: + with sql.session_for_read() as session: try: ref = (session.query(Domain). filter_by(name=domain_name).one()) @@ -260,7 +260,7 @@ class Resource(keystone_resource.ResourceDriverV9): @sql.handle_conflicts(conflict_type='domain') def update_domain(self, domain_id, domain): - with sql.transaction() as session: + with sql.session_for_write() as session: ref = self._get_domain(session, domain_id) old_dict = ref.to_dict() for k in domain: @@ -273,7 +273,7 @@ class Resource(keystone_resource.ResourceDriverV9): return ref.to_dict() def delete_domain(self, domain_id): - with sql.transaction() as session: + with sql.session_for_write() as session: ref = self._get_domain(session, domain_id) session.delete(ref) diff --git a/keystone/resource/config_backends/sql.py b/keystone/resource/config_backends/sql.py index 2f6091472e..6413becc7a 100644 --- a/keystone/resource/config_backends/sql.py +++ b/keystone/resource/config_backends/sql.py @@ -59,12 +59,12 @@ class DomainConfig(resource.DomainConfigDriverV8): @sql.handle_conflicts(conflict_type='domain_config') def create_config_option(self, domain_id, group, option, value, sensitive=False): - with sql.transaction() as session: + with sql.session_for_write() as session: config_table = self.choose_table(sensitive) ref = config_table(domain_id=domain_id, group=group, option=option, value=value) session.add(ref) - return ref.to_dict() + return ref.to_dict() def _get_config_option(self, session, domain_id, group, option, sensitive): try: @@ -80,14 +80,14 @@ class DomainConfig(resource.DomainConfigDriverV8): return ref def get_config_option(self, domain_id, group, option, sensitive=False): - with sql.transaction() as session: + with sql.session_for_read() as session: ref = self._get_config_option(session, domain_id, group, option, sensitive) - return ref.to_dict() + return ref.to_dict() def list_config_options(self, domain_id, group=None, option=None, sensitive=False): - with sql.transaction() as session: + with sql.session_for_read() as session: config_table = self.choose_table(sensitive) query = session.query(config_table) query = query.filter_by(domain_id=domain_id) @@ -99,11 +99,11 @@ class DomainConfig(resource.DomainConfigDriverV8): def update_config_option(self, domain_id, group, option, value, sensitive=False): - with sql.transaction() as session: + with sql.session_for_write() as session: ref = self._get_config_option(session, domain_id, group, option, sensitive) ref.value = value - return ref.to_dict() + return ref.to_dict() def delete_config_options(self, domain_id, group=None, option=None, sensitive=False): @@ -114,7 +114,7 @@ class DomainConfig(resource.DomainConfigDriverV8): if there was nothing to delete. """ - with sql.transaction() as session: + with sql.session_for_write() as session: config_table = self.choose_table(sensitive) query = session.query(config_table) query = query.filter_by(domain_id=domain_id) @@ -126,7 +126,7 @@ class DomainConfig(resource.DomainConfigDriverV8): def obtain_registration(self, domain_id, type): try: - with sql.transaction() as session: + with sql.session_for_write() as session: ref = ConfigRegister(type=type, domain_id=domain_id) session.add(ref) return True @@ -136,15 +136,15 @@ class DomainConfig(resource.DomainConfigDriverV8): return False def read_registration(self, type): - with sql.transaction() as session: + with sql.session_for_read() as session: ref = session.query(ConfigRegister).get(type) if not ref: raise exception.ConfigRegistrationNotFound() - return ref.domain_id + return ref.domain_id def release_registration(self, domain_id, type=None): """Silently delete anything registered for the domain specified.""" - with sql.transaction() as session: + with sql.session_for_write() as session: query = session.query(ConfigRegister) if type: query = query.filter_by(type=type) diff --git a/keystone/revoke/backends/sql.py b/keystone/revoke/backends/sql.py index 5bf7d84ec3..67c7caab8d 100644 --- a/keystone/revoke/backends/sql.py +++ b/keystone/revoke/backends/sql.py @@ -60,37 +60,37 @@ class Revoke(revoke.RevokeDriverV8): def _prune_expired_events(self): oldest = revoke.revoked_before_cutoff_time() - session = sql.get_session() - dialect = session.bind.dialect.name - batch_size = self._flush_batch_size(dialect) - if batch_size > 0: - query = session.query(RevocationEvent.id) - query = query.filter(RevocationEvent.revoked_at < oldest) - query = query.limit(batch_size).subquery() - delete_query = (session.query(RevocationEvent). - filter(RevocationEvent.id.in_(query))) - while True: - rowcount = delete_query.delete(synchronize_session=False) - if rowcount == 0: - break - else: - query = session.query(RevocationEvent) - query = query.filter(RevocationEvent.revoked_at < oldest) - query.delete(synchronize_session=False) + with sql.session_for_write() as session: + dialect = session.bind.dialect.name + batch_size = self._flush_batch_size(dialect) + if batch_size > 0: + query = session.query(RevocationEvent.id) + query = query.filter(RevocationEvent.revoked_at < oldest) + query = query.limit(batch_size).subquery() + delete_query = (session.query(RevocationEvent). + filter(RevocationEvent.id.in_(query))) + while True: + rowcount = delete_query.delete(synchronize_session=False) + if rowcount == 0: + break + else: + query = session.query(RevocationEvent) + query = query.filter(RevocationEvent.revoked_at < oldest) + query.delete(synchronize_session=False) - session.flush() + session.flush() def list_events(self, last_fetch=None): - session = sql.get_session() - query = session.query(RevocationEvent).order_by( - RevocationEvent.revoked_at) + with sql.session_for_read() as session: + query = session.query(RevocationEvent).order_by( + RevocationEvent.revoked_at) - if last_fetch: - query = query.filter(RevocationEvent.revoked_at > last_fetch) + if last_fetch: + query = query.filter(RevocationEvent.revoked_at > last_fetch) - events = [model.RevokeEvent(**e.to_dict()) for e in query] + events = [model.RevokeEvent(**e.to_dict()) for e in query] - return events + return events def revoke(self, event): kwargs = dict() @@ -98,7 +98,6 @@ class Revoke(revoke.RevokeDriverV8): kwargs[attr] = getattr(event, attr) kwargs['id'] = uuid.uuid4().hex record = RevocationEvent(**kwargs) - session = sql.get_session() - with session.begin(): + with sql.session_for_write() as session: session.add(record) - self._prune_expired_events() + self._prune_expired_events() diff --git a/keystone/tests/unit/identity_mapping.py b/keystone/tests/unit/identity_mapping.py index c8f6fd4696..4ba4f0c2b0 100644 --- a/keystone/tests/unit/identity_mapping.py +++ b/keystone/tests/unit/identity_mapping.py @@ -17,6 +17,6 @@ from keystone.identity.mapping_backends import sql as mapping_sql def list_id_mappings(): """List all id_mappings for testing purposes.""" - a_session = sql.get_session() - refs = a_session.query(mapping_sql.IDMapping).all() - return [x.to_dict() for x in refs] + with sql.session_for_read() as session: + refs = session.query(mapping_sql.IDMapping).all() + return [x.to_dict() for x in refs] diff --git a/keystone/tests/unit/test_backend_sql.py b/keystone/tests/unit/test_backend_sql.py index c5d6ca1500..370dad90d9 100644 --- a/keystone/tests/unit/test_backend_sql.py +++ b/keystone/tests/unit/test_backend_sql.py @@ -611,7 +611,7 @@ class SqlToken(SqlTests, test_backend.TokenTests): tok = token_sql.Token() tok.list_revoked_tokens() - mock_query = mock_sql.get_session().query + mock_query = mock_sql.session_for_read().__enter__().query mock_query.assert_called_with(*expected_query_args) def test_flush_expired_tokens_batch(self): @@ -636,8 +636,12 @@ class SqlToken(SqlTests, test_backend.TokenTests): # other tests below test the differences between how they use the batch # strategy with mock.patch.object(token_sql, 'sql') as mock_sql: - mock_sql.get_session().query().filter().delete.return_value = 0 - mock_sql.get_session().bind.dialect.name = 'mysql' + mock_sql.session_for_write().__enter__( + ).query().filter().delete.return_value = 0 + + mock_sql.session_for_write().__enter__( + ).bind.dialect.name = 'mysql' + tok = token_sql.Token() expiry_mock = mock.Mock() ITERS = [1, 2, 3] @@ -648,7 +652,10 @@ class SqlToken(SqlTests, test_backend.TokenTests): # The expiry strategy is only invoked once, the other calls are via # the yield return. self.assertEqual(1, expiry_mock.call_count) - mock_delete = mock_sql.get_session().query().filter().delete + + mock_delete = mock_sql.session_for_write().__enter__( + ).query().filter().delete + self.assertThat(mock_delete.call_args_list, matchers.HasLength(len(ITERS))) diff --git a/keystone/tests/unit/test_sql_upgrade.py b/keystone/tests/unit/test_sql_upgrade.py index 790c0460bb..8353237161 100644 --- a/keystone/tests/unit/test_sql_upgrade.py +++ b/keystone/tests/unit/test_sql_upgrade.py @@ -161,7 +161,8 @@ class SqlMigrateBase(unit.SQLDriverOverrides, unit.TestCase): self.repo_package()) self.schema = versioning_api.ControlledSchema.create( self.engine, - self.repo_path, self.initial_db_version) + self.repo_path, + self.initial_db_version) # auto-detect the highest available schema version in the migrate_repo self.max_version = self.schema.repository.version().version diff --git a/keystone/token/persistence/backends/sql.py b/keystone/token/persistence/backends/sql.py index ebd324f914..4b3439a1c3 100644 --- a/keystone/token/persistence/backends/sql.py +++ b/keystone/token/persistence/backends/sql.py @@ -86,11 +86,11 @@ class Token(token.persistence.TokenDriverV8): def get_token(self, token_id): if token_id is None: raise exception.TokenNotFound(token_id=token_id) - session = sql.get_session() - token_ref = session.query(TokenModel).get(token_id) - if not token_ref or not token_ref.valid: - raise exception.TokenNotFound(token_id=token_id) - return token_ref.to_dict() + with sql.session_for_read() as session: + token_ref = session.query(TokenModel).get(token_id) + if not token_ref or not token_ref.valid: + raise exception.TokenNotFound(token_id=token_id) + return token_ref.to_dict() def create_token(self, token_id, data): data_copy = copy.deepcopy(data) @@ -101,14 +101,12 @@ class Token(token.persistence.TokenDriverV8): token_ref = TokenModel.from_dict(data_copy) token_ref.valid = True - session = sql.get_session() - with session.begin(): + with sql.session_for_write() as session: session.add(token_ref) return token_ref.to_dict() def delete_token(self, token_id): - session = sql.get_session() - with session.begin(): + with sql.session_for_write() as session: token_ref = session.query(TokenModel).get(token_id) if not token_ref or not token_ref.valid: raise exception.TokenNotFound(token_id=token_id) @@ -124,9 +122,8 @@ class Token(token.persistence.TokenDriverV8): or the trustor's user ID, so will use trust_id to query the tokens. """ - session = sql.get_session() token_list = [] - with session.begin(): + with sql.session_for_write() as session: now = timeutils.utcnow() query = session.query(TokenModel) query = query.filter_by(valid=True) @@ -167,38 +164,37 @@ class Token(token.persistence.TokenDriverV8): return False def _list_tokens_for_trust(self, trust_id): - session = sql.get_session() - tokens = [] - now = timeutils.utcnow() - query = session.query(TokenModel) - query = query.filter(TokenModel.expires > now) - query = query.filter(TokenModel.trust_id == trust_id) + with sql.session_for_read() as session: + tokens = [] + now = timeutils.utcnow() + query = session.query(TokenModel) + query = query.filter(TokenModel.expires > now) + query = query.filter(TokenModel.trust_id == trust_id) - token_references = query.filter_by(valid=True) - for token_ref in token_references: - token_ref_dict = token_ref.to_dict() - tokens.append(token_ref_dict['id']) - return tokens + token_references = query.filter_by(valid=True) + for token_ref in token_references: + token_ref_dict = token_ref.to_dict() + tokens.append(token_ref_dict['id']) + return tokens def _list_tokens_for_user(self, user_id, tenant_id=None): - session = sql.get_session() - tokens = [] - now = timeutils.utcnow() - query = session.query(TokenModel) - query = query.filter(TokenModel.expires > now) - query = query.filter(TokenModel.user_id == user_id) + with sql.session_for_read() as session: + tokens = [] + now = timeutils.utcnow() + query = session.query(TokenModel) + query = query.filter(TokenModel.expires > now) + query = query.filter(TokenModel.user_id == user_id) - token_references = query.filter_by(valid=True) - for token_ref in token_references: - token_ref_dict = token_ref.to_dict() - if self._tenant_matches(tenant_id, token_ref_dict): - tokens.append(token_ref['id']) - return tokens + token_references = query.filter_by(valid=True) + for token_ref in token_references: + token_ref_dict = token_ref.to_dict() + if self._tenant_matches(tenant_id, token_ref_dict): + tokens.append(token_ref['id']) + return tokens def _list_tokens_for_consumer(self, user_id, consumer_id): tokens = [] - session = sql.get_session() - with session.begin(): + with sql.session_for_write() as session: now = timeutils.utcnow() query = session.query(TokenModel) query = query.filter(TokenModel.expires > now) @@ -223,29 +219,29 @@ class Token(token.persistence.TokenDriverV8): return self._list_tokens_for_user(user_id, tenant_id) def list_revoked_tokens(self): - session = sql.get_session() - tokens = [] - now = timeutils.utcnow() - query = session.query(TokenModel.id, TokenModel.expires, - TokenModel.extra) - query = query.filter(TokenModel.expires > now) - token_references = query.filter_by(valid=False) - for token_ref in token_references: - token_data = token_ref[2]['token_data'] - if 'access' in token_data: - # It's a v2 token. - audit_ids = token_data['access']['token']['audit_ids'] - else: - # It's a v3 token. - audit_ids = token_data['token']['audit_ids'] + with sql.session_for_read() as session: + tokens = [] + now = timeutils.utcnow() + query = session.query(TokenModel.id, TokenModel.expires, + TokenModel.extra) + query = query.filter(TokenModel.expires > now) + token_references = query.filter_by(valid=False) + for token_ref in token_references: + token_data = token_ref[2]['token_data'] + if 'access' in token_data: + # It's a v2 token. + audit_ids = token_data['access']['token']['audit_ids'] + else: + # It's a v3 token. + audit_ids = token_data['token']['audit_ids'] - record = { - 'id': token_ref[0], - 'expires': token_ref[1], - 'audit_id': audit_ids[0], - } - tokens.append(record) - return tokens + record = { + 'id': token_ref[0], + 'expires': token_ref[1], + 'audit_id': audit_ids[0], + } + tokens.append(record) + return tokens def _expiry_range_strategy(self, dialect): """Choose a token range expiration strategy @@ -273,18 +269,18 @@ class Token(token.persistence.TokenDriverV8): return _expiry_range_all def flush_expired_tokens(self): - session = sql.get_session() - dialect = session.bind.dialect.name - expiry_range_func = self._expiry_range_strategy(dialect) - query = session.query(TokenModel.expires) - total_removed = 0 - upper_bound_func = timeutils.utcnow - for expiry_time in expiry_range_func(session, upper_bound_func): - delete_query = query.filter(TokenModel.expires <= - expiry_time) - row_count = delete_query.delete(synchronize_session=False) - total_removed += row_count - LOG.debug('Removed %d total expired tokens', total_removed) + with sql.session_for_write() as session: + dialect = session.bind.dialect.name + expiry_range_func = self._expiry_range_strategy(dialect) + query = session.query(TokenModel.expires) + total_removed = 0 + upper_bound_func = timeutils.utcnow + for expiry_time in expiry_range_func(session, upper_bound_func): + delete_query = query.filter(TokenModel.expires <= + expiry_time) + row_count = delete_query.delete(synchronize_session=False) + total_removed += row_count + LOG.debug('Removed %d total expired tokens', total_removed) - session.flush() - LOG.info(_LI('Total expired tokens removed: %d'), total_removed) + session.flush() + LOG.info(_LI('Total expired tokens removed: %d'), total_removed) diff --git a/keystone/trust/backends/sql.py b/keystone/trust/backends/sql.py index e3c8822116..cb8446b3f2 100644 --- a/keystone/trust/backends/sql.py +++ b/keystone/trust/backends/sql.py @@ -59,7 +59,7 @@ class TrustRole(sql.ModelBase): class Trust(trust.TrustDriverV8): @sql.handle_conflicts(conflict_type='trust') def create_trust(self, trust_id, trust, roles): - with sql.transaction() as session: + with sql.session_for_write() as session: ref = TrustModel.from_dict(trust) ref['id'] = trust_id if ref.get('expires_at') and ref['expires_at'].tzinfo is not None: @@ -72,9 +72,9 @@ class Trust(trust.TrustDriverV8): trust_role.role_id = role['id'] added_roles.append({'id': role['id']}) session.add(trust_role) - trust_dict = ref.to_dict() - trust_dict['roles'] = added_roles - return trust_dict + trust_dict = ref.to_dict() + trust_dict['roles'] = added_roles + return trust_dict def _add_roles(self, trust_id, session, trust_dict): roles = [] @@ -86,7 +86,7 @@ class Trust(trust.TrustDriverV8): def consume_use(self, trust_id): for attempt in range(MAXIMUM_CONSUME_ATTEMPTS): - with sql.transaction() as session: + with sql.session_for_write() as session: try: query_result = (session.query(TrustModel.remaining_uses). filter_by(id=trust_id). @@ -132,51 +132,51 @@ class Trust(trust.TrustDriverV8): raise exception.TrustConsumeMaximumAttempt(trust_id=trust_id) def get_trust(self, trust_id, deleted=False): - session = sql.get_session() - query = session.query(TrustModel).filter_by(id=trust_id) - if not deleted: - query = query.filter_by(deleted_at=None) - ref = query.first() - if ref is None: - raise exception.TrustNotFound(trust_id=trust_id) - if ref.expires_at is not None and not deleted: - now = timeutils.utcnow() - if now > ref.expires_at: + with sql.session_for_read() as session: + query = session.query(TrustModel).filter_by(id=trust_id) + if not deleted: + query = query.filter_by(deleted_at=None) + ref = query.first() + if ref is None: raise exception.TrustNotFound(trust_id=trust_id) - # Do not return trusts that can't be used anymore - if ref.remaining_uses is not None and not deleted: - if ref.remaining_uses <= 0: - raise exception.TrustNotFound(trust_id=trust_id) - trust_dict = ref.to_dict() + if ref.expires_at is not None and not deleted: + now = timeutils.utcnow() + if now > ref.expires_at: + raise exception.TrustNotFound(trust_id=trust_id) + # Do not return trusts that can't be used anymore + if ref.remaining_uses is not None and not deleted: + if ref.remaining_uses <= 0: + raise exception.TrustNotFound(trust_id=trust_id) + trust_dict = ref.to_dict() - self._add_roles(trust_id, session, trust_dict) - return trust_dict + self._add_roles(trust_id, session, trust_dict) + return trust_dict @sql.handle_conflicts(conflict_type='trust') def list_trusts(self): - session = sql.get_session() - trusts = session.query(TrustModel).filter_by(deleted_at=None) - return [trust_ref.to_dict() for trust_ref in trusts] + with sql.session_for_read() as session: + trusts = session.query(TrustModel).filter_by(deleted_at=None) + return [trust_ref.to_dict() for trust_ref in trusts] @sql.handle_conflicts(conflict_type='trust') def list_trusts_for_trustee(self, trustee_user_id): - session = sql.get_session() - trusts = (session.query(TrustModel). - filter_by(deleted_at=None). - filter_by(trustee_user_id=trustee_user_id)) - return [trust_ref.to_dict() for trust_ref in trusts] + with sql.session_for_read() as session: + trusts = (session.query(TrustModel). + filter_by(deleted_at=None). + filter_by(trustee_user_id=trustee_user_id)) + return [trust_ref.to_dict() for trust_ref in trusts] @sql.handle_conflicts(conflict_type='trust') def list_trusts_for_trustor(self, trustor_user_id): - session = sql.get_session() - trusts = (session.query(TrustModel). - filter_by(deleted_at=None). - filter_by(trustor_user_id=trustor_user_id)) - return [trust_ref.to_dict() for trust_ref in trusts] + with sql.session_for_read() as session: + trusts = (session.query(TrustModel). + filter_by(deleted_at=None). + filter_by(trustor_user_id=trustor_user_id)) + return [trust_ref.to_dict() for trust_ref in trusts] @sql.handle_conflicts(conflict_type='trust') def delete_trust(self, trust_id): - with sql.transaction() as session: + with sql.session_for_write() as session: trust_ref = session.query(TrustModel).get(trust_id) if not trust_ref: raise exception.TrustNotFound(trust_id=trust_id)