Merge "db: Only call private methods from public methods"

This commit is contained in:
Zuul 2023-10-16 07:23:34 +00:00 committed by Gerrit Code Review
commit d574808b07

View File

@ -34,7 +34,6 @@ from sqlalchemy import and_
from sqlalchemy import func
from sqlalchemy import or_
from sqlalchemy import orm
from sqlalchemy.orm import aliased as orm_aliased
from heat.common import crypt
from heat.common import exception
@ -153,6 +152,10 @@ def _soft_delete_aware_query(context, *args, **kwargs):
def raw_template_get(context, template_id):
return _raw_template_get(context, template_id)
def _raw_template_get(context, template_id):
result = context.session.get(models.RawTemplate, template_id)
if not result:
@ -169,7 +172,7 @@ def raw_template_create(context, values):
def raw_template_update(context, template_id, values):
raw_template_ref = raw_template_get(context, template_id)
raw_template_ref = _raw_template_get(context, template_id)
# get only the changed values
values = dict((k, v) for k, v in values.items()
if getattr(raw_template_ref, k) != v)
@ -182,7 +185,7 @@ def raw_template_update(context, template_id, values):
def raw_template_delete(context, template_id):
try:
raw_template = raw_template_get(context, template_id)
raw_template = _raw_template_get(context, template_id)
except exception.NotFound:
# Ignore not found
return
@ -196,7 +199,7 @@ def raw_template_delete(context, template_id):
if context.session.query(models.RawTemplate).filter_by(
files_id=raw_tmpl_files_id).first() is None:
try:
raw_tmpl_files = raw_template_files_get(
raw_tmpl_files = _raw_template_files_get(
context, raw_tmpl_files_id)
except exception.NotFound:
# Ignore not found
@ -216,6 +219,10 @@ def raw_template_files_create(context, values):
def raw_template_files_get(context, files_id):
return _raw_template_files_get(context, files_id)
def _raw_template_files_get(context, files_id):
result = context.session.get(models.RawTemplateFiles, files_id)
if not result:
raise exception.NotFound(
@ -228,6 +235,10 @@ def raw_template_files_get(context, files_id):
def resource_create(context, values):
return _resource_create(context, values)
def _resource_create(context, values):
resource_ref = models.Resource()
resource_ref.update(values)
resource_ref.save(context.session)
@ -241,7 +252,7 @@ def resource_create_replacement(context,
atomic_key, expected_engine_id=None):
try:
with context.session.begin():
new_res = resource_create(context, new_res_values)
new_res = _resource_create(context, new_res_values)
update_data = {'replaced_by': new_res.id}
if not _try_resource_update(context,
existing_res_id, update_data,
@ -354,6 +365,12 @@ def resource_get_by_name_and_stack(context, resource_name, stack_id):
def resource_get_all_by_physical_resource_id(context, physical_resource_id):
return _resource_get_all_by_physical_resource_id(
context, physical_resource_id,
)
def _resource_get_all_by_physical_resource_id(context, physical_resource_id):
results = (context.session.query(models.Resource)
.filter_by(physical_resource_id=physical_resource_id)
.all())
@ -365,8 +382,9 @@ def resource_get_all_by_physical_resource_id(context, physical_resource_id):
def resource_get_by_physical_resource_id(context, physical_resource_id):
results = resource_get_all_by_physical_resource_id(context,
physical_resource_id)
results = _resource_get_all_by_physical_resource_id(
context, physical_resource_id,
)
try:
return next(results)
except StopIteration:
@ -515,15 +533,17 @@ def resource_data_get(context, resource_id, key):
Decrypts resource data if necessary.
"""
result = resource_data_get_by_key(context,
resource_id,
key)
result = _resource_data_get_by_key(context, resource_id, key)
if result.redact:
return crypt.decrypt(result.decrypt_method, result.value)
return result.value
def resource_data_get_by_key(context, resource_id, key):
return _resource_data_get_by_key(context, resource_id, key)
def _resource_data_get_by_key(context, resource_id, key):
"""Looks up resource_data by resource_id and key.
Does not decrypt resource_data.
@ -544,7 +564,7 @@ def resource_data_set(context, resource_id, key, value, redact=False):
else:
method = ''
try:
current = resource_data_get_by_key(context, resource_id, key)
current = _resource_data_get_by_key(context, resource_id, key)
except exception.NotFound:
current = models.ResourceData()
current.key = key
@ -557,7 +577,7 @@ def resource_data_set(context, resource_id, key, value, redact=False):
def resource_data_delete(context, resource_id, key):
result = resource_data_get_by_key(context, resource_id, key)
result = _resource_data_get_by_key(context, resource_id, key)
with context.session.begin():
context.session.delete(result)
@ -566,6 +586,10 @@ def resource_data_delete(context, resource_id, key):
def resource_prop_data_create_or_update(context, values, rpd_id=None):
return _resource_prop_data_create_or_update(context, values, rpd_id=rpd_id)
def _resource_prop_data_create_or_update(context, values, rpd_id=None):
obj_ref = None
if rpd_id is not None:
obj_ref = context.session.query(
@ -578,7 +602,7 @@ def resource_prop_data_create_or_update(context, values, rpd_id=None):
def resource_prop_data_create(context, values):
return resource_prop_data_create_or_update(context, values)
return _resource_prop_data_create_or_update(context, values)
def resource_prop_data_get(context, resource_prop_data_id):
@ -605,6 +629,10 @@ def stack_get_by_name_and_owner_id(context, stack_name, owner_id):
def stack_get_by_name(context, stack_name):
return _stack_get_by_name(context, stack_name)
def _stack_get_by_name(context, stack_name):
query = _soft_delete_aware_query(
context, models.Stack
).filter(sqlalchemy.or_(
@ -615,6 +643,12 @@ def stack_get_by_name(context, stack_name):
def stack_get(context, stack_id, show_deleted=False, eager_load=True):
return _stack_get(
context, stack_id, show_deleted=show_deleted, eager_load=eager_load
)
def _stack_get(context, stack_id, show_deleted=False, eager_load=True):
options = []
if eager_load:
options.append(orm.joinedload(models.Stack.raw_template))
@ -653,6 +687,10 @@ def stack_get_status(context, stack_id):
def stack_get_all_by_owner_id(context, owner_id):
return _stack_get_all_by_owner_id(context, owner_id)
def _stack_get_all_by_owner_id(context, owner_id):
results = _soft_delete_aware_query(
context, models.Stack,
).filter_by(
@ -662,9 +700,13 @@ def stack_get_all_by_owner_id(context, owner_id):
def stack_get_all_by_root_owner_id(context, owner_id):
for stack in stack_get_all_by_owner_id(context, owner_id):
return _stack_get_all_by_root_owner_id(context, owner_id)
def _stack_get_all_by_root_owner_id(context, owner_id):
for stack in _stack_get_all_by_owner_id(context, owner_id):
yield stack
for ch_st in stack_get_all_by_root_owner_id(context, stack.id):
for ch_st in _stack_get_all_by_root_owner_id(context, stack.id):
yield ch_st
@ -722,7 +764,7 @@ def _query_stack_get_all(context, show_deleted=False,
query = query.options(orm.subqueryload(models.Stack.tags))
if tags:
for tag in tags:
tag_alias = orm_aliased(models.StackTag)
tag_alias = orm.aliased(models.StackTag)
query = query.join(tag_alias, models.Stack.tags)
query = query.filter(tag_alias.tag == tag)
@ -736,7 +778,7 @@ def _query_stack_get_all(context, show_deleted=False,
context, models.Stack, show_deleted=show_deleted
)
for tag in not_tags:
tag_alias = orm_aliased(models.StackTag)
tag_alias = orm.aliased(models.StackTag)
subquery = subquery.join(tag_alias, models.Stack.tags)
subquery = subquery.filter(tag_alias.tag == tag)
not_stack_ids = [s.id for s in subquery.all()]
@ -811,7 +853,7 @@ def stack_create(context, values):
# Even though we just created a stack with this name, we may not find
# it again because some unit tests create stacks with deleted_at set. Also
# some backup stacks may not be found, for reasons that are unclear.
earliest = stack_get_by_name(context, stack_name)
earliest = _stack_get_by_name(context, stack_name)
if earliest is not None and earliest.id != stack_ref.id:
with context.session.begin():
context.session.query(models.Stack).filter_by(
@ -841,7 +883,7 @@ def stack_update(context, stack_id, values, exp_trvsl=None):
'expected traversal: %(trav)s',
{'id': stack_id, 'vals': str(values),
'trav': str(exp_trvsl)})
if not stack_get(context, stack_id, eager_load=False):
if not _stack_get(context, stack_id, eager_load=False):
raise exception.NotFound(
_('Attempt to update a stack with id: '
'%(id)s %(msg)s') % {
@ -852,7 +894,7 @@ def stack_update(context, stack_id, values, exp_trvsl=None):
def stack_delete(context, stack_id):
s = stack_get(context, stack_id, eager_load=False)
s = _stack_get(context, stack_id, eager_load=False)
if not s:
raise exception.NotFound(_('Attempt to delete a stack with id: '
'%(id)s %(msg)s') % {
@ -873,7 +915,14 @@ def stack_delete(context, stack_id):
_soft_delete(context, s)
def reset_stack_status(context, stack_id, stack=None):
def reset_stack_status(context, stack_id):
return _reset_stack_status(context, stack_id)
# NOTE(stephenfin): This method uses separate transactions to delete nested
# stacks, thus it's the only private method that is allowed to open a
# transaction (via 'context.session.begin')
def _reset_stack_status(context, stack_id, stack=None):
if stack is None:
stack = context.session.get(models.Stack, stack_id)
@ -901,7 +950,7 @@ def reset_stack_status(context, stack_id, stack=None):
query = context.session.query(models.Stack).filter_by(owner_id=stack_id)
for child in query:
reset_stack_status(context, child.id, child)
_reset_stack_status(context, child.id, child)
with context.session.begin():
if stack.status == 'IN_PROGRESS':
@ -915,7 +964,7 @@ def reset_stack_status(context, stack_id, stack=None):
def stack_tags_set(context, stack_id, tags):
with context.session.begin():
stack_tags_delete(context, stack_id)
_stack_tags_delete(context, stack_id)
result = []
for tag in tags:
stack_tag = models.StackTag()
@ -928,13 +977,21 @@ def stack_tags_set(context, stack_id, tags):
def stack_tags_delete(context, stack_id):
with transaction(context):
result = stack_tags_get(context, stack_id)
if result:
for tag in result:
context.session.delete(tag)
return _stack_tags_delete(context, stack_id)
def _stack_tags_delete(context, stack_id):
result = _stack_tags_get(context, stack_id)
if result:
for tag in result:
context.session.delete(tag)
def stack_tags_get(context, stack_id):
return _stack_tags_get(context, stack_id)
def _stack_tags_get(context, stack_id):
result = (context.session.query(models.StackTag)
.filter_by(stack_id=stack_id)
.all())
@ -1004,11 +1061,11 @@ def stack_lock_release(context, stack_id, engine_id):
def stack_get_root_id(context, stack_id):
s = stack_get(context, stack_id, eager_load=False)
s = _stack_get(context, stack_id, eager_load=False)
if not s:
return None
while s.owner_id:
s = stack_get(context, s.owner_id, eager_load=False)
s = _stack_get(context, s.owner_id, eager_load=False)
return s.id
@ -1152,6 +1209,10 @@ def _events_filter_and_page_query(context, query,
def event_count_all_by_stack(context, stack_id):
return _event_count_all_by_stack(context, stack_id)
def _event_count_all_by_stack(context, stack_id):
query = context.session.query(func.count(models.Event.id))
return query.filter_by(stack_id=stack_id).scalar()
@ -1212,47 +1273,46 @@ def _delete_event_rows(context, stack_id, limit):
# So we must manually supply the IN() values.
# pgsql SHOULD work with the pure DELETE/JOIN below but that must be
# confirmed via integration tests.
with context.session.begin():
query = context.session.query(models.Event).filter_by(
stack_id=stack_id,
)
query = query.order_by(models.Event.id).limit(limit)
id_pairs = [(e.id, e.rsrc_prop_data_id) for e in query.all()]
if not id_pairs:
return 0
(ids, rsrc_prop_ids) = zip(*id_pairs)
max_id = ids[-1]
# delete the events
retval = context.session.query(models.Event).filter(
models.Event.id <= max_id).filter(
models.Event.stack_id == stack_id).delete()
query = context.session.query(models.Event).filter_by(
stack_id=stack_id,
)
query = query.order_by(models.Event.id).limit(limit)
id_pairs = [(e.id, e.rsrc_prop_data_id) for e in query.all()]
if not id_pairs:
return 0
(ids, rsrc_prop_ids) = zip(*id_pairs)
max_id = ids[-1]
# delete the events
retval = context.session.query(models.Event).filter(
models.Event.id <= max_id).filter(
models.Event.stack_id == stack_id).delete()
# delete unreferenced resource_properties_data
def del_rpd(rpd_ids):
if not rpd_ids:
return
q_rpd = context.session.query(models.ResourcePropertiesData)
q_rpd = q_rpd.filter(models.ResourcePropertiesData.id.in_(rpd_ids))
q_rpd.delete(synchronize_session=False)
# delete unreferenced resource_properties_data
def del_rpd(rpd_ids):
if not rpd_ids:
return
q_rpd = context.session.query(models.ResourcePropertiesData)
q_rpd = q_rpd.filter(models.ResourcePropertiesData.id.in_(rpd_ids))
q_rpd.delete(synchronize_session=False)
if rsrc_prop_ids:
clr_prop_ids = set(rsrc_prop_ids) - _find_rpd_references(context,
stack_id)
clr_prop_ids.discard(None)
try:
del_rpd(clr_prop_ids)
except db_exception.DBReferenceError:
LOG.debug('Checking backup/stack pairs for RPD references')
found = False
for partner_stack_id in _all_backup_stack_ids(context,
stack_id):
found = True
clr_prop_ids -= _find_rpd_references(context,
partner_stack_id)
if not found:
LOG.debug('No backup/stack pairs found for %s', stack_id)
raise
del_rpd(clr_prop_ids)
if rsrc_prop_ids:
clr_prop_ids = set(rsrc_prop_ids) - _find_rpd_references(context,
stack_id)
clr_prop_ids.discard(None)
try:
del_rpd(clr_prop_ids)
except db_exception.DBReferenceError:
LOG.debug('Checking backup/stack pairs for RPD references')
found = False
for partner_stack_id in _all_backup_stack_ids(context,
stack_id):
found = True
clr_prop_ids -= _find_rpd_references(context,
partner_stack_id)
if not found:
LOG.debug('No backup/stack pairs found for %s', stack_id)
raise
del_rpd(clr_prop_ids)
return retval
@ -1263,13 +1323,18 @@ def event_create(context, values):
# only count events and purge on average
# 200.0/cfg.CONF.event_purge_batch_size percent of the time.
check = (2.0 / cfg.CONF.event_purge_batch_size) > random.uniform(0, 1)
if (check and
(event_count_all_by_stack(context, values['stack_id']) >=
cfg.CONF.max_events_per_stack)):
if (
check and _event_count_all_by_stack(
context, values['stack_id']
) >= cfg.CONF.max_events_per_stack
):
# prune
try:
_delete_event_rows(context, values['stack_id'],
cfg.CONF.event_purge_batch_size)
with context.session.begin():
_delete_event_rows(
context, values['stack_id'],
cfg.CONF.event_purge_batch_size,
)
except db_exception.DBError as exc:
LOG.error('Failed to purge events: %s', str(exc))
event_ref = models.Event()
@ -1289,6 +1354,10 @@ def software_config_create(context, values):
def software_config_get(context, config_id):
return _software_config_get(context, config_id)
def _software_config_get(context, config_id):
result = context.session.get(models.SoftwareConfig, config_id)
if (result is not None and context is not None and not context.is_admin and
result.tenant != context.tenant_id):
@ -1309,7 +1378,7 @@ def software_config_get_all(context, limit=None, marker=None):
def software_config_delete(context, config_id):
config = software_config_get(context, config_id)
config = _software_config_get(context, config_id)
# Query if the software config has been referenced by deployment.
result = context.session.query(models.SoftwareDeployment).filter_by(
config_id=config_id).first()
@ -1340,6 +1409,10 @@ def software_deployment_create(context, values):
def software_deployment_get(context, deployment_id):
return _software_deployment_get(context, deployment_id)
def _software_deployment_get(context, deployment_id):
result = context.session.get(models.SoftwareDeployment, deployment_id)
if (result is not None and context is not None and not context.is_admin and
context.tenant_id not in (result.tenant,
@ -1366,7 +1439,7 @@ def software_deployment_get_all(context, server_id=None):
def software_deployment_update(context, deployment_id, values):
deployment = software_deployment_get(context, deployment_id)
deployment = _software_deployment_get(context, deployment_id)
try:
update_and_save(context, deployment, values)
except db_exception.DBReferenceError:
@ -1377,7 +1450,7 @@ def software_deployment_update(context, deployment_id, values):
def software_deployment_delete(context, deployment_id):
deployment = software_deployment_get(context, deployment_id)
deployment = _software_deployment_get(context, deployment_id)
with context.session.begin():
context.session.delete(deployment)
@ -1393,6 +1466,10 @@ def snapshot_create(context, values):
def snapshot_get(context, snapshot_id):
return _snapshot_get(context, snapshot_id)
def _snapshot_get(context, snapshot_id):
result = context.session.get(models.Snapshot, snapshot_id)
if (result is not None and context is not None and
context.tenant_id != result.tenant):
@ -1405,7 +1482,7 @@ def snapshot_get(context, snapshot_id):
def snapshot_get_by_stack(context, snapshot_id, stack):
snapshot = snapshot_get(context, snapshot_id)
snapshot = _snapshot_get(context, snapshot_id)
if snapshot.stack_id != stack.id:
raise exception.SnapshotNotFound(snapshot=snapshot_id,
stack=stack.name)
@ -1414,14 +1491,14 @@ def snapshot_get_by_stack(context, snapshot_id, stack):
def snapshot_update(context, snapshot_id, values):
snapshot = snapshot_get(context, snapshot_id)
snapshot = _snapshot_get(context, snapshot_id)
snapshot.update(values)
snapshot.save(context.session)
return snapshot
def snapshot_delete(context, snapshot_id):
snapshot = snapshot_get(context, snapshot_id)
snapshot = _snapshot_get(context, snapshot_id)
with context.session.begin():
context.session.delete(snapshot)
@ -1442,7 +1519,7 @@ def service_create(context, values):
def service_update(context, service_id, values):
service = service_get(context, service_id)
service = _service_get(context, service_id)
values.update({'updated_at': timeutils.utcnow()})
service.update(values)
service.save(context.session)
@ -1450,7 +1527,7 @@ def service_update(context, service_id, values):
def service_delete(context, service_id, soft_delete=True):
service = service_get(context, service_id)
service = _service_get(context, service_id)
with context.session.begin():
if soft_delete:
_soft_delete(context, service)
@ -1459,6 +1536,10 @@ def service_delete(context, service_id, soft_delete=True):
def service_get(context, service_id):
return _service_get(context, service_id)
def _service_get(context, service_id):
result = context.session.get(models.Service, service_id)
if result is None:
raise exception.EntityNotFound(entity='Service', name=service_id)
@ -1466,8 +1547,9 @@ def service_get(context, service_id):
def service_get_all(context):
return (context.session.query(models.Service).
filter_by(deleted_at=None).all())
return context.session.query(models.Service).filter_by(
deleted_at=None,
).all()
def service_get_all_by_args(context, host, binary, hostname):