diff --git a/bin/nova-manage b/bin/nova-manage index 46374673fbea..d0246ad9d343 100755 --- a/bin/nova-manage +++ b/bin/nova-manage @@ -1379,7 +1379,7 @@ class VsaCommands(object): raise is_admin = self.manager.is_admin(user_id) - ctxt = context.RequestContext(user_id, project_id, is_admin) + ctxt = context.RequestContext(user_id, project_id, is_admin=is_admin) if not is_admin and \ not self.manager.is_project_member(user_id, project_id): msg = _("%(user_id)s must be an admin or a " diff --git a/nova/context.py b/nova/context.py index 36d15ba08737..cd07ca629f55 100644 --- a/nova/context.py +++ b/nova/context.py @@ -19,6 +19,7 @@ """RequestContext: context for requests that persist through all of nova.""" +import copy import uuid from nova import local @@ -32,9 +33,14 @@ class RequestContext(object): """ - def __init__(self, user_id, project_id, is_admin=None, read_deleted=False, + def __init__(self, user_id, project_id, is_admin=None, read_deleted="no", roles=None, remote_address=None, timestamp=None, request_id=None, auth_token=None, strategy='noauth'): + """ + :param read_deleted: 'no' indicates deleted records are hidden, 'yes' + indicates deleted records are visible, 'only' indicates that + *only* deleted records are visible. + """ self.user_id = user_id self.project_id = project_id self.roles = roles or [] @@ -73,18 +79,17 @@ class RequestContext(object): def elevated(self, read_deleted=None): """Return a version of this context with admin flag set.""" - rd = self.read_deleted if read_deleted is None else read_deleted - return RequestContext(user_id=self.user_id, - project_id=self.project_id, - is_admin=True, - read_deleted=rd, - roles=self.roles, - remote_address=self.remote_address, - timestamp=self.timestamp, - request_id=self.request_id, - auth_token=self.auth_token, - strategy=self.strategy) + context = copy.copy(self) + context.is_admin = True + + if read_deleted is not None: + context.read_deleted = read_deleted + + return context -def get_admin_context(read_deleted=False): - return RequestContext(None, None, True, read_deleted) +def get_admin_context(read_deleted="no"): + return RequestContext(user_id=None, + project_id=None, + is_admin=True, + read_deleted=read_deleted) diff --git a/nova/db/sqlalchemy/api.py b/nova/db/sqlalchemy/api.py index 8cd8daf1c3e6..a8424686ff8d 100644 --- a/nova/db/sqlalchemy/api.py +++ b/nova/db/sqlalchemy/api.py @@ -82,13 +82,6 @@ def authorize_user_context(context, user_id): raise exception.NotAuthorized() -def can_read_deleted(context): - """Indicates if the context has access to deleted objects.""" - if not context: - return False - return context.read_deleted - - def require_admin_context(f): """Decorator to require admin request context. @@ -149,6 +142,37 @@ def require_volume_exists(f): return wrapper +def model_query(context, *args, **kwargs): + """Query helper that accounts for context's `read_deleted` field. + + :param context: context to query under + :param session: if present, the session to use + :param read_deleted: if present, overrides context's read_deleted field. + :param project_only: if present and context is user-type, then restrict + query to match the context's project_id. + """ + session = kwargs.get('session') or get_session() + read_deleted = kwargs.get('read_deleted') or context.read_deleted + project_only = kwargs.get('project_only') + + query = session.query(*args) + + if read_deleted == 'no': + query = query.filter_by(deleted=False) + elif read_deleted == 'yes': + pass # omit the filter to include deleted and active + elif read_deleted == 'only': + query = query.filter_by(deleted=True) + else: + raise Exception( + _("Unrecognized read_deleted value '%s'") % read_deleted) + + if project_only and is_user_context(context): + query = query.filter_by(project_id=context.project_id) + + return query + + ################### @@ -167,15 +191,10 @@ def service_destroy(context, service_id): @require_admin_context def service_get(context, service_id, session=None): - if not session: - session = get_session() - - result = session.query(models.Service).\ + result = model_query(context, models.Service, session=session).\ options(joinedload('compute_node')).\ filter_by(id=service_id).\ - filter_by(deleted=can_read_deleted(context)).\ first() - if not result: raise exception.ServiceNotFound(service_id=service_id) @@ -184,9 +203,7 @@ def service_get(context, service_id, session=None): @require_admin_context def service_get_all(context, disabled=None): - session = get_session() - query = session.query(models.Service).\ - filter_by(deleted=can_read_deleted(context)) + query = model_query(context, models.Service) if disabled is not None: query = query.filter_by(disabled=disabled) @@ -196,44 +213,35 @@ def service_get_all(context, disabled=None): @require_admin_context def service_get_all_by_topic(context, topic): - session = get_session() - return session.query(models.Service).\ - filter_by(deleted=False).\ - filter_by(disabled=False).\ - filter_by(topic=topic).\ - all() + return model_query(context, models.Service, read_deleted="no").\ + filter_by(disabled=False).\ + filter_by(topic=topic).\ + all() @require_admin_context def service_get_by_host_and_topic(context, host, topic): - session = get_session() - return session.query(models.Service).\ - filter_by(deleted=False).\ - filter_by(disabled=False).\ - filter_by(host=host).\ - filter_by(topic=topic).\ - first() + return model_query(context, models.Service, read_deleted="no").\ + filter_by(disabled=False).\ + filter_by(host=host).\ + filter_by(topic=topic).\ + first() @require_admin_context def service_get_all_by_host(context, host): - session = get_session() - return session.query(models.Service).\ - filter_by(deleted=False).\ - filter_by(host=host).\ - all() + return model_query(context, models.Service, read_deleted="no").\ + filter_by(host=host).\ + all() @require_admin_context def service_get_all_compute_by_host(context, host): - topic = 'compute' - session = get_session() - result = session.query(models.Service).\ - options(joinedload('compute_node')).\ - filter_by(deleted=False).\ - filter_by(host=host).\ - filter_by(topic=topic).\ - all() + result = model_query(context, models.Service, read_deleted="no").\ + options(joinedload('compute_node')).\ + filter_by(host=host).\ + filter_by(topic="compute").\ + all() if not result: raise exception.ComputeHostNotFound(host=host) @@ -244,13 +252,14 @@ def service_get_all_compute_by_host(context, host): @require_admin_context def _service_get_all_topic_subquery(context, session, topic, subq, label): sort_value = getattr(subq.c, label) - return session.query(models.Service, func.coalesce(sort_value, 0)).\ - filter_by(topic=topic).\ - filter_by(deleted=False).\ - filter_by(disabled=False).\ - outerjoin((subq, models.Service.host == subq.c.host)).\ - order_by(sort_value).\ - all() + return model_query(context, models.Service, + func.coalesce(sort_value, 0), + session=session, read_deleted="no").\ + filter_by(topic=topic).\ + filter_by(disabled=False).\ + outerjoin((subq, models.Service.host == subq.c.host)).\ + order_by(sort_value).\ + all() @require_admin_context @@ -266,9 +275,9 @@ def service_get_all_compute_sorted(context): # ON services.host = inst_cores.host topic = 'compute' label = 'instance_cores' - subq = session.query(models.Instance.host, - func.sum(models.Instance.vcpus).label(label)).\ - filter_by(deleted=False).\ + subq = model_query(context, models.Instance.host, + func.sum(models.Instance.vcpus).label(label), + session=session, read_deleted="no").\ group_by(models.Instance.host).\ subquery() return _service_get_all_topic_subquery(context, @@ -284,9 +293,9 @@ def service_get_all_network_sorted(context): with session.begin(): topic = 'network' label = 'network_count' - subq = session.query(models.Network.host, - func.count(models.Network.id).label(label)).\ - filter_by(deleted=False).\ + subq = model_query(context, models.Network.host, + func.count(models.Network.id).label(label), + session=session, read_deleted="no").\ group_by(models.Network.host).\ subquery() return _service_get_all_topic_subquery(context, @@ -302,9 +311,9 @@ def service_get_all_volume_sorted(context): with session.begin(): topic = 'volume' label = 'volume_gigabytes' - subq = session.query(models.Volume.host, - func.sum(models.Volume.size).label(label)).\ - filter_by(deleted=False).\ + subq = model_query(context, models.Volume.host, + func.sum(models.Volume.size).label(label), + session=session, read_deleted="no").\ group_by(models.Volume.host).\ subquery() return _service_get_all_topic_subquery(context, @@ -316,12 +325,11 @@ def service_get_all_volume_sorted(context): @require_admin_context def service_get_by_args(context, host, binary): - session = get_session() - result = session.query(models.Service).\ + result = model_query(context, models.Service).\ filter_by(host=host).\ filter_by(binary=binary).\ - filter_by(deleted=can_read_deleted(context)).\ first() + if not result: raise exception.HostBinaryNotFound(host=host, binary=binary) @@ -352,12 +360,8 @@ def service_update(context, service_id, values): @require_admin_context def compute_node_get(context, compute_id, session=None): - if not session: - session = get_session() - - result = session.query(models.ComputeNode).\ + result = model_query(context, models.ComputeNode, session=session).\ filter_by(id=compute_id).\ - filter_by(deleted=can_read_deleted(context)).\ first() if not result: @@ -368,12 +372,9 @@ def compute_node_get(context, compute_id, session=None): @require_admin_context def compute_node_get_all(context, session=None): - if not session: - session = get_session() - - return session.query(models.ComputeNode).\ + return model_query(context, models.ComputeNode, session=session).\ options(joinedload('service')).\ - filter_by(deleted=can_read_deleted(context)) + all() @require_admin_context @@ -398,12 +399,8 @@ def compute_node_update(context, compute_id, values): @require_admin_context def certificate_get(context, certificate_id, session=None): - if not session: - session = get_session() - - result = session.query(models.Certificate).\ + result = model_query(context, models.Certificate, session=session).\ filter_by(id=certificate_id).\ - filter_by(deleted=can_read_deleted(context)).\ first() if not result: @@ -433,29 +430,23 @@ def certificate_destroy(context, certificate_id): @require_admin_context def certificate_get_all_by_project(context, project_id): - session = get_session() - return session.query(models.Certificate).\ + return model_query(context, models.Certificate, read_deleted="no").\ filter_by(project_id=project_id).\ - filter_by(deleted=False).\ all() @require_admin_context def certificate_get_all_by_user(context, user_id): - session = get_session() - return session.query(models.Certificate).\ + return model_query(context, models.Certificate, read_deleted="no").\ filter_by(user_id=user_id).\ - filter_by(deleted=False).\ all() @require_admin_context -def certificate_get_all_by_user_and_project(_context, user_id, project_id): - session = get_session() - return session.query(models.Certificate).\ +def certificate_get_all_by_user_and_project(context, user_id, project_id): + return model_query(context, models.Certificate, read_deleted="no").\ filter_by(user_id=user_id).\ filter_by(project_id=project_id).\ - filter_by(deleted=False).\ all() @@ -476,23 +467,12 @@ def certificate_update(context, certificate_id, values): @require_context def floating_ip_get(context, id): - session = get_session() - result = None - if is_admin_context(context): - result = session.query(models.FloatingIp).\ - options(joinedload('fixed_ip')).\ - options(joinedload_all('fixed_ip.instance')).\ - filter_by(id=id).\ - filter_by(deleted=can_read_deleted(context)).\ - first() - elif is_user_context(context): - result = session.query(models.FloatingIp).\ - options(joinedload('fixed_ip')).\ - options(joinedload_all('fixed_ip.instance')).\ - filter_by(project_id=context.project_id).\ - filter_by(id=id).\ - filter_by(deleted=False).\ - first() + result = model_query(context, models.FloatingIp, project_only=True).\ + options(joinedload('fixed_ip')).\ + options(joinedload_all('fixed_ip.instance')).\ + filter_by(id=id).\ + first() + if not result: raise exception.FloatingIpNotFound(id=id) @@ -504,10 +484,10 @@ def floating_ip_allocate_address(context, project_id): authorize_project_context(context, project_id) session = get_session() with session.begin(): - floating_ip_ref = session.query(models.FloatingIp).\ + floating_ip_ref = model_query(context, models.FloatingIp, + session=session, read_deleted="no").\ filter_by(fixed_ip_id=None).\ filter_by(project_id=None).\ - filter_by(deleted=False).\ with_lockmode('update').\ first() # NOTE(vish): if with_lockmode isn't supported, as in sqlite, @@ -530,12 +510,10 @@ def floating_ip_create(context, values): @require_context def floating_ip_count_by_project(context, project_id): authorize_project_context(context, project_id) - session = get_session() # TODO(tr3buchet): why leave auto_assigned floating IPs out? - return session.query(models.FloatingIp).\ + return model_query(context, models.FloatingIp, read_deleted="no").\ filter_by(project_id=project_id).\ filter_by(auto_assigned=False).\ - filter_by(deleted=False).\ count() @@ -607,13 +585,15 @@ def floating_ip_set_auto_assigned(context, address): floating_ip_ref.save(session=session) +@require_admin_context +def _floating_ip_get_all(context): + return model_query(context, models.FloatingIp, read_deleted="no").\ + options(joinedload_all('fixed_ip.instance')) + + @require_admin_context def floating_ip_get_all(context): - session = get_session() - floating_ip_refs = session.query(models.FloatingIp).\ - options(joinedload_all('fixed_ip.instance')).\ - filter_by(deleted=False).\ - all() + floating_ip_refs = _floating_ip_get_all(context).all() if not floating_ip_refs: raise exception.NoFloatingIpsDefined() return floating_ip_refs @@ -621,12 +601,9 @@ def floating_ip_get_all(context): @require_admin_context def floating_ip_get_all_by_host(context, host): - session = get_session() - floating_ip_refs = session.query(models.FloatingIp).\ - options(joinedload_all('fixed_ip.instance')).\ - filter_by(host=host).\ - filter_by(deleted=False).\ - all() + floating_ip_refs = _floating_ip_get_all(context).\ + filter_by(host=host).\ + all() if not floating_ip_refs: raise exception.FloatingIpNotFoundForHost(host=host) return floating_ip_refs @@ -635,25 +612,18 @@ def floating_ip_get_all_by_host(context, host): @require_context def floating_ip_get_all_by_project(context, project_id): authorize_project_context(context, project_id) - session = get_session() # TODO(tr3buchet): why do we not want auto_assigned floating IPs here? - return session.query(models.FloatingIp).\ - options(joinedload_all('fixed_ip.instance')).\ + return _floating_ip_get_all(context).\ filter_by(project_id=project_id).\ filter_by(auto_assigned=False).\ - filter_by(deleted=False).\ all() @require_context def floating_ip_get_by_address(context, address, session=None): - if not session: - session = get_session() - - result = session.query(models.FloatingIp).\ + result = model_query(context, models.FloatingIp, session=session).\ options(joinedload_all('fixed_ip.network')).\ filter_by(address=address).\ - filter_by(deleted=can_read_deleted(context)).\ first() if not result: @@ -675,10 +645,9 @@ def floating_ip_get_by_fixed_address(context, fixed_address, session=None): fixed_ip = fixed_ip_get_by_address(context, fixed_address, session) fixed_ip_id = fixed_ip['id'] - return session.query(models.FloatingIp).\ + return model_query(context, models.FloatingIp, session=session).\ options(joinedload_all('fixed_ip.network')).\ filter_by(fixed_ip_id=fixed_ip_id).\ - filter_by(deleted=can_read_deleted(context)).\ all() # NOTE(tr3buchet) please don't invent an exception here, empty list is fine @@ -708,10 +677,10 @@ def fixed_ip_associate(context, address, instance_id, network_id=None, with session.begin(): network_or_none = or_(models.FixedIp.network_id == network_id, models.FixedIp.network_id == None) - fixed_ip_ref = session.query(models.FixedIp).\ + fixed_ip_ref = model_query(context, models.FixedIp, session=session, + read_deleted="no").\ filter(network_or_none).\ filter_by(reserved=reserved).\ - filter_by(deleted=False).\ filter_by(address=address).\ with_lockmode('update').\ first() @@ -740,10 +709,10 @@ def fixed_ip_associate_pool(context, network_id, instance_id=None, host=None): with session.begin(): network_or_none = or_(models.FixedIp.network_id == network_id, models.FixedIp.network_id == None) - fixed_ip_ref = session.query(models.FixedIp).\ + fixed_ip_ref = model_query(context, models.FixedIp, session=session, + read_deleted="no").\ filter(network_or_none).\ filter_by(reserved=False).\ - filter_by(deleted=False).\ filter_by(instance=None).\ filter_by(host=None).\ with_lockmode('update').\ @@ -767,7 +736,7 @@ def fixed_ip_associate_pool(context, network_id, instance_id=None, host=None): @require_context -def fixed_ip_create(_context, values): +def fixed_ip_create(context, values): fixed_ip_ref = models.FixedIp() fixed_ip_ref.update(values) fixed_ip_ref.save() @@ -775,7 +744,7 @@ def fixed_ip_create(_context, values): @require_context -def fixed_ip_bulk_create(_context, ips): +def fixed_ip_bulk_create(context, ips): session = get_session() with session.begin(): for ip in ips: @@ -796,12 +765,14 @@ def fixed_ip_disassociate(context, address): @require_admin_context -def fixed_ip_disassociate_all_by_timeout(_context, host, time): +def fixed_ip_disassociate_all_by_timeout(context, host, time): session = get_session() - inner_q = session.query(models.Network.id).\ + inner_q = model_query(context, models.Network.id, session=session, + read_deleted="yes").\ filter_by(host=host).\ subquery() - result = session.query(models.FixedIp).\ + result = model_query(context, models.FixedIp, session=session, + read_deleted="yes").\ filter(models.FixedIp.network_id.in_(inner_q)).\ filter(models.FixedIp.updated_at < time).\ filter(models.FixedIp.instance_id != None).\ @@ -815,17 +786,16 @@ def fixed_ip_disassociate_all_by_timeout(_context, host, time): @require_context def fixed_ip_get(context, id, session=None): - if not session: - session = get_session() - result = session.query(models.FixedIp).\ + result = model_query(context, models.FixedIp, session=session).\ filter_by(id=id).\ - filter_by(deleted=can_read_deleted(context)).\ options(joinedload('floating_ips')).\ options(joinedload('network')).\ first() if not result: raise exception.FixedIpNotFound(id=id) + # FIXME(sirp): shouldn't we just use project_only here to restrict the + # results? if is_user_context(context): authorize_project_context(context, result.instance.project_id) @@ -834,9 +804,8 @@ def fixed_ip_get(context, id, session=None): @require_admin_context def fixed_ip_get_all(context, session=None): - if not session: - session = get_session() - result = session.query(models.FixedIp).\ + result = model_query(context, models.FixedIp, session=session, + read_deleted="yes").\ options(joinedload('floating_ips')).\ all() if not result: @@ -847,9 +816,7 @@ def fixed_ip_get_all(context, session=None): @require_admin_context def fixed_ip_get_all_by_instance_host(context, host=None): - session = get_session() - - result = session.query(models.FixedIp).\ + result = model_query(context, models.FixedIp, read_deleted="yes").\ options(joinedload('floating_ips')).\ join(models.FixedIp.instance).\ filter_by(state=1).\ @@ -864,11 +831,9 @@ def fixed_ip_get_all_by_instance_host(context, host=None): @require_context def fixed_ip_get_by_address(context, address, session=None): - if not session: - session = get_session() - result = session.query(models.FixedIp).\ + result = model_query(context, models.FixedIp, session=session, + read_deleted="yes").\ filter_by(address=address).\ - filter_by(deleted=can_read_deleted(context)).\ options(joinedload('floating_ips')).\ options(joinedload('network')).\ options(joinedload('instance')).\ @@ -876,6 +841,8 @@ def fixed_ip_get_by_address(context, address, session=None): if not result: raise exception.FixedIpNotFoundForAddress(address=address) + # NOTE(sirp): shouldn't we just use project_only here to restrict the + # results? if is_user_context(context): authorize_project_context(context, result.instance.project_id) @@ -884,42 +851,41 @@ def fixed_ip_get_by_address(context, address, session=None): @require_context def fixed_ip_get_by_instance(context, instance_id): - session = get_session() - rv = session.query(models.FixedIp).\ + result = model_query(context, models.FixedIp, read_deleted="no").\ options(joinedload('floating_ips')).\ filter_by(instance_id=instance_id).\ - filter_by(deleted=False).\ all() - if not rv: + + if not result: raise exception.FixedIpNotFoundForInstance(instance_id=instance_id) - return rv + + return result @require_context def fixed_ip_get_by_network_host(context, network_id, host): - session = get_session() - rv = session.query(models.FixedIp).\ + result = model_query(context, models.FixedIp, read_deleted="no").\ filter_by(network_id=network_id).\ filter_by(host=host).\ - filter_by(deleted=False).\ first() - if not rv: + + if not result: raise exception.FixedIpNotFoundForNetworkHost(network_id=network_id, host=host) - return rv + return result @require_context def fixed_ip_get_by_virtual_interface(context, vif_id): - session = get_session() - rv = session.query(models.FixedIp).\ + result = model_query(context, models.FixedIp, read_deleted="no").\ options(joinedload('floating_ips')).\ filter_by(virtual_interface_id=vif_id).\ - filter_by(deleted=False).\ all() - if not rv: + + if not result: raise exception.FixedIpNotFoundForVirtualInterface(vif_id=vif_id) - return rv + + return result @require_admin_context @@ -973,18 +939,21 @@ def virtual_interface_update(context, vif_id, values): return vif_ref +@require_context +def _virtual_interface_query(context, session=None): + return model_query(context, models.VirtualInterface, session=session, + read_deleted="yes").\ + options(joinedload('fixed_ips')) + + @require_context def virtual_interface_get(context, vif_id, session=None): """Gets a virtual interface from the table. :param vif_id: = id of the virtual interface """ - if not session: - session = get_session() - - vif_ref = session.query(models.VirtualInterface).\ + vif_ref = _virtual_interface_query(context, session=session).\ filter_by(id=vif_id).\ - options(joinedload('fixed_ips')).\ first() return vif_ref @@ -995,10 +964,8 @@ def virtual_interface_get_by_address(context, address): :param address: = the address of the interface you're looking to get """ - session = get_session() - vif_ref = session.query(models.VirtualInterface).\ + vif_ref = _virtual_interface_query(context).\ filter_by(address=address).\ - options(joinedload('fixed_ips')).\ first() return vif_ref @@ -1009,10 +976,8 @@ def virtual_interface_get_by_uuid(context, vif_uuid): :param vif_uuid: the uuid of the interface you're looking to get """ - session = get_session() - vif_ref = session.query(models.VirtualInterface).\ + vif_ref = _virtual_interface_query(context).\ filter_by(uuid=vif_uuid).\ - options(joinedload('fixed_ips')).\ first() return vif_ref @@ -1023,10 +988,8 @@ def virtual_interface_get_by_fixed_ip(context, fixed_ip_id): :param fixed_ip_id: = id of the fixed_ip """ - session = get_session() - vif_ref = session.query(models.VirtualInterface).\ + vif_ref = _virtual_interface_query(context).\ filter_by(fixed_ip_id=fixed_ip_id).\ - options(joinedload('fixed_ips')).\ first() return vif_ref @@ -1038,10 +1001,8 @@ def virtual_interface_get_by_instance(context, instance_id): :param instance_id: = id of the instance to retrieve vifs for """ - session = get_session() - vif_refs = session.query(models.VirtualInterface).\ + vif_refs = _virtual_interface_query(context).\ filter_by(instance_id=instance_id).\ - options(joinedload('fixed_ips')).\ all() return vif_refs @@ -1050,11 +1011,9 @@ def virtual_interface_get_by_instance(context, instance_id): def virtual_interface_get_by_instance_and_network(context, instance_id, network_id): """Gets virtual interface for instance that's associated with network.""" - session = get_session() - vif_ref = session.query(models.VirtualInterface).\ + vif_ref = _virtual_interface_query(context).\ filter_by(instance_id=instance_id).\ filter_by(network_id=network_id).\ - options(joinedload('fixed_ips')).\ first() return vif_ref @@ -1065,10 +1024,8 @@ def virtual_interface_get_by_network(context, network_id): :param network_id: = network to retrieve vifs for """ - session = get_session() - vif_refs = session.query(models.VirtualInterface).\ + vif_refs = _virtual_interface_query(context).\ filter_by(network_id=network_id).\ - options(joinedload('fixed_ips')).\ all() return vif_refs @@ -1100,10 +1057,7 @@ def virtual_interface_delete_by_instance(context, instance_id): @require_context def virtual_interface_get_all(context): """Get all vifs""" - session = get_session() - vif_refs = session.query(models.VirtualInterface).\ - options(joinedload('fixed_ips')).\ - all() + vif_refs = _virtual_interface_query(context).all() return vif_refs @@ -1143,12 +1097,12 @@ def instance_create(context, values): @require_admin_context def instance_data_get_for_project(context, project_id): - session = get_session() - result = session.query(func.count(models.Instance.id), - func.sum(models.Instance.vcpus), - func.sum(models.Instance.memory_mb)).\ + result = model_query(context, + func.count(models.Instance.id), + func.sum(models.Instance.vcpus), + func.sum(models.Instance.memory_mb), + read_deleted="no").\ filter_by(project_id=project_id).\ - filter_by(deleted=False).\ first() # NOTE(vish): convert None to 0 return (result[0] or 0, result[1] or 0, result[2] or 0) @@ -1200,32 +1154,34 @@ def instance_stop(context, instance_id): @require_context def instance_get_by_uuid(context, uuid, session=None): - partial = _build_instance_get(context, session=session) - result = partial.filter_by(uuid=uuid) - result = result.first() + result = _build_instance_get(context, session=session).\ + filter_by(uuid=uuid).\ + first() + if not result: # FIXME(sirp): it would be nice if InstanceNotFound would accept a # uuid parameter as well raise exception.InstanceNotFound(instance_id=uuid) + return result @require_context def instance_get(context, instance_id, session=None): - partial = _build_instance_get(context, session=session) - result = partial.filter_by(id=instance_id) - result = result.first() + result = _build_instance_get(context, session=session).\ + filter_by(id=instance_id).\ + first() + if not result: raise exception.InstanceNotFound(instance_id=instance_id) + return result @require_context def _build_instance_get(context, session=None): - if not session: - session = get_session() - - partial = session.query(models.Instance).\ + return model_query(context, models.Instance, session=session, + project_only=True).\ options(joinedload_all('fixed_ips.floating_ips')).\ options(joinedload_all('fixed_ips.network')).\ options(joinedload_all('fixed_ips.virtual_interface')).\ @@ -1234,24 +1190,15 @@ def _build_instance_get(context, session=None): options(joinedload('metadata')).\ options(joinedload('instance_type')) - if is_admin_context(context): - partial = partial.filter_by(deleted=can_read_deleted(context)) - elif is_user_context(context): - partial = partial.filter_by(project_id=context.project_id).\ - filter_by(deleted=False) - return partial - @require_admin_context def instance_get_all(context): - session = get_session() - return session.query(models.Instance).\ + return model_query(context, models.Instance).\ options(joinedload_all('fixed_ips.floating_ips')).\ options(joinedload('security_groups')).\ options(joinedload_all('fixed_ips.network')).\ options(joinedload('metadata')).\ options(joinedload('instance_type')).\ - filter_by(deleted=can_read_deleted(context)).\ all() @@ -1412,84 +1359,45 @@ def instance_get_active_by_window_joined(context, begin, end=None, @require_admin_context -def instance_get_all_by_user(context, user_id): - session = get_session() - return session.query(models.Instance).\ +def _instance_get_all_query(context, project_only=False): + return model_query(context, models.Instance, project_only=project_only).\ options(joinedload_all('fixed_ips.floating_ips')).\ options(joinedload('security_groups')).\ options(joinedload_all('fixed_ips.network')).\ options(joinedload('metadata')).\ - options(joinedload('instance_type')).\ - filter_by(deleted=can_read_deleted(context)).\ - filter_by(user_id=user_id).\ - all() + options(joinedload('instance_type')) + + +@require_admin_context +def instance_get_all_by_user(context, user_id): + return _instance_get_all_query(context).filter_by(user_id=user_id).all() @require_admin_context def instance_get_all_by_host(context, host): - session = get_session() - read_deleted = can_read_deleted(context) - return session.query(models.Instance).\ - options(joinedload_all('fixed_ips.floating_ips')).\ - options(joinedload('security_groups')).\ - options(joinedload_all('fixed_ips.network')).\ - options(joinedload('metadata')).\ - options(joinedload('instance_type')).\ - filter_by(host=host).\ - filter_by(deleted=read_deleted).\ - all() + return _instance_get_all_query(context).filter_by(host=host).all() @require_context def instance_get_all_by_project(context, project_id): authorize_project_context(context, project_id) - - session = get_session() - return session.query(models.Instance).\ - options(joinedload_all('fixed_ips.floating_ips')).\ - options(joinedload('security_groups')).\ - options(joinedload_all('fixed_ips.network')).\ - options(joinedload('metadata')).\ - options(joinedload('instance_type')).\ - filter_by(project_id=project_id).\ - filter_by(deleted=can_read_deleted(context)).\ - all() + return _instance_get_all_query(context).\ + filter_by(project_id=project_id).\ + all() @require_context def instance_get_all_by_reservation(context, reservation_id): - session = get_session() - query = session.query(models.Instance).\ + return _instance_get_all_query(context, project_only=True).\ filter_by(reservation_id=reservation_id).\ - options(joinedload_all('fixed_ips.floating_ips')).\ - options(joinedload('security_groups')).\ - options(joinedload_all('fixed_ips.network')).\ - options(joinedload('metadata')).\ - options(joinedload('instance_type')) - - if is_admin_context(context): - return query.\ - filter_by(deleted=can_read_deleted(context)).\ - all() - elif is_user_context(context): - return query.\ - filter_by(project_id=context.project_id).\ - filter_by(deleted=False).\ - all() + all() @require_admin_context def instance_get_project_vpn(context, project_id): - session = get_session() - return session.query(models.Instance).\ - options(joinedload_all('fixed_ips.floating_ips')).\ - options(joinedload('security_groups')).\ - options(joinedload_all('fixed_ips.network')).\ - options(joinedload('metadata')).\ - options(joinedload('instance_type')).\ + return _instance_get_all_query(context).\ filter_by(project_id=project_id).\ filter_by(image_ref=str(FLAGS.vpn_image_id)).\ - filter_by(deleted=can_read_deleted(context)).\ first() @@ -1630,8 +1538,8 @@ def instance_get_actions(context, instance_id): instance_id = instance.id return session.query(models.InstanceActions).\ - filter_by(instance_id=instance_id).\ - all() + filter_by(instance_id=instance_id).\ + all() @require_context @@ -1681,27 +1589,22 @@ def key_pair_destroy_all_by_user(context, user_id): @require_context def key_pair_get(context, user_id, name, session=None): authorize_user_context(context, user_id) - - if not session: - session = get_session() - - result = session.query(models.KeyPair).\ + result = model_query(context, models.KeyPair, session=session).\ filter_by(user_id=user_id).\ filter_by(name=name).\ - filter_by(deleted=can_read_deleted(context)).\ first() + if not result: raise exception.KeypairNotFound(user_id=user_id, name=name) + return result @require_context def key_pair_get_all_by_user(context, user_id): authorize_user_context(context, user_id) - session = get_session() - return session.query(models.KeyPair).\ + return model_query(context, models.KeyPair, read_deleted="no").\ filter_by(user_id=user_id).\ - filter_by(deleted=False).\ all() @@ -1728,8 +1631,8 @@ def network_associate(context, project_id, force=False): with session.begin(): def network_query(project_filter): - return session.query(models.Network).\ - filter_by(deleted=False).\ + return model_query(context, models.Network, session=session, + read_deleted="no").\ filter_by(project_id=project_filter).\ with_lockmode('update').\ first() @@ -1757,41 +1660,35 @@ def network_associate(context, project_id, force=False): @require_admin_context def network_count(context): - session = get_session() - return session.query(models.Network).\ - filter_by(deleted=can_read_deleted(context)).\ - count() + return model_query(context, models.Network).count() + + +@require_admin_context +def _network_ips_query(context, network_id): + return model_query(context, models.FixedIp, read_deleted="no").\ + filter_by(network_id=network_id) @require_admin_context def network_count_allocated_ips(context, network_id): - session = get_session() - return session.query(models.FixedIp).\ - filter_by(network_id=network_id).\ - filter_by(allocated=True).\ - filter_by(deleted=False).\ - count() + return _network_ips_query(context, network_id).\ + filter_by(allocated=True).\ + count() @require_admin_context def network_count_available_ips(context, network_id): - session = get_session() - return session.query(models.FixedIp).\ - filter_by(network_id=network_id).\ - filter_by(allocated=False).\ - filter_by(reserved=False).\ - filter_by(deleted=False).\ - count() + return _network_ips_query(context, network_id).\ + filter_by(allocated=False).\ + filter_by(reserved=False).\ + count() @require_admin_context def network_count_reserved_ips(context, network_id): - session = get_session() - return session.query(models.FixedIp).\ - filter_by(network_id=network_id).\ - filter_by(reserved=True).\ - filter_by(deleted=False).\ - count() + return _network_ips_query(context, network_id).\ + filter_by(reserved=True).\ + count() @require_admin_context @@ -1810,7 +1707,7 @@ def network_create_safe(context, values): def network_delete_safe(context, network_id): session = get_session() with session.begin(): - network_ref = network_get(context, network_id=network_id, \ + network_ref = network_get(context, network_id=network_id, session=session) session.delete(network_ref) @@ -1831,21 +1728,11 @@ def network_disassociate_all(context): @require_context def network_get(context, network_id, session=None): - if not session: - session = get_session() - result = None + result = model_query(context, models.Network, session=session, + project_only=True).\ + filter_by(id=network_id).\ + first() - if is_admin_context(context): - result = session.query(models.Network).\ - filter_by(id=network_id).\ - filter_by(deleted=can_read_deleted(context)).\ - first() - elif is_user_context(context): - result = session.query(models.Network).\ - filter_by(project_id=context.project_id).\ - filter_by(id=network_id).\ - filter_by(deleted=False).\ - first() if not result: raise exception.NetworkNotFound(network_id=network_id) @@ -1854,23 +1741,23 @@ def network_get(context, network_id, session=None): @require_admin_context def network_get_all(context): - session = get_session() - result = session.query(models.Network).\ - filter_by(deleted=False).all() + result = model_query(context, models.Network, read_deleted="no").all() + if not result: raise exception.NoNetworksFound() + return result @require_admin_context def network_get_all_by_uuids(context, network_uuids, project_id=None): - session = get_session() project_or_none = or_(models.Network.project_id == project_id, - models.Network.project_id == None) - result = session.query(models.Network).\ - filter(models.Network.uuid.in_(network_uuids)).\ - filter(project_or_none).\ - filter_by(deleted=False).all() + models.Network.project_id == None) + result = model_query(context, models.Network, read_deleted="no").\ + filter(models.Network.uuid.in_(network_uuids)).\ + filter(project_or_none).\ + all() + if not result: raise exception.NoNetworksFound() @@ -1903,100 +1790,92 @@ def network_get_all_by_uuids(context, network_uuids, project_id=None): @require_admin_context def network_get_associated_fixed_ips(context, network_id): - session = get_session() - return session.query(models.FixedIp).\ - options(joinedload_all('instance')).\ - filter_by(network_id=network_id).\ - filter(models.FixedIp.instance_id != None).\ - filter(models.FixedIp.virtual_interface_id != None).\ - filter_by(deleted=False).\ - all() + # FIXME(sirp): since this returns fixed_ips, this would be better named + # fixed_ip_get_all_by_network. + return model_query(context, models.FixedIp, read_deleted="no").\ + options(joinedload_all('instance')).\ + filter_by(network_id=network_id).\ + filter(models.FixedIp.instance_id != None).\ + filter(models.FixedIp.virtual_interface_id != None).\ + all() + + +@require_admin_context +def _network_get_query(context, session=None): + return model_query(context, models.Network, session=session, + read_deleted="no") @require_admin_context def network_get_by_bridge(context, bridge): - session = get_session() - result = session.query(models.Network).\ - filter_by(bridge=bridge).\ - filter_by(deleted=False).\ - first() + result = _network_get_query(context).filter_by(bridge=bridge).first() if not result: raise exception.NetworkNotFoundForBridge(bridge=bridge) + return result @require_admin_context def network_get_by_uuid(context, uuid): - session = get_session() - result = session.query(models.Network).\ - filter_by(uuid=uuid).\ - filter_by(deleted=False).\ - first() + result = _network_get_query(context).filter_by(uuid=uuid).first() if not result: raise exception.NetworkNotFoundForUUID(uuid=uuid) + return result @require_admin_context def network_get_by_cidr(context, cidr): - session = get_session() - result = session.query(models.Network).\ + result = _network_get_query(context).\ filter(or_(models.Network.cidr == cidr, models.Network.cidr_v6 == cidr)).\ - filter_by(deleted=False).\ first() if not result: raise exception.NetworkNotFoundForCidr(cidr=cidr) + return result @require_admin_context -def network_get_by_instance(_context, instance_id): +def network_get_by_instance(context, instance_id): # note this uses fixed IP to get to instance # only works for networks the instance has an IP from - session = get_session() - rv = session.query(models.Network).\ - filter_by(deleted=False).\ + result = _network_get_query(context).\ join(models.Network.fixed_ips).\ filter_by(instance_id=instance_id).\ - filter_by(deleted=False).\ first() - if not rv: + + if not result: raise exception.NetworkNotFoundForInstance(instance_id=instance_id) - return rv + + return result @require_admin_context -def network_get_all_by_instance(_context, instance_id): - session = get_session() - rv = session.query(models.Network).\ - filter_by(deleted=False).\ +def network_get_all_by_instance(context, instance_id): + result = _network_get_query(context).\ join(models.Network.fixed_ips).\ filter_by(instance_id=instance_id).\ - filter_by(deleted=False).\ all() - if not rv: + + if not result: raise exception.NetworkNotFoundForInstance(instance_id=instance_id) - return rv + + return result @require_admin_context def network_get_all_by_host(context, host): - session = get_session() - with session.begin(): - # NOTE(vish): return networks that have host set - # or that have a fixed ip with host set - host_filter = or_(models.Network.host == host, - models.FixedIp.host == host) - - return session.query(models.Network).\ - filter_by(deleted=False).\ + # NOTE(vish): return networks that have host set + # or that have a fixed ip with host set + host_filter = or_(models.Network.host == host, + models.FixedIp.host == host) + return _network_get_query(context).\ join(models.Network.fixed_ips).\ filter(host_filter).\ - filter_by(deleted=False).\ all() @@ -2004,11 +1883,11 @@ def network_get_all_by_host(context, host): def network_set_host(context, network_id, host_id): session = get_session() with session.begin(): - network_ref = session.query(models.Network).\ + network_ref = _network_get_query(context, session=session).\ filter_by(id=network_id).\ - filter_by(deleted=False).\ with_lockmode('update').\ first() + if not network_ref: raise exception.NetworkNotFound(network_id=network_id) @@ -2034,7 +1913,7 @@ def network_update(context, network_id, values): ################### -def queue_get_for(_context, topic, physical_node_id): +def queue_get_for(context, topic, physical_node_id): # FIXME(ja): this should be servername? return "%s.%s" % (topic, physical_node_id) @@ -2044,9 +1923,7 @@ def queue_get_for(_context, topic, physical_node_id): @require_admin_context def iscsi_target_count_by_host(context, host): - session = get_session() - return session.query(models.IscsiTarget).\ - filter_by(deleted=can_read_deleted(context)).\ + return model_query(context, models.IscsiTarget).\ filter_by(host=host).\ count() @@ -2076,15 +1953,14 @@ def auth_token_destroy(context, token_id): @require_admin_context def auth_token_get(context, token_hash, session=None): - if session is None: - session = get_session() - tk = session.query(models.AuthToken).\ + result = model_query(context, models.AuthToken, session=session).\ filter_by(token_hash=token_hash).\ - filter_by(deleted=can_read_deleted(context)).\ first() - if not tk: + + if not result: raise exception.AuthTokenNotFound(token=token_hash) - return tk + + return result @require_admin_context @@ -2097,7 +1973,7 @@ def auth_token_update(context, token_hash, values): @require_admin_context -def auth_token_create(_context, token): +def auth_token_create(context, token): tk = models.AuthToken() tk.update(token) tk.save() @@ -2109,29 +1985,30 @@ def auth_token_create(_context, token): @require_context def quota_get(context, project_id, resource, session=None): - if not session: - session = get_session() - result = session.query(models.Quota).\ + result = model_query(context, models.Quota, session=session, + read_deleted="no").\ filter_by(project_id=project_id).\ filter_by(resource=resource).\ - filter_by(deleted=False).\ first() + if not result: raise exception.ProjectQuotaNotFound(project_id=project_id) + return result @require_context def quota_get_all_by_project(context, project_id): authorize_project_context(context, project_id) - session = get_session() - result = {'project_id': project_id} - rows = session.query(models.Quota).\ + + rows = model_query(context, models.Quota, read_deleted="no").\ filter_by(project_id=project_id).\ - filter_by(deleted=False).\ all() + + result = {'project_id': project_id} for row in rows: result[row.resource] = row.hard_limit + return result @@ -2166,10 +2043,11 @@ def quota_destroy(context, project_id, resource): def quota_destroy_all_by_project(context, project_id): session = get_session() with session.begin(): - quotas = session.query(models.Quota).\ + quotas = model_query(context, models.Quota, session=session, + read_deleted="no").\ filter_by(project_id=project_id).\ - filter_by(deleted=False).\ all() + for quota_ref in quotas: quota_ref.delete(session=session) @@ -2181,18 +2059,21 @@ def quota_destroy_all_by_project(context, project_id): def volume_allocate_iscsi_target(context, volume_id, host): session = get_session() with session.begin(): - iscsi_target_ref = session.query(models.IscsiTarget).\ + iscsi_target_ref = model_query(context, models.IscsiTarget, + session=session, read_deleted="no").\ filter_by(volume=None).\ filter_by(host=host).\ - filter_by(deleted=False).\ with_lockmode('update').\ first() + # NOTE(vish): if with_lockmode isn't supported, as in sqlite, # then this has concurrency issues if not iscsi_target_ref: raise db.NoMoreTargets() + iscsi_target_ref.volume_id = volume_id session.add(iscsi_target_ref) + return iscsi_target_ref.target_num @@ -2224,12 +2105,13 @@ def volume_create(context, values): @require_admin_context def volume_data_get_for_project(context, project_id): - session = get_session() - result = session.query(func.count(models.Volume.id), - func.sum(models.Volume.size)).\ + result = model_query(context, + func.count(models.Volume.id), + func.sum(models.Volume.size), + read_deleted="no").\ filter_by(project_id=project_id).\ - filter_by(deleted=False).\ first() + # NOTE(vish): convert None to 0 return (result[0] or 0, result[1] or 0) @@ -2266,28 +2148,20 @@ def volume_detached(context, volume_id): @require_context -def volume_get(context, volume_id, session=None): - if not session: - session = get_session() - result = None +def _volume_get_query(context, session=None, project_only=False): + return model_query(context, models.Volume, session=session, + project_only=project_only).\ + options(joinedload('instance')).\ + options(joinedload('volume_metadata')).\ + options(joinedload('volume_type')) + + +@require_context +def volume_get(context, volume_id, session=None): + result = _volume_get_query(context, session=session, project_only=True).\ + filter_by(id=volume_id).\ + first() - if is_admin_context(context): - result = session.query(models.Volume).\ - options(joinedload('instance')).\ - options(joinedload('volume_metadata')).\ - options(joinedload('volume_type')).\ - filter_by(id=volume_id).\ - filter_by(deleted=can_read_deleted(context)).\ - first() - elif is_user_context(context): - result = session.query(models.Volume).\ - options(joinedload('instance')).\ - options(joinedload('volume_metadata')).\ - options(joinedload('volume_type')).\ - filter_by(project_id=context.project_id).\ - filter_by(id=volume_id).\ - filter_by(deleted=False).\ - first() if not result: raise exception.VolumeNotFound(volume_id=volume_id) @@ -2296,65 +2170,38 @@ def volume_get(context, volume_id, session=None): @require_admin_context def volume_get_all(context): - session = get_session() - return session.query(models.Volume).\ - options(joinedload('instance')).\ - options(joinedload('volume_metadata')).\ - options(joinedload('volume_type')).\ - filter_by(deleted=can_read_deleted(context)).\ - all() + return _volume_get_query(context).all() @require_admin_context def volume_get_all_by_host(context, host): - session = get_session() - return session.query(models.Volume).\ - options(joinedload('instance')).\ - options(joinedload('volume_metadata')).\ - options(joinedload('volume_type')).\ - filter_by(host=host).\ - filter_by(deleted=can_read_deleted(context)).\ - all() + return _volume_get_query(context).filter_by(host=host).all() @require_admin_context def volume_get_all_by_instance(context, instance_id): - session = get_session() - result = session.query(models.Volume).\ + result = model_query(context, models.Volume, read_deleted="no").\ options(joinedload('volume_metadata')).\ options(joinedload('volume_type')).\ filter_by(instance_id=instance_id).\ - filter_by(deleted=False).\ all() + if not result: raise exception.VolumeNotFoundForInstance(instance_id=instance_id) + return result @require_context def volume_get_all_by_project(context, project_id): authorize_project_context(context, project_id) - - session = get_session() - return session.query(models.Volume).\ - options(joinedload('instance')).\ - options(joinedload('volume_metadata')).\ - options(joinedload('volume_type')).\ - filter_by(project_id=project_id).\ - filter_by(deleted=can_read_deleted(context)).\ - all() + return _volume_get_query(context).filter_by(project_id=project_id).all() @require_admin_context def volume_get_instance(context, volume_id): - session = get_session() - result = session.query(models.Volume).\ - filter_by(id=volume_id).\ - filter_by(deleted=can_read_deleted(context)).\ - options(joinedload('instance')).\ - options(joinedload('volume_metadata')).\ - options(joinedload('volume_type')).\ - first() + result = _volume_get_query(context).filter_by(id=volume_id).first() + if not result: raise exception.VolumeNotFound(volume_id=volume_id) @@ -2363,10 +2210,10 @@ def volume_get_instance(context, volume_id): @require_admin_context def volume_get_iscsi_target_num(context, volume_id): - session = get_session() - result = session.query(models.IscsiTarget).\ + result = model_query(context, models.IscsiTarget, read_deleted="yes").\ filter_by(volume_id=volume_id).\ first() + if not result: raise exception.ISCSITargetNotFoundForVolume(volume_id=volume_id) @@ -2390,31 +2237,28 @@ def volume_update(context, volume_id, values): #################### +def _volume_metadata_get_query(context, volume_id, session=None): + return model_query(context, models.VolumeMetadata, + session=session, read_deleted="no").\ + filter_by(volume_id=volume_id) + @require_context @require_volume_exists def volume_metadata_get(context, volume_id): - session = get_session() + rows = _volume_metadata_get_query(context, volume_id).all() + result = {} + for row in rows: + result[row['key']] = row['value'] - meta_results = session.query(models.VolumeMetadata).\ - filter_by(volume_id=volume_id).\ - filter_by(deleted=False).\ - all() - - meta_dict = {} - for i in meta_results: - meta_dict[i['key']] = i['value'] - return meta_dict + return result @require_context @require_volume_exists def volume_metadata_delete(context, volume_id, key): - session = get_session() - session.query(models.VolumeMetadata).\ - filter_by(volume_id=volume_id).\ + _volume_metadata_get_query(context, volume_id).\ filter_by(key=key).\ - filter_by(deleted=False).\ update({'deleted': True, 'deleted_at': utils.utcnow(), 'updated_at': literal_column('updated_at')}) @@ -2423,10 +2267,7 @@ def volume_metadata_delete(context, volume_id, key): @require_context @require_volume_exists def volume_metadata_delete_all(context, volume_id): - session = get_session() - session.query(models.VolumeMetadata).\ - filter_by(volume_id=volume_id).\ - filter_by(deleted=False).\ + _volume_metadata_get_query(context, volume_id).\ update({'deleted': True, 'deleted_at': utils.utcnow(), 'updated_at': literal_column('updated_at')}) @@ -2435,19 +2276,14 @@ def volume_metadata_delete_all(context, volume_id): @require_context @require_volume_exists def volume_metadata_get_item(context, volume_id, key, session=None): - if not session: - session = get_session() - - meta_result = session.query(models.VolumeMetadata).\ - filter_by(volume_id=volume_id).\ + result = _volume_metadata_get_query(context, volume_id, session=session).\ filter_by(key=key).\ - filter_by(deleted=False).\ first() - if not meta_result: + if not result: raise exception.VolumeMetadataNotFound(metadata_key=key, volume_id=volume_id) - return meta_result + return result @require_context @@ -2513,21 +2349,11 @@ def snapshot_destroy(context, snapshot_id): @require_context def snapshot_get(context, snapshot_id, session=None): - if not session: - session = get_session() - result = None + result = model_query(context, models.Snapshot, session=session, + project_only=True).\ + filter_by(id=snapshot_id).\ + first() - if is_admin_context(context): - result = session.query(models.Snapshot).\ - filter_by(id=snapshot_id).\ - filter_by(deleted=can_read_deleted(context)).\ - first() - elif is_user_context(context): - result = session.query(models.Snapshot).\ - filter_by(project_id=context.project_id).\ - filter_by(id=snapshot_id).\ - filter_by(deleted=False).\ - first() if not result: raise exception.SnapshotNotFound(snapshot_id=snapshot_id) @@ -2536,20 +2362,14 @@ def snapshot_get(context, snapshot_id, session=None): @require_admin_context def snapshot_get_all(context): - session = get_session() - return session.query(models.Snapshot).\ - filter_by(deleted=can_read_deleted(context)).\ - all() + return model_query(context, models.Snapshot).all() @require_context def snapshot_get_all_by_project(context, project_id): authorize_project_context(context, project_id) - - session = get_session() - return session.query(models.Snapshot).\ + return model_query(context, models.Snapshot).\ filter_by(project_id=project_id).\ - filter_by(deleted=can_read_deleted(context)).\ all() @@ -2565,6 +2385,11 @@ def snapshot_update(context, snapshot_id, values): ################### +def _block_device_mapping_get_query(context, session=None): + return model_query(context, models.BlockDeviceMapping, session=session, + read_deleted="no") + + @require_context def block_device_mapping_create(context, values): bdm_ref = models.BlockDeviceMapping() @@ -2579,9 +2404,8 @@ def block_device_mapping_create(context, values): def block_device_mapping_update(context, bdm_id, values): session = get_session() with session.begin(): - session.query(models.BlockDeviceMapping).\ + _block_device_mapping_get_query(context, session=session).\ filter_by(id=bdm_id).\ - filter_by(deleted=False).\ update(values) @@ -2589,10 +2413,9 @@ def block_device_mapping_update(context, bdm_id, values): def block_device_mapping_update_or_create(context, values): session = get_session() with session.begin(): - result = session.query(models.BlockDeviceMapping).\ + result = _block_device_mapping_get_query(context, session=session).\ filter_by(instance_id=values['instance_id']).\ filter_by(device_name=values['device_name']).\ - filter_by(deleted=False).\ first() if not result: bdm_ref = models.BlockDeviceMapping() @@ -2607,25 +2430,20 @@ def block_device_mapping_update_or_create(context, values): if (virtual_name is not None and block_device.is_swap_or_ephemeral(virtual_name)): session.query(models.BlockDeviceMapping).\ - filter_by(instance_id=values['instance_id']).\ - filter_by(virtual_name=virtual_name).\ - filter(models.BlockDeviceMapping.device_name != - values['device_name']).\ - update({'deleted': True, - 'deleted_at': utils.utcnow(), - 'updated_at': literal_column('updated_at')}) + filter_by(instance_id=values['instance_id']).\ + filter_by(virtual_name=virtual_name).\ + filter(models.BlockDeviceMapping.device_name != + values['device_name']).\ + update({'deleted': True, + 'deleted_at': utils.utcnow(), + 'updated_at': literal_column('updated_at')}) @require_context def block_device_mapping_get_all_by_instance(context, instance_id): - session = get_session() - result = session.query(models.BlockDeviceMapping).\ - filter_by(instance_id=instance_id).\ - filter_by(deleted=False).\ - all() - if not result: - return [] - return result + return _block_device_mapping_get_query(context).\ + filter_by(instance_id=instance_id).\ + all() @require_context @@ -2644,87 +2462,70 @@ def block_device_mapping_destroy_by_instance_and_volume(context, instance_id, volume_id): session = get_session() with session.begin(): - session.query(models.BlockDeviceMapping).\ - filter_by(instance_id=instance_id).\ - filter_by(volume_id=volume_id).\ - filter_by(deleted=False).\ - update({'deleted': True, - 'deleted_at': utils.utcnow(), - 'updated_at': literal_column('updated_at')}) + _block_device_mapping_get_query(context, session=session).\ + filter_by(instance_id=instance_id).\ + filter_by(volume_id=volume_id).\ + update({'deleted': True, + 'deleted_at': utils.utcnow(), + 'updated_at': literal_column('updated_at')}) ################### +def _security_group_get_query(context, session=None, read_deleted=None, + project_only=False): + return model_query(context, models.SecurityGroup, session=session, + read_deleted=read_deleted, project_only=project_only).\ + options(joinedload_all('rules')) + @require_context def security_group_get_all(context): - session = get_session() - return session.query(models.SecurityGroup).\ - filter_by(deleted=can_read_deleted(context)).\ - options(joinedload_all('rules')).\ - all() + return _security_group_get_query(context).all() @require_context def security_group_get(context, security_group_id, session=None): - if not session: - session = get_session() - if is_admin_context(context): - result = session.query(models.SecurityGroup).\ - filter_by(deleted=can_read_deleted(context),).\ - filter_by(id=security_group_id).\ - options(joinedload_all('rules')).\ - options(joinedload_all('instances')).\ - first() - else: - result = session.query(models.SecurityGroup).\ - filter_by(deleted=False).\ - filter_by(id=security_group_id).\ - filter_by(project_id=context.project_id).\ - options(joinedload_all('rules')).\ - options(joinedload_all('instances')).\ - first() + result = _security_group_get_query(context, session=session, + project_only=True).\ + filter_by(id=security_group_id).\ + options(joinedload_all('instances')).\ + first() + if not result: raise exception.SecurityGroupNotFound( security_group_id=security_group_id) + return result @require_context def security_group_get_by_name(context, project_id, group_name): - session = get_session() - result = session.query(models.SecurityGroup).\ + result = _security_group_get_query(context, read_deleted="no").\ filter_by(project_id=project_id).\ filter_by(name=group_name).\ - filter_by(deleted=False).\ - options(joinedload_all('rules')).\ options(joinedload_all('instances')).\ first() + if not result: - raise exception.SecurityGroupNotFoundForProject(project_id=project_id, - security_group_id=group_name) + raise exception.SecurityGroupNotFoundForProject( + project_id=project_id, security_group_id=group_name) + return result @require_context def security_group_get_by_project(context, project_id): - session = get_session() - return session.query(models.SecurityGroup).\ - filter_by(project_id=project_id).\ - filter_by(deleted=False).\ - options(joinedload_all('rules')).\ - all() + return _security_group_get_query(context, read_deleted="no").\ + filter_by(project_id=project_id).\ + all() @require_context def security_group_get_by_instance(context, instance_id): - session = get_session() - return session.query(models.SecurityGroup).\ - filter_by(deleted=False).\ - options(joinedload_all('rules')).\ + return _security_group_get_query(context, read_deleted="no").\ join(models.SecurityGroup.instances).\ filter_by(id=instance_id).\ - filter_by(deleted=False).\ all() @@ -2787,64 +2588,41 @@ def security_group_destroy_all(context, session=None): ################### +def _security_group_rule_get_query(context, session=None): + return model_query(context, models.SecurityGroupIngressRule, + session=session) + + @require_context def security_group_rule_get(context, security_group_rule_id, session=None): - if not session: - session = get_session() - if is_admin_context(context): - result = session.query(models.SecurityGroupIngressRule).\ - filter_by(deleted=can_read_deleted(context)).\ - filter_by(id=security_group_rule_id).\ - first() - else: - # TODO(vish): Join to group and check for project_id - result = session.query(models.SecurityGroupIngressRule).\ - filter_by(deleted=False).\ + result = _security_group_rule_get_query(context, session=session).\ filter_by(id=security_group_rule_id).\ first() + if not result: raise exception.SecurityGroupNotFoundForRule( rule_id=security_group_rule_id) + return result @require_context def security_group_rule_get_by_security_group(context, security_group_id, session=None): - if not session: - session = get_session() - if is_admin_context(context): - result = session.query(models.SecurityGroupIngressRule).\ - filter_by(deleted=can_read_deleted(context)).\ + return _security_group_rule_get_query(context, session=session).\ filter_by(parent_group_id=security_group_id).\ options(joinedload_all('grantee_group.instances')).\ all() - else: - result = session.query(models.SecurityGroupIngressRule).\ - filter_by(deleted=False).\ - filter_by(parent_group_id=security_group_id).\ - options(joinedload_all('grantee_group.instances')).\ - all() - return result @require_context def security_group_rule_get_by_security_group_grantee(context, security_group_id, session=None): - if not session: - session = get_session() - if is_admin_context(context): - result = session.query(models.SecurityGroupIngressRule).\ - filter_by(deleted=can_read_deleted(context)).\ + + return _security_group_rule_get_query(context, session=session).\ filter_by(group_id=security_group_id).\ all() - else: - result = session.query(models.SecurityGroupIngressRule).\ - filter_by(deleted=False).\ - filter_by(group_id=security_group_id).\ - all() - return result @require_context @@ -2878,17 +2656,12 @@ def provider_fw_rule_create(context, rule): @require_admin_context def provider_fw_rule_get_all(context): - session = get_session() - return session.query(models.ProviderFirewallRule).\ - filter_by(deleted=can_read_deleted(context)).\ - all() + return model_query(context, models.ProviderFirewallRule).all() @require_admin_context def provider_fw_rule_get_all_by_cidr(context, cidr): - session = get_session() - return session.query(models.ProviderFirewallRule).\ - filter_by(deleted=can_read_deleted(context)).\ + return model_query(context, models.ProviderFirewallRule).\ filter_by(cidr=cidr).\ all() @@ -2909,12 +2682,8 @@ def provider_fw_rule_destroy(context, rule_id): @require_admin_context def user_get(context, id, session=None): - if not session: - session = get_session() - - result = session.query(models.User).\ + result = model_query(context, models.User, session=session).\ filter_by(id=id).\ - filter_by(deleted=can_read_deleted(context)).\ first() if not result: @@ -2925,12 +2694,8 @@ def user_get(context, id, session=None): @require_admin_context def user_get_by_access_key(context, access_key, session=None): - if not session: - session = get_session() - - result = session.query(models.User).\ + result = model_query(context, models.User, session=session).\ filter_by(access_key=access_key).\ - filter_by(deleted=can_read_deleted(context)).\ first() if not result: @@ -2940,7 +2705,7 @@ def user_get_by_access_key(context, access_key, session=None): @require_admin_context -def user_create(_context, values): +def user_create(context, values): user_ref = models.User() user_ref.update(values) user_ref.save() @@ -2965,10 +2730,7 @@ def user_delete(context, id): def user_get_all(context): - session = get_session() - return session.query(models.User).\ - filter_by(deleted=can_read_deleted(context)).\ - all() + return model_query(context, models.User).all() def user_get_roles(context, user_id): @@ -3038,7 +2800,7 @@ def user_update(context, user_id, values): ################### -def project_create(_context, values): +def project_create(context, values): project_ref = models.Project() project_ref.update(values) project_ref.save() @@ -3056,11 +2818,8 @@ def project_add_member(context, project_id, user_id): def project_get(context, id, session=None): - if not session: - session = get_session() - - result = session.query(models.Project).\ - filter_by(deleted=False).\ + result = model_query(context, models.Project, session=session, + read_deleted="no").\ filter_by(id=id).\ options(joinedload_all('members')).\ first() @@ -3072,22 +2831,20 @@ def project_get(context, id, session=None): def project_get_all(context): - session = get_session() - return session.query(models.Project).\ - filter_by(deleted=can_read_deleted(context)).\ + return model_query(context, models.Project).\ options(joinedload_all('members')).\ all() def project_get_by_user(context, user_id): - session = get_session() - user = session.query(models.User).\ - filter_by(deleted=can_read_deleted(context)).\ + user = model_query(context, models.User).\ filter_by(id=user_id).\ options(joinedload_all('projects')).\ first() + if not user: raise exception.UserNotFound(user_id=user_id) + return user.projects @@ -3127,15 +2884,16 @@ def project_get_networks(context, project_id, associate=True): # NOTE(tr3buchet): as before this function will associate # a project with a network if it doesn't have one and # associate is true - session = get_session() - result = session.query(models.Network).\ + result = model_query(context, models.Network, read_deleted="no").\ filter_by(project_id=project_id).\ - filter_by(deleted=False).all() + all() if not result: if not associate: return [] + return [network_associate(context, project_id)] + return result @@ -3167,24 +2925,28 @@ def migration_update(context, id, values): @require_admin_context def migration_get(context, id, session=None): - if not session: - session = get_session() - result = session.query(models.Migration).\ - filter_by(id=id).first() + result = model_query(context, models.Migration, session=session, + read_deleted="yes").\ + filter_by(id=id).\ + first() + if not result: raise exception.MigrationNotFound(migration_id=id) + return result @require_admin_context def migration_get_by_instance_and_status(context, instance_uuid, status): - session = get_session() - result = session.query(models.Migration).\ + result = model_query(context, models.Migration, read_deleted="yes").\ filter_by(instance_uuid=instance_uuid).\ - filter_by(status=status).first() + filter_by(status=status).\ + first() + if not result: raise exception.MigrationNotFoundByStatus(instance_id=instance_uuid, status=status) + return result @@ -3193,14 +2955,11 @@ def migration_get_all_unconfirmed(context, confirm_window, session=None): confirm_window = datetime.datetime.utcnow() - datetime.timedelta( seconds=confirm_window) - if not session: - session = get_session() - - results = session.query(models.Migration).\ + return model_query(context, models.Migration, session=session, + read_deleted="yes").\ filter(models.Migration.updated_at <= confirm_window).\ - filter_by(status="FINISHED").all() - - return results + filter_by(status="FINISHED").\ + all() ################## @@ -3214,11 +2973,10 @@ def console_pool_create(context, values): def console_pool_get(context, pool_id): - session = get_session() - result = session.query(models.ConsolePool).\ - filter_by(deleted=False).\ + result = model_query(context, models.ConsolePool, read_deleted="no").\ filter_by(id=pool_id).\ first() + if not result: raise exception.ConsolePoolNotFound(pool_id=pool_id) @@ -3227,27 +2985,26 @@ def console_pool_get(context, pool_id): def console_pool_get_by_host_type(context, compute_host, host, console_type): - session = get_session() - result = session.query(models.ConsolePool).\ + + result = model_query(context, models.ConsolePool, read_deleted="no").\ filter_by(host=host).\ filter_by(console_type=console_type).\ filter_by(compute_host=compute_host).\ - filter_by(deleted=False).\ options(joinedload('consoles')).\ first() + if not result: - raise exception.ConsolePoolNotFoundForHostType(host=host, - console_type=console_type, - compute_host=compute_host) + raise exception.ConsolePoolNotFoundForHostType( + host=host, console_type=console_type, + compute_host=compute_host) + return result def console_pool_get_all_by_host_type(context, host, console_type): - session = get_session() - return session.query(models.ConsolePool).\ + return model_query(context, models.ConsolePool, read_deleted="no").\ filter_by(host=host).\ filter_by(console_type=console_type).\ - filter_by(deleted=False).\ options(joinedload('consoles')).\ all() @@ -3262,55 +3019,57 @@ def console_create(context, values): def console_delete(context, console_id): session = get_session() with session.begin(): - # consoles are meant to be transient. (mdragon) + # NOTE(mdragon): consoles are meant to be transient. session.query(models.Console).\ filter_by(id=console_id).\ delete() def console_get_by_pool_instance(context, pool_id, instance_id): - session = get_session() - result = session.query(models.Console).\ + result = model_query(context, models.Console, read_deleted="yes").\ filter_by(pool_id=pool_id).\ filter_by(instance_id=instance_id).\ options(joinedload('pool')).\ first() + if not result: - raise exception.ConsoleNotFoundInPoolForInstance(pool_id=pool_id, - instance_id=instance_id) + raise exception.ConsoleNotFoundInPoolForInstance( + pool_id=pool_id, instance_id=instance_id) + return result def console_get_all_by_instance(context, instance_id): - session = get_session() - results = session.query(models.Console).\ + return model_query(context, models.Console, read_deleted="yes").\ filter_by(instance_id=instance_id).\ - options(joinedload('pool')).\ all() - return results def console_get(context, console_id, instance_id=None): - session = get_session() - query = session.query(models.Console).\ - filter_by(id=console_id) - if instance_id: + query = model_query(context, models.Console, read_deleted="yes").\ + filter_by(id=console_id).\ + options(joinedload('pool')) + + if instance_id is not None: query = query.filter_by(instance_id=instance_id) - result = query.options(joinedload('pool')).first() + + result = query.first() + if not result: if instance_id: - raise exception.ConsoleNotFoundForInstance(console_id=console_id, - instance_id=instance_id) + raise exception.ConsoleNotFoundForInstance( + console_id=console_id, instance_id=instance_id) else: raise exception.ConsoleNotFound(console_id=console_id) + return result - ################## +################## @require_admin_context -def instance_type_create(_context, values): +def instance_type_create(context, values): """Create a new instance type. In order to pass in extra specs, the values dict should contain a 'extra_specs' key/value pair: @@ -3354,26 +3113,29 @@ def _dict_with_extra_specs(inst_type_query): return inst_type_dict +def _instance_type_get_query(context, session=None, read_deleted=None): + return model_query(context, models.InstanceTypes, session=session, + read_deleted=read_deleted).\ + options(joinedload('extra_specs')) + + @require_context def instance_type_get_all(context, inactive=False, filters=None): """ Returns all instance types. """ filters = filters or {} - session = get_session() - partial = session.query(models.InstanceTypes)\ - .options(joinedload('extra_specs')) - if not inactive: - partial = partial.filter_by(deleted=False) + read_deleted = "yes" if inactive else "no" + query = _instance_type_get_query(context, read_deleted=read_deleted) if 'min_memory_mb' in filters: - partial = partial.filter( + query = query.filter( models.InstanceTypes.memory_mb >= filters['min_memory_mb']) if 'min_local_gb' in filters: - partial = partial.filter( + query = query.filter( models.InstanceTypes.local_gb >= filters['min_local_gb']) - inst_types = partial.order_by("name").all() + inst_types = query.order_by("name").all() return [_dict_with_extra_specs(i) for i in inst_types] @@ -3381,53 +3143,52 @@ def instance_type_get_all(context, inactive=False, filters=None): @require_context def instance_type_get(context, id): """Returns a dict describing specific instance_type""" - session = get_session() - inst_type = session.query(models.InstanceTypes).\ - options(joinedload('extra_specs')).\ + result = _instance_type_get_query(context, read_deleted="yes").\ filter_by(id=id).\ first() - if not inst_type: + if not result: raise exception.InstanceTypeNotFound(instance_type_id=id) - else: - return _dict_with_extra_specs(inst_type) + + return _dict_with_extra_specs(result) @require_context def instance_type_get_by_name(context, name): """Returns a dict describing specific instance_type""" - session = get_session() - inst_type = session.query(models.InstanceTypes).\ - options(joinedload('extra_specs')).\ + result = _instance_type_get_query(context, read_deleted="yes").\ filter_by(name=name).\ first() - if not inst_type: + + if not result: raise exception.InstanceTypeNotFoundByName(instance_type_name=name) - else: - return _dict_with_extra_specs(inst_type) + + return _dict_with_extra_specs(result) @require_context def instance_type_get_by_flavor_id(context, flavor_id): """Returns a dict describing specific flavor_id""" - session = get_session() - inst_type = session.query(models.InstanceTypes).\ - options(joinedload('extra_specs')).\ - filter_by(flavorid=flavor_id).\ - first() - if not inst_type: + result = _instance_type_get_query(context, read_deleted="yes").\ + filter_by(flavorid=flavor_id).\ + first() + + if not result: raise exception.FlavorNotFound(flavor_id=flavor_id) - else: - return _dict_with_extra_specs(inst_type) + + return _dict_with_extra_specs(result) @require_admin_context def instance_type_destroy(context, name): """ Marks specific instance_type as deleted""" - session = get_session() - instance_type_ref = session.query(models.InstanceTypes).\ - filter_by(name=name) + instance_type_ref = model_query(context, models.InstanceTypes, + read_deleted="yes").\ + filter_by(name=name) + + # FIXME(sirp): this should update deleted_at and updated_at as well records = instance_type_ref.update(dict(deleted=True)) + if records == 0: raise exception.InstanceTypeNotFoundByName(instance_type_name=name) else: @@ -3439,10 +3200,12 @@ def instance_type_purge(context, name): """ Removes specific instance_type from DB Usually instance_type_destroy should be used """ - session = get_session() - instance_type_ref = session.query(models.InstanceTypes).\ - filter_by(name=name) + instance_type_ref = model_query(context, models.InstanceTypes, + read_deleted="yes").\ + filter_by(name=name) + records = instance_type_ref.delete() + if records == 0: raise exception.InstanceTypeNotFoundByName(instance_type_name=name) else: @@ -3460,12 +3223,14 @@ def zone_create(context, values): return zone +def _zone_get_by_id_query(context, zone_id, session=None): + return model_query(context, models.Zone, session=session).\ + filter_by(id=zone_id) + + @require_admin_context def zone_update(context, zone_id, values): - session = get_session() - zone = session.query(models.Zone).filter_by(id=zone_id).first() - if not zone: - raise exception.ZoneNotFound(zone_id=zone_id) + zone = zone_get(context, zone_id) zone.update(values) zone.save(session=session) return zone @@ -3475,53 +3240,51 @@ def zone_update(context, zone_id, values): def zone_delete(context, zone_id): session = get_session() with session.begin(): - session.query(models.Zone).\ - filter_by(id=zone_id).\ + _zone_get_by_id_query(context, zone_id, session=session).\ delete() @require_admin_context def zone_get(context, zone_id): - session = get_session() - result = session.query(models.Zone).filter_by(id=zone_id).first() + result = _zone_get_by_id_query(context, zone_id).first() + if not result: raise exception.ZoneNotFound(zone_id=zone_id) + return result @require_admin_context def zone_get_all(context): - session = get_session() - return session.query(models.Zone).all() + return model_query(context, models.Zone, read_deleted="yes").all() #################### +def _instance_metadata_get_query(context, instance_id, session=None): + return model_query(context, models.InstanceMetadata, session=session, + read_deleted="no").\ + filter_by(instance_id=instance_id) + + @require_context @require_instance_exists def instance_metadata_get(context, instance_id): - session = get_session() + rows = _instance_metadata_get_query(context, instance_id).all() - meta_results = session.query(models.InstanceMetadata).\ - filter_by(instance_id=instance_id).\ - filter_by(deleted=False).\ - all() + result = {} + for row in rows: + result[row['key']] = row['value'] - meta_dict = {} - for i in meta_results: - meta_dict[i['key']] = i['value'] - return meta_dict + return result @require_context @require_instance_exists def instance_metadata_delete(context, instance_id, key): - session = get_session() - session.query(models.InstanceMetadata).\ - filter_by(instance_id=instance_id).\ + _instance_metadata_get_query(context, instance_id).\ filter_by(key=key).\ - filter_by(deleted=False).\ update({'deleted': True, 'deleted_at': utils.utcnow(), 'updated_at': literal_column('updated_at')}) @@ -3530,10 +3293,7 @@ def instance_metadata_delete(context, instance_id, key): @require_context @require_instance_exists def instance_metadata_delete_all(context, instance_id): - session = get_session() - session.query(models.InstanceMetadata).\ - filter_by(instance_id=instance_id).\ - filter_by(deleted=False).\ + _instance_metadata_get_query(context, instance_id).\ update({'deleted': True, 'deleted_at': utils.utcnow(), 'updated_at': literal_column('updated_at')}) @@ -3542,19 +3302,16 @@ def instance_metadata_delete_all(context, instance_id): @require_context @require_instance_exists def instance_metadata_get_item(context, instance_id, key, session=None): - if not session: - session = get_session() - - meta_result = session.query(models.InstanceMetadata).\ - filter_by(instance_id=instance_id).\ + result = _instance_metadata_get_query( + context, instance_id, session=session).\ filter_by(key=key).\ - filter_by(deleted=False).\ first() - if not meta_result: + if not result: raise exception.InstanceMetadataNotFound(metadata_key=key, instance_id=instance_id) - return meta_result + + return result @require_context @@ -3607,21 +3364,17 @@ def agent_build_create(context, values): @require_admin_context def agent_build_get_by_triple(context, hypervisor, os, architecture, session=None): - if not session: - session = get_session() - return session.query(models.AgentBuild).\ + return model_query(context, models.AgentBuild, session=session, + read_deleted="no").\ filter_by(hypervisor=hypervisor).\ filter_by(os=os).\ filter_by(architecture=architecture).\ - filter_by(deleted=False).\ first() @require_admin_context def agent_build_get_all(context): - session = get_session() - return session.query(models.AgentBuild).\ - filter_by(deleted=False).\ + return model_query(context, models.AgentBuild, read_deleted="no").\ all() @@ -3629,7 +3382,8 @@ def agent_build_get_all(context): def agent_build_destroy(context, agent_build_id): session = get_session() with session.begin(): - session.query(models.AgentBuild).\ + model_query(context, models.AgentBuild, session=session, + read_deleted="yes").\ filter_by(id=agent_build_id).\ update({'deleted': True, 'deleted_at': utils.utcnow(), @@ -3640,9 +3394,11 @@ def agent_build_destroy(context, agent_build_id): def agent_build_update(context, agent_build_id, values): session = get_session() with session.begin(): - agent_build_ref = session.query(models.AgentBuild).\ - filter_by(id=agent_build_id). \ + agent_build_ref = model_query(context, models.AgentBuild, + session=session, read_deleted="yes").\ + filter_by(id=agent_build_id).\ first() + agent_build_ref.update(values) agent_build_ref.save(session=session) @@ -3651,8 +3407,7 @@ def agent_build_update(context, agent_build_id, values): @require_context def bw_usage_get_by_instance(context, instance_id, start_period): - session = get_session() - return session.query(models.BandwidthUsage).\ + return model_query(context, models.BandwidthUsage, read_deleted="yes").\ filter_by(instance_id=instance_id).\ filter_by(start_period=start_period).\ all() @@ -3665,18 +3420,23 @@ def bw_usage_update(context, start_period, bw_in, bw_out, session=None): - session = session if session else get_session() + if not session: + session = get_session() + with session.begin(): - bwusage = session.query(models.BandwidthUsage).\ + bwusage = model_query(context, models.BandwidthUsage, + read_deleted="yes").\ filter_by(instance_id=instance_id).\ filter_by(start_period=start_period).\ filter_by(network_label=network_label).\ first() + if not bwusage: bwusage = models.BandwidthUsage() bwusage.instance_id = instance_id bwusage.start_period = start_period bwusage.network_label = network_label + bwusage.last_refreshed = utils.utcnow() bwusage.bw_in = bw_in bwusage.bw_out = bw_out @@ -3686,28 +3446,31 @@ def bw_usage_update(context, #################### +def _instance_type_extra_specs_get_query(context, instance_type_id, + session=None): + return model_query(context, models.InstanceTypeExtraSpecs, + session=session, read_deleted="no").\ + filter_by(instance_type_id=instance_type_id) + + @require_context def instance_type_extra_specs_get(context, instance_type_id): - session = get_session() - - spec_results = session.query(models.InstanceTypeExtraSpecs).\ - filter_by(instance_type_id=instance_type_id).\ - filter_by(deleted=False).\ + rows = _instance_type_extra_specs_get_query( + context, instance_type_id).\ all() - spec_dict = {} - for i in spec_results: - spec_dict[i['key']] = i['value'] - return spec_dict + result = {} + for row in rows: + result[row['key']] = row['value'] + + return result @require_context def instance_type_extra_specs_delete(context, instance_type_id, key): - session = get_session() - session.query(models.InstanceTypeExtraSpecs).\ - filter_by(instance_type_id=instance_type_id).\ + _instance_type_extra_specs_get_query( + context, instance_type_id).\ filter_by(key=key).\ - filter_by(deleted=False).\ update({'deleted': True, 'deleted_at': utils.utcnow(), 'updated_at': literal_column('updated_at')}) @@ -3716,21 +3479,16 @@ def instance_type_extra_specs_delete(context, instance_type_id, key): @require_context def instance_type_extra_specs_get_item(context, instance_type_id, key, session=None): - - if not session: - session = get_session() - - spec_result = session.query(models.InstanceTypeExtraSpecs).\ - filter_by(instance_type_id=instance_type_id).\ + result = _instance_type_extra_specs_get_query( + context, instance_type_id, session=session).\ filter_by(key=key).\ - filter_by(deleted=False).\ first() - if not spec_result: - raise exception.\ - InstanceTypeExtraSpecsNotFound(extra_specs_key=key, - instance_type_id=instance_type_id) - return spec_result + if not result: + raise exception.InstanceTypeExtraSpecsNotFound( + extra_specs_key=key, instance_type_id=instance_type_id) + + return result @require_context @@ -3755,7 +3513,7 @@ def instance_type_extra_specs_update_or_create(context, instance_type_id, @require_admin_context -def volume_type_create(_context, values): +def volume_type_create(context, values): """Create a new instance type. In order to pass in extra specs, the values dict should contain a 'extra_specs' key/value pair: @@ -3776,64 +3534,64 @@ def volume_type_create(_context, values): @require_context -def volume_type_get_all(context, inactive=False, filters={}): +def volume_type_get_all(context, inactive=False, filters=None): """ Returns a dict describing all volume_types with name as key. """ - session = get_session() - if inactive: - vol_types = session.query(models.VolumeTypes).\ + filters = filters or {} + + read_deleted = "yes" if inactive else "no" + rows = model_query(context, models.VolumeTypes, + read_deleted=read_deleted).\ options(joinedload('extra_specs')).\ order_by("name").\ all() - else: - vol_types = session.query(models.VolumeTypes).\ - options(joinedload('extra_specs')).\ - filter_by(deleted=False).\ - order_by("name").\ - all() - vol_dict = {} - if vol_types: - for i in vol_types: - vol_dict[i['name']] = _dict_with_extra_specs(i) - return vol_dict + + # TODO(sirp): this patern of converting rows to a result with extra_specs + # is repeated quite a bit, might be worth creating a method for it + result = {} + for row in rows: + result[row['name']] = _dict_with_extra_specs(row) + + return result @require_context def volume_type_get(context, id): """Returns a dict describing specific volume_type""" - session = get_session() - vol_type = session.query(models.VolumeTypes).\ + result = model_query(context, models.VolumeTypes, read_deleted="yes").\ options(joinedload('extra_specs')).\ filter_by(id=id).\ first() - if not vol_type: + if not result: raise exception.VolumeTypeNotFound(volume_type=id) - else: - return _dict_with_extra_specs(vol_type) + + return _dict_with_extra_specs(result) @require_context def volume_type_get_by_name(context, name): """Returns a dict describing specific volume_type""" - session = get_session() - vol_type = session.query(models.VolumeTypes).\ + result = model_query(context, models.VolumeTypes, read_deleted="yes").\ options(joinedload('extra_specs')).\ filter_by(name=name).\ first() - if not vol_type: + + if not result: raise exception.VolumeTypeNotFoundByName(volume_type_name=name) else: - return _dict_with_extra_specs(vol_type) + return _dict_with_extra_specs(result) @require_admin_context def volume_type_destroy(context, name): """ Marks specific volume_type as deleted""" - session = get_session() - volume_type_ref = session.query(models.VolumeTypes).\ + volume_type_ref = model_query(context, models.VolumeTypes, + read_deleted="yes").\ filter_by(name=name) + + # FIXME(sirp): we should be setting deleted_at and updated_at here records = volume_type_ref.update(dict(deleted=True)) if records == 0: raise exception.VolumeTypeNotFoundByName(volume_type_name=name) @@ -3846,8 +3604,8 @@ def volume_type_purge(context, name): """ Removes specific volume_type from DB Usually volume_type_destroy should be used """ - session = get_session() - volume_type_ref = session.query(models.VolumeTypes).\ + volume_type_ref = model_query(context, models.VolumeTypes, + read_deleted="yes").\ filter_by(name=name) records = volume_type_ref.delete() if records == 0: @@ -3859,28 +3617,28 @@ def volume_type_purge(context, name): #################### +def _volume_type_extra_specs_query(context, volume_type_id, session=None): + return model_query(context, models.VolumeTypeExtraSpecs, session=session, + read_deleted="no").\ + filter_by(volume_type_id=volume_type_id) + + @require_context def volume_type_extra_specs_get(context, volume_type_id): - session = get_session() - - spec_results = session.query(models.VolumeTypeExtraSpecs).\ - filter_by(volume_type_id=volume_type_id).\ - filter_by(deleted=False).\ + rows = _volume_type_extra_specs_query(context, volume_type_id).\ all() - spec_dict = {} - for i in spec_results: - spec_dict[i['key']] = i['value'] - return spec_dict + result = {} + for row in rows: + result[row['key']] = row['value'] + + return result @require_context def volume_type_extra_specs_delete(context, volume_type_id, key): - session = get_session() - session.query(models.VolumeTypeExtraSpecs).\ - filter_by(volume_type_id=volume_type_id).\ + _volume_type_extra_specs_query(context, volume_type_id).\ filter_by(key=key).\ - filter_by(deleted=False).\ update({'deleted': True, 'deleted_at': utils.utcnow(), 'updated_at': literal_column('updated_at')}) @@ -3888,22 +3646,17 @@ def volume_type_extra_specs_delete(context, volume_type_id, key): @require_context def volume_type_extra_specs_get_item(context, volume_type_id, key, - session=None): - - if not session: - session = get_session() - - spec_result = session.query(models.VolumeTypeExtraSpecs).\ - filter_by(volume_type_id=volume_type_id).\ + session=None): + result = _volume_type_extra_specs_query( + context, volume_type_id, session=session).\ filter_by(key=key).\ - filter_by(deleted=False).\ first() - if not spec_result: - raise exception.\ - VolumeTypeExtraSpecsNotFound(extra_specs_key=key, - volume_type_id=volume_type_id) - return spec_result + if not result: + raise exception.VolumeTypeExtraSpecsNotFound( + extra_specs_key=key, volume_type_id=volume_type_id) + + return result @require_context @@ -3924,7 +3677,13 @@ def volume_type_extra_specs_update_or_create(context, volume_type_id, return specs - #################### +#################### + + +def _vsa_get_query(context, session=None, project_only=False): + return model_query(context, models.VirtualStorageArray, session=session, + project_only=project_only).\ + options(joinedload('vsa_instance_type')) @require_admin_context @@ -3973,23 +3732,10 @@ def vsa_get(context, vsa_id, session=None): """ Get Virtual Storage Array record by ID. """ - if not session: - session = get_session() - result = None + result = _vsa_get_query(context, session=session, project_only=True).\ + filter_by(id=vsa_id).\ + first() - if is_admin_context(context): - result = session.query(models.VirtualStorageArray).\ - options(joinedload('vsa_instance_type')).\ - filter_by(id=vsa_id).\ - filter_by(deleted=can_read_deleted(context)).\ - first() - elif is_user_context(context): - result = session.query(models.VirtualStorageArray).\ - options(joinedload('vsa_instance_type')).\ - filter_by(project_id=context.project_id).\ - filter_by(id=vsa_id).\ - filter_by(deleted=False).\ - first() if not result: raise exception.VirtualStorageArrayNotFound(id=vsa_id) @@ -4001,11 +3747,7 @@ def vsa_get_all(context): """ Get all Virtual Storage Array records. """ - session = get_session() - return session.query(models.VirtualStorageArray).\ - options(joinedload('vsa_instance_type')).\ - filter_by(deleted=can_read_deleted(context)).\ - all() + return _vsa_get_query(context).all() @require_context @@ -4014,13 +3756,7 @@ def vsa_get_all_by_project(context, project_id): Get all Virtual Storage Array records by project ID. """ authorize_project_context(context, project_id) - - session = get_session() - return session.query(models.VirtualStorageArray).\ - options(joinedload('vsa_instance_type')).\ - filter_by(project_id=project_id).\ - filter_by(deleted=can_read_deleted(context)).\ - all() + return _vsa_get_query(context).filter_by(project_id=project_id).all() #################### @@ -4028,28 +3764,26 @@ def vsa_get_all_by_project(context, project_id): def s3_image_get(context, image_id): """Find local s3 image represented by the provided id""" - session = get_session() - res = session.query(models.S3Image)\ - .filter_by(id=image_id)\ - .first() + result = model_query(context, models.S3Image, read_deleted="yes").\ + filter_by(id=image_id).\ + first() - if not res: + if not result: raise exception.ImageNotFound(image_id=image_id) - return res + return result def s3_image_get_by_uuid(context, image_uuid): """Find local s3 image represented by the provided uuid""" - session = get_session() - res = session.query(models.S3Image)\ - .filter_by(uuid=image_uuid)\ - .first() + result = model_query(context, models.S3Image, read_deleted="yes").\ + filter_by(uuid=image_uuid).\ + first() - if not res: + if not result: raise exception.ImageNotFound(image_id=image_uuid) - return res + return result def s3_image_create(context, image_uuid): @@ -4077,12 +3811,15 @@ def sm_backend_conf_create(context, values): @require_admin_context def sm_backend_conf_update(context, sm_backend_id, values): - session = get_session() - backend_conf = session.query(models.SMBackendConf).\ - filter_by(id=sm_backend_id).first() + backend_conf = model_query(context, models.SMBackendConf, + read_deleted="yes").\ + filter_by(id=sm_backend_id).\ + first() + if not backend_conf: - raise exception.NotFound(_("No backend config with id "\ - "%(sm_backend_id)s") % locals()) + raise exception.NotFound( + _("No backend config with id %(sm_backend_id)s") % locals()) + backend_conf.update(values) backend_conf.save(session=session) return backend_conf @@ -4090,40 +3827,52 @@ def sm_backend_conf_update(context, sm_backend_id, values): @require_admin_context def sm_backend_conf_delete(context, sm_backend_id): + # FIXME(sirp): for consistency, shouldn't this just mark as deleted with + # `purge` actually deleting the record? session = get_session() with session.begin(): - session.query(models.SMBackendConf).\ + model_query(context, models.SMBackendConf, session=session, + read_deleted="yes").\ filter_by(id=sm_backend_id).\ delete() @require_admin_context def sm_backend_conf_get(context, sm_backend_id): - session = get_session() - result = session.query(models.SMBackendConf).\ - filter_by(id=sm_backend_id).first() + result = model_query(context, models.SMBackendConf, read_deleted="yes").\ + filter_by(id=sm_backend_id).\ + first() + if not result: raise exception.NotFound(_("No backend config with id "\ "%(sm_backend_id)s") % locals()) + return result @require_admin_context def sm_backend_conf_get_by_sr(context, sr_uuid): session = get_session() - result = session.query(models.SMBackendConf).filter_by(sr_uuid=sr_uuid) - return result + # FIXME(sirp): shouldn't this have a `first()` qualifier attached? + return model_query(context, models.SMBackendConf, read_deleted="yes").\ + filter_by(sr_uuid=sr_uuid) @require_admin_context def sm_backend_conf_get_all(context): - session = get_session() - return session.query(models.SMBackendConf).all() + return model_query(context, models.SMBackendConf, read_deleted="yes").\ + all() #################### +def _sm_flavor_get_query(context, sm_flavor_label, session=None): + return model_query(context, models.SMFlavors, session=session, + read_deleted="yes").\ + filter_by(label=sm_flavor_label) + + @require_admin_context def sm_flavor_create(context, values): sm_flavor = models.SMFlavors() @@ -4134,12 +3883,7 @@ def sm_flavor_create(context, values): @require_admin_context def sm_flavor_update(context, sm_flavor_label, values): - session = get_session() - sm_flavor = session.query(models.SMFlavors).\ - filter_by(label=sm_flavor_label) - if not sm_flavor: - raise exception.NotFound(_("No sm_flavor with id "\ - "%(sm_flavor_id)s") % locals()) + sm_flavor = sm_flavor_get(context, sm_flavor_label) sm_flavor.update(values) sm_flavor.save() return sm_flavor @@ -4149,30 +3893,34 @@ def sm_flavor_update(context, sm_flavor_label, values): def sm_flavor_delete(context, sm_flavor_label): session = get_session() with session.begin(): - session.query(models.SMFlavors).\ - filter_by(label=sm_flavor_label).\ - delete() + _sm_flavor_get_query(context, sm_flavor_label).delete() @require_admin_context -def sm_flavor_get(context, sm_flavor): - session = get_session() - result = session.query(models.SMFlavors).filter_by(label=sm_flavor) +def sm_flavor_get(context, sm_flavor_label): + result = _sm_flavor_get_query(context, sm_flavor_label).first() + if not result: - raise exception.NotFound(_("No sm_flavor called %(sm_flavor)s") \ - % locals()) + raise exception.NotFound( + _("No sm_flavor called %(sm_flavor)s") % locals()) + return result @require_admin_context def sm_flavor_get_all(context): - session = get_session() - return session.query(models.SMFlavors).all() + return model_query(context, models.SMFlavors, read_deleted="yes").all() ############################### +def _sm_volume_get_query(context, volume_id, session=None): + return model_query(context, models.SMVolume, session=session, + read_deleted="yes").\ + filter_by(id=volume_id) + + def sm_volume_create(context, values): sm_volume = models.SMVolume() sm_volume.update(values) @@ -4181,11 +3929,7 @@ def sm_volume_create(context, values): def sm_volume_update(context, volume_id, values): - session = get_session() - sm_volume = session.query(models.SMVolume).filter_by(id=volume_id).first() - if not sm_volume: - raise exception.NotFound(_("No sm_volume with id %(volume_id)s") \ - % locals()) + sm_volume = sm_volume_get(context, volume_id) sm_volume.update(values) sm_volume.save() return sm_volume @@ -4194,20 +3938,18 @@ def sm_volume_update(context, volume_id, values): def sm_volume_delete(context, volume_id): session = get_session() with session.begin(): - session.query(models.SMVolume).\ - filter_by(id=volume_id).\ - delete() + _sm_volume_get_query(context, volume_id, session=session).delete() def sm_volume_get(context, volume_id): - session = get_session() - result = session.query(models.SMVolume).filter_by(id=volume_id).first() + result = _sm_volume_get_query(context, volume_id).first() + if not result: - raise exception.NotFound(_("No sm_volume with id %(volume_id)s") \ - % locals()) + raise exception.NotFound( + _("No sm_volume with id %(volume_id)s") % locals()) + return result def sm_volume_get_all(context): - session = get_session() - return session.query(models.SMVolume).all() + return model_query(context, models.SMVolume, read_deleted="yes").all() diff --git a/nova/tests/api/ec2/test_cloud.py b/nova/tests/api/ec2/test_cloud.py index caa6ff68d925..78e7f96fe98b 100644 --- a/nova/tests/api/ec2/test_cloud.py +++ b/nova/tests/api/ec2/test_cloud.py @@ -106,7 +106,7 @@ class CloudTestCase(test.TestCase): self.project_id = 'fake' self.context = context.RequestContext(self.user_id, self.project_id, - True) + is_admin=True) def fake_show(meh, context, id): return {'id': id, @@ -1564,12 +1564,12 @@ class CloudTestCase(test.TestCase): self.cloud.terminate_instances(self.context, [ec2_instance_id]) - admin_ctxt = context.get_admin_context(read_deleted=False) + admin_ctxt = context.get_admin_context(read_deleted="no") vol = db.volume_get(admin_ctxt, vol1['id']) self.assertFalse(vol['deleted']) db.volume_destroy(self.context, vol1['id']) - admin_ctxt = context.get_admin_context(read_deleted=True) + admin_ctxt = context.get_admin_context(read_deleted="only") vol = db.volume_get(admin_ctxt, vol2['id']) self.assertTrue(vol['deleted']) @@ -1689,13 +1689,13 @@ class CloudTestCase(test.TestCase): self.cloud.terminate_instances(self.context, [ec2_instance_id]) - admin_ctxt = context.get_admin_context(read_deleted=False) + admin_ctxt = context.get_admin_context(read_deleted="no") vol = db.volume_get(admin_ctxt, vol1_id) self._assert_volume_detached(vol) self.assertFalse(vol['deleted']) db.volume_destroy(self.context, vol1_id) - admin_ctxt = context.get_admin_context(read_deleted=True) + admin_ctxt = context.get_admin_context(read_deleted="only") vol = db.volume_get(admin_ctxt, vol2_id) self.assertTrue(vol['deleted']) diff --git a/nova/tests/scheduler/test_scheduler.py b/nova/tests/scheduler/test_scheduler.py index 9938b5dd98c4..32ff67f1d79b 100644 --- a/nova/tests/scheduler/test_scheduler.py +++ b/nova/tests/scheduler/test_scheduler.py @@ -312,7 +312,7 @@ class SimpleDriverTestCase(test.TestCase): FLAGS.compute_manager) compute1.start() _create_instance() - ctxt = context.RequestContext('fake', 'fake', False) + ctxt = context.RequestContext('fake', 'fake', is_admin=False) global instance_uuids instance_uuids = [] self.stubs.Set(SimpleScheduler, diff --git a/nova/tests/test_adminapi.py b/nova/tests/test_adminapi.py index 08c8f707a32d..41dbc3e925ef 100644 --- a/nova/tests/test_adminapi.py +++ b/nova/tests/test_adminapi.py @@ -52,7 +52,7 @@ class AdminApiTestCase(test.TestCase): self.project_id = 'admin' self.context = context.RequestContext(self.user_id, self.project_id, - True) + is_admin=True) def fake_show(meh, context, id): return {'id': 1, 'properties': {'kernel_id': 1, 'ramdisk_id': 1, diff --git a/nova/tests/test_compute.py b/nova/tests/test_compute.py index 262f0a9687f5..e4c60e0682f2 100644 --- a/nova/tests/test_compute.py +++ b/nova/tests/test_compute.py @@ -222,7 +222,7 @@ class ComputeTestCase(BaseTestCase): self.assertEqual(instance['deleted_at'], None) terminate = utils.utcnow() self.compute.terminate_instance(self.context, instance['uuid']) - context = self.context.elevated(True) + context = self.context.elevated(read_deleted="only") instance = db.instance_get_by_uuid(context, instance['uuid']) self.assert_(instance['launched_at'] < terminate) self.assert_(instance['deleted_at'] > terminate) @@ -674,7 +674,7 @@ class ComputeTestCase(BaseTestCase): instance_uuid = instance['uuid'] self.compute.run_instance(self.context, instance_uuid) - non_admin_context = context.RequestContext(None, None, False, False) + non_admin_context = context.RequestContext(None, None, is_admin=False) # decorator should return False (fail) with locked nonadmin context self.compute.lock_instance(self.context, instance_uuid) @@ -1231,8 +1231,9 @@ class ComputeAPITestCase(BaseTestCase): try: db.security_group_destroy(self.context, group['id']) - group = db.security_group_get(context.get_admin_context( - read_deleted=True), group['id']) + admin_deleted_context = context.get_admin_context( + read_deleted="only") + group = db.security_group_get(admin_deleted_context, group['id']) self.assert_(len(group.instances) == 0) finally: db.instance_destroy(self.context, ref[0]['id']) diff --git a/nova/tests/test_quota.py b/nova/tests/test_quota.py index 449f37572d92..83c9e36dfe11 100644 --- a/nova/tests/test_quota.py +++ b/nova/tests/test_quota.py @@ -53,7 +53,7 @@ class QuotaTestCase(test.TestCase): self.project_id = 'admin' self.context = context.RequestContext(self.user_id, self.project_id, - True) + is_admin=True) orig_rpc_call = rpc.call def rpc_call_wrapper(context, topic, msg): diff --git a/nova/tests/test_vmwareapi.py b/nova/tests/test_vmwareapi.py index 9dacda4b393d..b1b297107afe 100644 --- a/nova/tests/test_vmwareapi.py +++ b/nova/tests/test_vmwareapi.py @@ -40,7 +40,7 @@ class VMWareAPIVMTestCase(test.TestCase): def setUp(self): super(VMWareAPIVMTestCase, self).setUp() - self.context = context.RequestContext('fake', 'fake', False) + self.context = context.RequestContext('fake', 'fake', is_admin=False) self.flags(vmwareapi_host_ip='test_url', vmwareapi_host_username='test_username', vmwareapi_host_password='test_pass') diff --git a/tools/xenserver/vm_vdi_cleaner.py b/tools/xenserver/vm_vdi_cleaner.py index 24553fd75e9f..6b9213e404a6 100755 --- a/tools/xenserver/vm_vdi_cleaner.py +++ b/tools/xenserver/vm_vdi_cleaner.py @@ -129,8 +129,7 @@ def get_instance_id_from_name_label(name_label, template): def find_orphaned_instances(session, verbose=False): """Find and return a list of orphaned instances.""" - ctxt = context.get_admin_context() - ctxt.read_deleted = True + ctxt = context.get_admin_context(read_deleted="only") orphaned_instances = []