diff --git a/manila/db/sqlalchemy/api.py b/manila/db/sqlalchemy/api.py index 4ab2438951..745e2576b2 100644 --- a/manila/db/sqlalchemy/api.py +++ b/manila/db/sqlalchemy/api.py @@ -4339,12 +4339,9 @@ def export_location_metadata_update(context, export_location_uuid, metadata, ################################### -# TODO(stephenfin): Remove the 'session' argument once all callers have been -# converted -def _security_service_get_query(context, project_only=False, session=None): +def _security_service_get_query(context, project_only=False): return model_query( context, models.SecurityService, project_only=project_only, - session=session, ) @@ -4383,13 +4380,10 @@ def security_service_get(context, id, **kwargs): return _security_service_get(context, id, **kwargs) -# TODO(stephenfin): Remove the 'session' argument once all callers have been -# converted @require_context def _security_service_get(context, id, session=None, **kwargs): result = _security_service_get_query( context, - session=session, **kwargs, ).filter_by(id=id).first() if result is None: @@ -4428,211 +4422,212 @@ def security_service_get_all_by_share_network(context, share_network_id): ################### -def _network_get_query(context, session=None): - if session is None: - session = get_session() - return (model_query(context, models.ShareNetwork, session=session, - project_only=True). - options(joinedload('share_instances'), - joinedload('security_services'), - subqueryload('share_network_subnets'))) +def _share_network_get_query(context): + return model_query( + context, models.ShareNetwork, project_only=True, + ).options( + joinedload('share_instances'), + joinedload('security_services'), + subqueryload('share_network_subnets'), + ) @require_context +@context_manager.writer def share_network_create(context, values): values = ensure_model_dict_has_id(values) network_ref = models.ShareNetwork() network_ref.update(values) - session = get_session() - with session.begin(): - network_ref.save(session=session) - return share_network_get(context, values['id'], session) + network_ref.save(session=context.session) + return _share_network_get(context, values['id']) @require_context +@context_manager.writer def share_network_delete(context, id): - session = get_session() - with session.begin(): - network_ref = share_network_get(context, id, session=session) - network_ref.soft_delete(session) + network_ref = _share_network_get(context, id) + network_ref.soft_delete(session=context.session) @require_context +@context_manager.writer def share_network_update(context, id, values): - session = get_session() - with session.begin(): - network_ref = share_network_get(context, id, session=session) - network_ref.update(values) - network_ref.save(session=session) - return network_ref + network_ref = _share_network_get(context, id) + network_ref.update(values) + network_ref.save(session=context.session) + return network_ref @require_context -def share_network_get(context, id, session=None): - result = _network_get_query(context, session).filter_by(id=id).first() +@context_manager.reader +def share_network_get(context, id): + return _share_network_get(context, id) + + +@require_context +def _share_network_get(context, id): + result = _share_network_get_query(context).filter_by(id=id).first() if result is None: raise exception.ShareNetworkNotFound(share_network_id=id) return result @require_context +@context_manager.reader def share_network_get_all_by_filter(context, filters=None): - model_sn = models.ShareNetwork - session = get_session() - with session.begin(): - query = _network_get_query(context, - session=session) + query = _share_network_get_query(context) - legal_filter_keys = ('project_id', 'created_since', 'created_before') + legal_filter_keys = ('project_id', 'created_since', 'created_before') - if not filters: - filters = {} + if not filters: + filters = {} - query = exact_filter(query, model_sn, filters, legal_filter_keys) + query = exact_filter( + query, models.ShareNetwork, filters, legal_filter_keys, + ) + if 'security_service_id' in filters: + security_service_id = filters.get('security_service_id') + query = query.join( + models.ShareNetworkSecurityServiceAssociation, + models.ShareNetwork.id == models.ShareNetworkSecurityServiceAssociation.share_network_id, # noqa: E501 + ).filter_by( + security_service_id=security_service_id, + deleted=0, + ) - if 'security_service_id' in filters: - security_service_id = filters.get('security_service_id') - query = query.join( - models.ShareNetworkSecurityServiceAssociation, - models.ShareNetwork.id == models. - ShareNetworkSecurityServiceAssociation. - share_network_id).filter_by( - security_service_id=security_service_id, deleted=0) - - return query.all() + return query.all() @require_context +@context_manager.reader def share_network_get_all(context): - return _network_get_query(context).all() + return _share_network_get_query(context).all() @require_context +@context_manager.reader def share_network_get_all_by_project(context, project_id): - return _network_get_query(context).filter_by(project_id=project_id).all() + return _share_network_get_query( + context, + ).filter_by(project_id=project_id).all() @require_context +@context_manager.reader def share_network_get_all_by_security_service(context, security_service_id): - session = get_session() - return (model_query(context, models.ShareNetwork, session=session). - join(models.ShareNetworkSecurityServiceAssociation, - models.ShareNetwork.id == - models.ShareNetworkSecurityServiceAssociation.share_network_id). - filter_by(security_service_id=security_service_id, deleted=0) - .all()) + return model_query( + context, models.ShareNetwork, + ).join( + models.ShareNetworkSecurityServiceAssociation, + models.ShareNetwork.id == + models.ShareNetworkSecurityServiceAssociation.share_network_id, + ).filter_by(security_service_id=security_service_id, deleted=0).all() @require_context +@context_manager.writer def share_network_add_security_service(context, id, security_service_id): - session = get_session() + assoc_ref = model_query( + context, + models.ShareNetworkSecurityServiceAssociation, + ).filter_by( + share_network_id=id, + ).filter_by(security_service_id=security_service_id).first() - with session.begin(): - assoc_ref = (model_query( - context, - models.ShareNetworkSecurityServiceAssociation, - session=session). - filter_by(share_network_id=id). - filter_by( - security_service_id=security_service_id).first()) - - if assoc_ref: - msg = "Already associated" - raise exception.ShareNetworkSecurityServiceAssociationError( - share_network_id=id, - security_service_id=security_service_id, - reason=msg) - - share_nw_ref = share_network_get(context, id, session=session) - security_service_ref = _security_service_get( - context, security_service_id, session=session, + if assoc_ref: + msg = "Already associated" + raise exception.ShareNetworkSecurityServiceAssociationError( + share_network_id=id, + security_service_id=security_service_id, + reason=msg, ) - share_nw_ref.security_services += [security_service_ref] - share_nw_ref.save(session=session) + + share_nw_ref = _share_network_get(context, id) + security_service_ref = _security_service_get(context, security_service_id) + share_nw_ref.security_services += [security_service_ref] + share_nw_ref.save(session=context.session) return share_nw_ref @require_context +@context_manager.reader def share_network_security_service_association_get( - context, share_network_id, security_service_id): - session = get_session() - - with session.begin(): - association = (model_query( - context, - models.ShareNetworkSecurityServiceAssociation, - session=session).filter_by( - share_network_id=share_network_id).filter_by( - security_service_id=security_service_id).first()) - return association + context, share_network_id, security_service_id, +): + association = model_query( + context, + models.ShareNetworkSecurityServiceAssociation, + ).filter_by( + share_network_id=share_network_id, + ).filter_by( + security_service_id=security_service_id, + ).first() + return association @require_context +@context_manager.writer def share_network_remove_security_service(context, id, security_service_id): - session = get_session() + share_nw_ref = _share_network_get(context, id) + _security_service_get(context, security_service_id) - with session.begin(): - share_nw_ref = share_network_get(context, id, session=session) - _security_service_get(context, security_service_id, session=session) + assoc_ref = model_query( + context, + models.ShareNetworkSecurityServiceAssociation, + ).filter_by( + share_network_id=id, + ).filter_by(security_service_id=security_service_id).first() - assoc_ref = (model_query( - context, - models.ShareNetworkSecurityServiceAssociation, - session=session). - filter_by(share_network_id=id). - filter_by(security_service_id=security_service_id).first()) - - if assoc_ref: - assoc_ref.soft_delete(session) - else: - msg = "No association defined" - raise exception.ShareNetworkSecurityServiceDissociationError( - share_network_id=id, - security_service_id=security_service_id, - reason=msg) + if assoc_ref: + assoc_ref.soft_delete(session=context.session) + else: + msg = "No association defined" + raise exception.ShareNetworkSecurityServiceDissociationError( + share_network_id=id, + security_service_id=security_service_id, + reason=msg, + ) return share_nw_ref @require_context -def share_network_update_security_service(context, id, - current_security_service_id, - new_security_service_id): - session = get_session() +@context_manager.writer +def share_network_update_security_service( + context, id, current_security_service_id, new_security_service_id, +): + share_nw_ref = _share_network_get(context, id) + # Check if the old security service exists + _security_service_get(context, current_security_service_id) + new_security_service_ref = _security_service_get( + context, new_security_service_id, + ) - with session.begin(): - share_nw_ref = share_network_get(context, id, session=session) - # Check if the old security service exists - _security_service_get( - context, current_security_service_id, session=session, - ) - new_security_service_ref = _security_service_get( - context, new_security_service_id, session=session, - ) + assoc_ref = model_query( + context, + models.ShareNetworkSecurityServiceAssociation, + ).filter_by( + share_network_id=id, + ).filter_by( + security_service_id=current_security_service_id, + ).first() - assoc_ref = (model_query( - context, - models.ShareNetworkSecurityServiceAssociation, - session=session).filter_by( - share_network_id=id).filter_by( - security_service_id=current_security_service_id).first()) + if assoc_ref: + assoc_ref.soft_delete(session=context.session) + else: + msg = "No association defined" + raise exception.ShareNetworkSecurityServiceDissociationError( + share_network_id=id, + security_service_id=current_security_service_id, + reason=msg) - if assoc_ref: - assoc_ref.soft_delete(session) - else: - msg = "No association defined" - raise exception.ShareNetworkSecurityServiceDissociationError( - share_network_id=id, - security_service_id=current_security_service_id, - reason=msg) + # Add new association + share_nw_ref.security_services += [new_security_service_ref] + share_nw_ref.save(session=context.session) - # Add new association - share_nw_ref.security_services += [new_security_service_ref] - share_nw_ref.save(session=session) - - return share_nw_ref + return share_nw_ref @require_context