diff --git a/heat/db/api.py b/heat/db/api.py index 06d868a935..21879b6c9d 100644 --- a/heat/db/api.py +++ b/heat/db/api.py @@ -34,7 +34,6 @@ from sqlalchemy import and_ from sqlalchemy import func from sqlalchemy import or_ from sqlalchemy import orm -from sqlalchemy.orm import aliased as orm_aliased from heat.common import crypt from heat.common import exception @@ -153,6 +152,10 @@ def _soft_delete_aware_query(context, *args, **kwargs): def raw_template_get(context, template_id): + return _raw_template_get(context, template_id) + + +def _raw_template_get(context, template_id): result = context.session.get(models.RawTemplate, template_id) if not result: @@ -169,7 +172,7 @@ def raw_template_create(context, values): def raw_template_update(context, template_id, values): - raw_template_ref = raw_template_get(context, template_id) + raw_template_ref = _raw_template_get(context, template_id) # get only the changed values values = dict((k, v) for k, v in values.items() if getattr(raw_template_ref, k) != v) @@ -182,7 +185,7 @@ def raw_template_update(context, template_id, values): def raw_template_delete(context, template_id): try: - raw_template = raw_template_get(context, template_id) + raw_template = _raw_template_get(context, template_id) except exception.NotFound: # Ignore not found return @@ -196,7 +199,7 @@ def raw_template_delete(context, template_id): if context.session.query(models.RawTemplate).filter_by( files_id=raw_tmpl_files_id).first() is None: try: - raw_tmpl_files = raw_template_files_get( + raw_tmpl_files = _raw_template_files_get( context, raw_tmpl_files_id) except exception.NotFound: # Ignore not found @@ -216,6 +219,10 @@ def raw_template_files_create(context, values): def raw_template_files_get(context, files_id): + return _raw_template_files_get(context, files_id) + + +def _raw_template_files_get(context, files_id): result = context.session.get(models.RawTemplateFiles, files_id) if not result: raise exception.NotFound( @@ -228,6 +235,10 @@ def raw_template_files_get(context, files_id): def resource_create(context, values): + return _resource_create(context, values) + + +def _resource_create(context, values): resource_ref = models.Resource() resource_ref.update(values) resource_ref.save(context.session) @@ -241,7 +252,7 @@ def resource_create_replacement(context, atomic_key, expected_engine_id=None): try: with context.session.begin(): - new_res = resource_create(context, new_res_values) + new_res = _resource_create(context, new_res_values) update_data = {'replaced_by': new_res.id} if not _try_resource_update(context, existing_res_id, update_data, @@ -354,6 +365,12 @@ def resource_get_by_name_and_stack(context, resource_name, stack_id): def resource_get_all_by_physical_resource_id(context, physical_resource_id): + return _resource_get_all_by_physical_resource_id( + context, physical_resource_id, + ) + + +def _resource_get_all_by_physical_resource_id(context, physical_resource_id): results = (context.session.query(models.Resource) .filter_by(physical_resource_id=physical_resource_id) .all()) @@ -365,8 +382,9 @@ def resource_get_all_by_physical_resource_id(context, physical_resource_id): def resource_get_by_physical_resource_id(context, physical_resource_id): - results = resource_get_all_by_physical_resource_id(context, - physical_resource_id) + results = _resource_get_all_by_physical_resource_id( + context, physical_resource_id, + ) try: return next(results) except StopIteration: @@ -515,15 +533,17 @@ def resource_data_get(context, resource_id, key): Decrypts resource data if necessary. """ - result = resource_data_get_by_key(context, - resource_id, - key) + result = _resource_data_get_by_key(context, resource_id, key) if result.redact: return crypt.decrypt(result.decrypt_method, result.value) return result.value def resource_data_get_by_key(context, resource_id, key): + return _resource_data_get_by_key(context, resource_id, key) + + +def _resource_data_get_by_key(context, resource_id, key): """Looks up resource_data by resource_id and key. Does not decrypt resource_data. @@ -544,7 +564,7 @@ def resource_data_set(context, resource_id, key, value, redact=False): else: method = '' try: - current = resource_data_get_by_key(context, resource_id, key) + current = _resource_data_get_by_key(context, resource_id, key) except exception.NotFound: current = models.ResourceData() current.key = key @@ -557,7 +577,7 @@ def resource_data_set(context, resource_id, key, value, redact=False): def resource_data_delete(context, resource_id, key): - result = resource_data_get_by_key(context, resource_id, key) + result = _resource_data_get_by_key(context, resource_id, key) with context.session.begin(): context.session.delete(result) @@ -566,6 +586,10 @@ def resource_data_delete(context, resource_id, key): def resource_prop_data_create_or_update(context, values, rpd_id=None): + return _resource_prop_data_create_or_update(context, values, rpd_id=rpd_id) + + +def _resource_prop_data_create_or_update(context, values, rpd_id=None): obj_ref = None if rpd_id is not None: obj_ref = context.session.query( @@ -578,7 +602,7 @@ def resource_prop_data_create_or_update(context, values, rpd_id=None): def resource_prop_data_create(context, values): - return resource_prop_data_create_or_update(context, values) + return _resource_prop_data_create_or_update(context, values) def resource_prop_data_get(context, resource_prop_data_id): @@ -605,6 +629,10 @@ def stack_get_by_name_and_owner_id(context, stack_name, owner_id): def stack_get_by_name(context, stack_name): + return _stack_get_by_name(context, stack_name) + + +def _stack_get_by_name(context, stack_name): query = _soft_delete_aware_query( context, models.Stack ).filter(sqlalchemy.or_( @@ -615,6 +643,12 @@ def stack_get_by_name(context, stack_name): def stack_get(context, stack_id, show_deleted=False, eager_load=True): + return _stack_get( + context, stack_id, show_deleted=show_deleted, eager_load=eager_load + ) + + +def _stack_get(context, stack_id, show_deleted=False, eager_load=True): options = [] if eager_load: options.append(orm.joinedload(models.Stack.raw_template)) @@ -653,6 +687,10 @@ def stack_get_status(context, stack_id): def stack_get_all_by_owner_id(context, owner_id): + return _stack_get_all_by_owner_id(context, owner_id) + + +def _stack_get_all_by_owner_id(context, owner_id): results = _soft_delete_aware_query( context, models.Stack, ).filter_by( @@ -662,9 +700,13 @@ def stack_get_all_by_owner_id(context, owner_id): def stack_get_all_by_root_owner_id(context, owner_id): - for stack in stack_get_all_by_owner_id(context, owner_id): + return _stack_get_all_by_root_owner_id(context, owner_id) + + +def _stack_get_all_by_root_owner_id(context, owner_id): + for stack in _stack_get_all_by_owner_id(context, owner_id): yield stack - for ch_st in stack_get_all_by_root_owner_id(context, stack.id): + for ch_st in _stack_get_all_by_root_owner_id(context, stack.id): yield ch_st @@ -722,7 +764,7 @@ def _query_stack_get_all(context, show_deleted=False, query = query.options(orm.subqueryload(models.Stack.tags)) if tags: for tag in tags: - tag_alias = orm_aliased(models.StackTag) + tag_alias = orm.aliased(models.StackTag) query = query.join(tag_alias, models.Stack.tags) query = query.filter(tag_alias.tag == tag) @@ -736,7 +778,7 @@ def _query_stack_get_all(context, show_deleted=False, context, models.Stack, show_deleted=show_deleted ) for tag in not_tags: - tag_alias = orm_aliased(models.StackTag) + tag_alias = orm.aliased(models.StackTag) subquery = subquery.join(tag_alias, models.Stack.tags) subquery = subquery.filter(tag_alias.tag == tag) not_stack_ids = [s.id for s in subquery.all()] @@ -811,7 +853,7 @@ def stack_create(context, values): # Even though we just created a stack with this name, we may not find # it again because some unit tests create stacks with deleted_at set. Also # some backup stacks may not be found, for reasons that are unclear. - earliest = stack_get_by_name(context, stack_name) + earliest = _stack_get_by_name(context, stack_name) if earliest is not None and earliest.id != stack_ref.id: with context.session.begin(): context.session.query(models.Stack).filter_by( @@ -841,7 +883,7 @@ def stack_update(context, stack_id, values, exp_trvsl=None): 'expected traversal: %(trav)s', {'id': stack_id, 'vals': str(values), 'trav': str(exp_trvsl)}) - if not stack_get(context, stack_id, eager_load=False): + if not _stack_get(context, stack_id, eager_load=False): raise exception.NotFound( _('Attempt to update a stack with id: ' '%(id)s %(msg)s') % { @@ -852,7 +894,7 @@ def stack_update(context, stack_id, values, exp_trvsl=None): def stack_delete(context, stack_id): - s = stack_get(context, stack_id, eager_load=False) + s = _stack_get(context, stack_id, eager_load=False) if not s: raise exception.NotFound(_('Attempt to delete a stack with id: ' '%(id)s %(msg)s') % { @@ -873,7 +915,14 @@ def stack_delete(context, stack_id): _soft_delete(context, s) -def reset_stack_status(context, stack_id, stack=None): +def reset_stack_status(context, stack_id): + return _reset_stack_status(context, stack_id) + + +# NOTE(stephenfin): This method uses separate transactions to delete nested +# stacks, thus it's the only private method that is allowed to open a +# transaction (via 'context.session.begin') +def _reset_stack_status(context, stack_id, stack=None): if stack is None: stack = context.session.get(models.Stack, stack_id) @@ -901,7 +950,7 @@ def reset_stack_status(context, stack_id, stack=None): query = context.session.query(models.Stack).filter_by(owner_id=stack_id) for child in query: - reset_stack_status(context, child.id, child) + _reset_stack_status(context, child.id, child) with context.session.begin(): if stack.status == 'IN_PROGRESS': @@ -915,7 +964,7 @@ def reset_stack_status(context, stack_id, stack=None): def stack_tags_set(context, stack_id, tags): with context.session.begin(): - stack_tags_delete(context, stack_id) + _stack_tags_delete(context, stack_id) result = [] for tag in tags: stack_tag = models.StackTag() @@ -928,13 +977,21 @@ def stack_tags_set(context, stack_id, tags): def stack_tags_delete(context, stack_id): with transaction(context): - result = stack_tags_get(context, stack_id) - if result: - for tag in result: - context.session.delete(tag) + return _stack_tags_delete(context, stack_id) + + +def _stack_tags_delete(context, stack_id): + result = _stack_tags_get(context, stack_id) + if result: + for tag in result: + context.session.delete(tag) def stack_tags_get(context, stack_id): + return _stack_tags_get(context, stack_id) + + +def _stack_tags_get(context, stack_id): result = (context.session.query(models.StackTag) .filter_by(stack_id=stack_id) .all()) @@ -1004,11 +1061,11 @@ def stack_lock_release(context, stack_id, engine_id): def stack_get_root_id(context, stack_id): - s = stack_get(context, stack_id, eager_load=False) + s = _stack_get(context, stack_id, eager_load=False) if not s: return None while s.owner_id: - s = stack_get(context, s.owner_id, eager_load=False) + s = _stack_get(context, s.owner_id, eager_load=False) return s.id @@ -1152,6 +1209,10 @@ def _events_filter_and_page_query(context, query, def event_count_all_by_stack(context, stack_id): + return _event_count_all_by_stack(context, stack_id) + + +def _event_count_all_by_stack(context, stack_id): query = context.session.query(func.count(models.Event.id)) return query.filter_by(stack_id=stack_id).scalar() @@ -1212,47 +1273,46 @@ def _delete_event_rows(context, stack_id, limit): # So we must manually supply the IN() values. # pgsql SHOULD work with the pure DELETE/JOIN below but that must be # confirmed via integration tests. - with context.session.begin(): - query = context.session.query(models.Event).filter_by( - stack_id=stack_id, - ) - query = query.order_by(models.Event.id).limit(limit) - id_pairs = [(e.id, e.rsrc_prop_data_id) for e in query.all()] - if not id_pairs: - return 0 - (ids, rsrc_prop_ids) = zip(*id_pairs) - max_id = ids[-1] - # delete the events - retval = context.session.query(models.Event).filter( - models.Event.id <= max_id).filter( - models.Event.stack_id == stack_id).delete() + query = context.session.query(models.Event).filter_by( + stack_id=stack_id, + ) + query = query.order_by(models.Event.id).limit(limit) + id_pairs = [(e.id, e.rsrc_prop_data_id) for e in query.all()] + if not id_pairs: + return 0 + (ids, rsrc_prop_ids) = zip(*id_pairs) + max_id = ids[-1] + # delete the events + retval = context.session.query(models.Event).filter( + models.Event.id <= max_id).filter( + models.Event.stack_id == stack_id).delete() - # delete unreferenced resource_properties_data - def del_rpd(rpd_ids): - if not rpd_ids: - return - q_rpd = context.session.query(models.ResourcePropertiesData) - q_rpd = q_rpd.filter(models.ResourcePropertiesData.id.in_(rpd_ids)) - q_rpd.delete(synchronize_session=False) + # delete unreferenced resource_properties_data + def del_rpd(rpd_ids): + if not rpd_ids: + return + q_rpd = context.session.query(models.ResourcePropertiesData) + q_rpd = q_rpd.filter(models.ResourcePropertiesData.id.in_(rpd_ids)) + q_rpd.delete(synchronize_session=False) - if rsrc_prop_ids: - clr_prop_ids = set(rsrc_prop_ids) - _find_rpd_references(context, - stack_id) - clr_prop_ids.discard(None) - try: - del_rpd(clr_prop_ids) - except db_exception.DBReferenceError: - LOG.debug('Checking backup/stack pairs for RPD references') - found = False - for partner_stack_id in _all_backup_stack_ids(context, - stack_id): - found = True - clr_prop_ids -= _find_rpd_references(context, - partner_stack_id) - if not found: - LOG.debug('No backup/stack pairs found for %s', stack_id) - raise - del_rpd(clr_prop_ids) + if rsrc_prop_ids: + clr_prop_ids = set(rsrc_prop_ids) - _find_rpd_references(context, + stack_id) + clr_prop_ids.discard(None) + try: + del_rpd(clr_prop_ids) + except db_exception.DBReferenceError: + LOG.debug('Checking backup/stack pairs for RPD references') + found = False + for partner_stack_id in _all_backup_stack_ids(context, + stack_id): + found = True + clr_prop_ids -= _find_rpd_references(context, + partner_stack_id) + if not found: + LOG.debug('No backup/stack pairs found for %s', stack_id) + raise + del_rpd(clr_prop_ids) return retval @@ -1263,13 +1323,18 @@ def event_create(context, values): # only count events and purge on average # 200.0/cfg.CONF.event_purge_batch_size percent of the time. check = (2.0 / cfg.CONF.event_purge_batch_size) > random.uniform(0, 1) - if (check and - (event_count_all_by_stack(context, values['stack_id']) >= - cfg.CONF.max_events_per_stack)): + if ( + check and _event_count_all_by_stack( + context, values['stack_id'] + ) >= cfg.CONF.max_events_per_stack + ): # prune try: - _delete_event_rows(context, values['stack_id'], - cfg.CONF.event_purge_batch_size) + with context.session.begin(): + _delete_event_rows( + context, values['stack_id'], + cfg.CONF.event_purge_batch_size, + ) except db_exception.DBError as exc: LOG.error('Failed to purge events: %s', str(exc)) event_ref = models.Event() @@ -1289,6 +1354,10 @@ def software_config_create(context, values): def software_config_get(context, config_id): + return _software_config_get(context, config_id) + + +def _software_config_get(context, config_id): result = context.session.get(models.SoftwareConfig, config_id) if (result is not None and context is not None and not context.is_admin and result.tenant != context.tenant_id): @@ -1309,7 +1378,7 @@ def software_config_get_all(context, limit=None, marker=None): def software_config_delete(context, config_id): - config = software_config_get(context, config_id) + config = _software_config_get(context, config_id) # Query if the software config has been referenced by deployment. result = context.session.query(models.SoftwareDeployment).filter_by( config_id=config_id).first() @@ -1340,6 +1409,10 @@ def software_deployment_create(context, values): def software_deployment_get(context, deployment_id): + return _software_deployment_get(context, deployment_id) + + +def _software_deployment_get(context, deployment_id): result = context.session.get(models.SoftwareDeployment, deployment_id) if (result is not None and context is not None and not context.is_admin and context.tenant_id not in (result.tenant, @@ -1366,7 +1439,7 @@ def software_deployment_get_all(context, server_id=None): def software_deployment_update(context, deployment_id, values): - deployment = software_deployment_get(context, deployment_id) + deployment = _software_deployment_get(context, deployment_id) try: update_and_save(context, deployment, values) except db_exception.DBReferenceError: @@ -1377,7 +1450,7 @@ def software_deployment_update(context, deployment_id, values): def software_deployment_delete(context, deployment_id): - deployment = software_deployment_get(context, deployment_id) + deployment = _software_deployment_get(context, deployment_id) with context.session.begin(): context.session.delete(deployment) @@ -1393,6 +1466,10 @@ def snapshot_create(context, values): def snapshot_get(context, snapshot_id): + return _snapshot_get(context, snapshot_id) + + +def _snapshot_get(context, snapshot_id): result = context.session.get(models.Snapshot, snapshot_id) if (result is not None and context is not None and context.tenant_id != result.tenant): @@ -1405,7 +1482,7 @@ def snapshot_get(context, snapshot_id): def snapshot_get_by_stack(context, snapshot_id, stack): - snapshot = snapshot_get(context, snapshot_id) + snapshot = _snapshot_get(context, snapshot_id) if snapshot.stack_id != stack.id: raise exception.SnapshotNotFound(snapshot=snapshot_id, stack=stack.name) @@ -1414,14 +1491,14 @@ def snapshot_get_by_stack(context, snapshot_id, stack): def snapshot_update(context, snapshot_id, values): - snapshot = snapshot_get(context, snapshot_id) + snapshot = _snapshot_get(context, snapshot_id) snapshot.update(values) snapshot.save(context.session) return snapshot def snapshot_delete(context, snapshot_id): - snapshot = snapshot_get(context, snapshot_id) + snapshot = _snapshot_get(context, snapshot_id) with context.session.begin(): context.session.delete(snapshot) @@ -1442,7 +1519,7 @@ def service_create(context, values): def service_update(context, service_id, values): - service = service_get(context, service_id) + service = _service_get(context, service_id) values.update({'updated_at': timeutils.utcnow()}) service.update(values) service.save(context.session) @@ -1450,7 +1527,7 @@ def service_update(context, service_id, values): def service_delete(context, service_id, soft_delete=True): - service = service_get(context, service_id) + service = _service_get(context, service_id) with context.session.begin(): if soft_delete: _soft_delete(context, service) @@ -1459,6 +1536,10 @@ def service_delete(context, service_id, soft_delete=True): def service_get(context, service_id): + return _service_get(context, service_id) + + +def _service_get(context, service_id): result = context.session.get(models.Service, service_id) if result is None: raise exception.EntityNotFound(entity='Service', name=service_id) @@ -1466,8 +1547,9 @@ def service_get(context, service_id): def service_get_all(context): - return (context.session.query(models.Service). - filter_by(deleted_at=None).all()) + return context.session.query(models.Service).filter_by( + deleted_at=None, + ).all() def service_get_all_by_args(context, host, binary, hostname):