del _session(), use context.session directly

There should always be a context supplied, lets use its session
directly so we can be sure that the session is the shared one from the
context.

Any remaining functions which call get_session() directly will be
handled in their own changes as some require special behaviour.

Change-Id: I3e52d2ff6bb1014f201715a1b1ba32c122c2391d
Related-Bug: #1479723
changes/99/330799/5
Steve Baker 6 years ago
parent 4ed43674c1
commit 419c9ab994
  1. 59
      heat/db/sqlalchemy/api.py
  2. 2
      heat/engine/stack.py
  3. 2
      heat/engine/template.py
  4. 8
      heat/tests/aws/test_waitcondition.py
  5. 20
      heat/tests/db/test_sqlalchemy_api.py
  6. 2
      heat/tests/engine/service/test_stack_create.py
  7. 2
      heat/tests/engine/service/test_stack_update.py
  8. 4
      heat/tests/openstack/heat/test_waitcondition.py
  9. 13
      heat/tests/test_resource.py
  10. 2
      heat/tests/test_stack.py
  11. 7
      heat/tests/test_template_files.py
  12. 2
      heat/tests/utils.py

@ -81,7 +81,7 @@ def get_backend():
def model_query(context, *args):
session = _session(context)
session = context.session
query = session.query(*args)
return query
@ -100,10 +100,6 @@ def soft_delete_aware_query(context, *args, **kwargs):
return query
def _session(context):
return (context and context.session) or get_session()
def raw_template_get(context, template_id):
result = model_query(context, models.RawTemplate).get(template_id)
@ -116,7 +112,7 @@ def raw_template_get(context, template_id):
def raw_template_create(context, values):
raw_template_ref = models.RawTemplate()
raw_template_ref.update(values)
raw_template_ref.save(_session(context))
raw_template_ref.save(context.session)
return raw_template_ref
@ -140,13 +136,13 @@ def raw_template_delete(context, template_id):
return
# If no other raw_template is referencing the same raw_template_files,
# delete that too
if _session(context).query(models.RawTemplate).filter_by(
if context.session.query(models.RawTemplate).filter_by(
files_id=raw_tmpl_files_id).first() is None:
raw_template_files_get(context, raw_tmpl_files_id).delete()
def raw_template_files_create(context, values):
session = _session(context)
session = context.session
raw_templ_files_ref = models.RawTemplateFiles()
raw_templ_files_ref.update(values)
with session.begin():
@ -155,7 +151,8 @@ def raw_template_files_create(context, values):
def raw_template_files_get(context, files_id):
result = model_query(context, models.RawTemplateFiles).get(files_id)
session = context.session if context else get_session()
result = session.query(models.RawTemplateFiles).get(files_id)
if not result:
raise exception.NotFound(
_("raw_template_files with files_id %d not found") %
@ -205,7 +202,7 @@ def resource_get_all(context):
def resource_update(context, resource_id, values, atomic_key,
expected_engine_id=None):
session = _session(context)
session = context.session
with session.begin():
if atomic_key is None:
values['atomic_key'] = 1
@ -254,7 +251,7 @@ def resource_data_get(context, resource_id, key):
def stack_tags_set(context, stack_id, tags):
session = _session(context)
session = context.session
with session.begin():
stack_tags_delete(context, stack_id)
result = []
@ -268,7 +265,7 @@ def stack_tags_set(context, stack_id, tags):
def stack_tags_delete(context, stack_id):
session = _session(context)
session = context.session
with session.begin(subtransactions=True):
result = stack_tags_get(context, stack_id)
if result:
@ -312,7 +309,7 @@ def resource_data_set(context, resource_id, key, value, redact=False):
current.redact = redact
current.value = value
current.decrypt_method = method
current.save(session=_session(context))
current.save(session=context.session)
return current
@ -337,7 +334,7 @@ def resource_data_delete(context, resource_id, key):
def resource_create(context, values):
resource_ref = models.Resource()
resource_ref.update(values)
resource_ref.save(_session(context))
resource_ref.save(context.session)
return resource_ref
@ -580,7 +577,7 @@ def stack_count_all(context, filters=None, tenant_safe=True,
def stack_create(context, values):
stack_ref = models.Stack()
stack_ref.update(values)
stack_ref.save(_session(context))
stack_ref.save(context.session)
return stack_ref
@ -600,7 +597,7 @@ def stack_update(context, stack_id, values, exp_trvsl=None):
# stack updated by another update
return False
session = _session(context)
session = context.session
with session.begin():
rows_updated = (session.query(models.Stack)
@ -646,7 +643,7 @@ def stack_lock_get_engine_id(stack_id):
def persist_state_and_release_lock(context, stack_id, engine_id, values):
session = _session(context)
session = context.session
with session.begin():
rows_updated = (session.query(models.Stack)
.filter(models.Stack.id == stack_id)
@ -722,7 +719,7 @@ def user_creds_create(context):
" exceeds Heat limit (255 chars)"))
user_creds_ref.password = password
user_creds_ref.decrypt_method = method
user_creds_ref.save(_session(context))
user_creds_ref.save(context.session)
result = dict(user_creds_ref)
if values.get('trust_id'):
@ -855,7 +852,7 @@ def _delete_event_rows(context, stack_id, limit):
# pgsql SHOULD work with the pure DELETE/JOIN below but that must be
# confirmed via integration tests.
query = _query_all_by_stack(context, stack_id)
session = _session(context)
session = context.session
ids = [r.id for r in query.order_by(
models.Event.id).limit(limit).all()]
q = session.query(models.Event).filter(
@ -872,7 +869,7 @@ def event_create(context, values):
context, values['stack_id'], cfg.CONF.event_purge_batch_size)
event_ref = models.Event()
event_ref.update(values)
event_ref.save(_session(context))
event_ref.save(context.session)
return event_ref
@ -901,7 +898,7 @@ def watch_rule_get_all_by_stack(context, stack_id):
def watch_rule_create(context, values):
obj_ref = models.WatchRule()
obj_ref.update(values)
obj_ref.save(_session(context))
obj_ref.save(context.session)
return obj_ref
@ -914,7 +911,7 @@ def watch_rule_update(context, watch_id, values):
'id': watch_id,
'msg': 'that does not exist'})
wr.update(values)
wr.save(_session(context))
wr.save(context.session)
def watch_rule_delete(context, watch_id):
@ -934,7 +931,7 @@ def watch_rule_delete(context, watch_id):
def watch_data_create(context, values):
obj_ref = models.WatchData()
obj_ref.update(values)
obj_ref.save(_session(context))
obj_ref.save(context.session)
return obj_ref
@ -952,7 +949,7 @@ def watch_data_get_all_by_watch_rule_id(context, watch_rule_id):
def software_config_create(context, values):
obj_ref = models.SoftwareConfig()
obj_ref.update(values)
obj_ref.save(_session(context))
obj_ref.save(context.session)
return obj_ref
@ -994,7 +991,7 @@ def software_config_delete(context, config_id):
def software_deployment_create(context, values):
obj_ref = models.SoftwareDeployment()
obj_ref.update(values)
session = _session(context)
session = context.session
session.begin()
obj_ref.save(session)
session.commit()
@ -1041,7 +1038,7 @@ def software_deployment_delete(context, deployment_id):
def snapshot_create(context, values):
obj_ref = models.Snapshot()
obj_ref.update(values)
obj_ref.save(_session(context))
obj_ref.save(context.session)
return obj_ref
@ -1069,7 +1066,7 @@ def snapshot_get_by_stack(context, snapshot_id, stack):
def snapshot_update(context, snapshot_id, values):
snapshot = snapshot_get(context, snapshot_id)
snapshot.update(values)
snapshot.save(_session(context))
snapshot.save(context.session)
return snapshot
@ -1088,7 +1085,7 @@ def snapshot_get_all(context, stack_id):
def service_create(context, values):
service = models.Service()
service.update(values)
service.save(_session(context))
service.save(context.session)
return service
@ -1096,7 +1093,7 @@ def service_update(context, service_id, values):
service = service_get(context, service_id)
values.update({'updated_at': timeutils.utcnow()})
service.update(values)
service.save(_session(context))
service.save(context.session)
return service
@ -1264,7 +1261,7 @@ def sync_point_create(context, values):
values['entity_id'] = str(values['entity_id'])
sync_point_ref = models.SyncPoint()
sync_point_ref.update(values)
sync_point_ref.save(_session(context))
sync_point_ref.save(context.session)
return sync_point_ref
@ -1507,7 +1504,7 @@ def reset_stack_status(context, stack_id, stack=None):
if stack is None:
raise exception.NotFound(_('Stack with id %s not found') % stack_id)
session = _session(context)
session = context.session
with session.begin():
query = model_query(context, models.Resource).filter_by(
status='IN_PROGRESS', stack_id=stack_id)

@ -1370,7 +1370,7 @@ class Stack(collections.Mapping):
prev_tmpl_id = self.prev_raw_template_id
# newstack.t may have been pre-stored, so save with that one
bu_tmpl, newstack.t = newstack.t, copy.deepcopy(newstack.t)
self.prev_raw_template_id = bu_tmpl.store()
self.prev_raw_template_id = bu_tmpl.store(self.context)
self.action = action
self.status = self.IN_PROGRESS
self.status_reason = 'Stack %s started' % action

@ -137,7 +137,7 @@ class Template(collections.Mapping):
"""Store the Template in the database and return its ID."""
rt = {
'template': self.t,
'files_id': self.files.store(),
'files_id': self.files.store(context),
'environment': self.env.user_env_as_dict()
}
if self.id is None:

@ -128,7 +128,7 @@ class WaitConditionTest(common.HeatTestCase):
rsrc.state)
r = resource_objects.Resource.get_by_name_and_stack(
None, 'WaitHandle', self.stack.id)
self.stack.context, 'WaitHandle', self.stack.id)
self.assertEqual('WaitHandle', r.name)
self.m.VerifyAll()
@ -148,7 +148,7 @@ class WaitConditionTest(common.HeatTestCase):
self.assertTrue(reason.startswith('WaitConditionFailure:'))
r = resource_objects.Resource.get_by_name_and_stack(
None, 'WaitHandle', self.stack.id)
self.stack.context, 'WaitHandle', self.stack.id)
self.assertEqual('WaitHandle', r.name)
self.m.VerifyAll()
@ -171,7 +171,7 @@ class WaitConditionTest(common.HeatTestCase):
rsrc.state)
r = resource_objects.Resource.get_by_name_and_stack(
None, 'WaitHandle', self.stack.id)
self.stack.context, 'WaitHandle', self.stack.id)
self.assertEqual('WaitHandle', r.name)
self.m.VerifyAll()
@ -192,7 +192,7 @@ class WaitConditionTest(common.HeatTestCase):
self.assertTrue(reason.startswith('WaitConditionFailure:'))
r = resource_objects.Resource.get_by_name_and_stack(
None, 'WaitHandle', self.stack.id)
self.stack.context, 'WaitHandle', self.stack.id)
self.assertEqual('WaitHandle', r.name)
self.m.VerifyAll()

@ -1385,7 +1385,7 @@ def create_raw_template(context, **kwargs):
if 'files' not in kwargs and 'files_id' not in kwargs:
# modern raw_templates have associated raw_template_files db obj
tf = template_files.TemplateFiles({'foo': 'bar'})
tf.store()
tf.store(context)
kwargs['files_id'] = tf.files_id
template.update(kwargs)
return db_api.raw_template_create(context, template)
@ -2462,7 +2462,7 @@ class DBAPIResourceDataTest(common.HeatTestCase):
self.assertEqual('test_value', vals.get('encryped_resource_key'))
# get all by using associated resource data
vals = db_api.resource_data_get_all(None, None, self.resource.data)
vals = db_api.resource_data_get_all(self.ctx, None, self.resource.data)
self.assertEqual(2, len(vals))
self.assertEqual('foo', vals.get('test_resource_key'))
self.assertEqual('test_value', vals.get('encryped_resource_key'))
@ -3145,7 +3145,7 @@ class DBAPICryptParamsPropsTest(common.HeatTestCase):
return db_api.raw_template_create(self.ctx, template)
def _test_db_encrypt_decrypt(self, batch_size=50):
session = db_api.get_session()
session = self.ctx.session
hidden_params_dict = {
'param2': 'bar',
'param_number': '456',
@ -3171,7 +3171,7 @@ class DBAPICryptParamsPropsTest(common.HeatTestCase):
enc_key = cfg.CONF.auth_encryption_key
self.assertEqual([], db_api.db_encrypt_parameters_and_properties(
self.ctx, enc_key, batch_size=batch_size))
session = db_api.get_session()
session = self.ctx.session
enc_raw_templates = session.query(models.RawTemplate).all()
self.assertNotEqual([], enc_raw_templates)
for enc_tmpl in enc_raw_templates:
@ -3205,7 +3205,7 @@ class DBAPICryptParamsPropsTest(common.HeatTestCase):
enc_key = cfg.CONF.auth_encryption_key
self.assertEqual([], db_api.db_decrypt_parameters_and_properties(
self.ctx, enc_key, batch_size=batch_size))
session = db_api.get_session()
session = self.ctx.session
dec_templates = session.query(models.RawTemplate).all()
self.assertNotEqual([], dec_templates)
for dec_tmpl in dec_templates:
@ -3305,7 +3305,7 @@ class DBAPICryptParamsPropsTest(common.HeatTestCase):
tmpl2 = self._create_template()
self.addCleanup(self._delete_templates, [tmpl1, tmpl2])
session = db_api.get_session()
session = self.ctx.session
r_tmpls = session.query(models.RawTemplate).all()
self.assertEqual('', r_tmpls[1].environment)
@ -3314,7 +3314,7 @@ class DBAPICryptParamsPropsTest(common.HeatTestCase):
self.ctx, cfg.CONF.auth_encryption_key, batch_size=50)
self.assertEqual(1, len(enc_result))
self.assertIs(AttributeError, type(enc_result[0]))
session = db_api.get_session()
session = self.ctx.session
enc_tmpls = session.query(models.RawTemplate).all()
self.assertEqual('', enc_tmpls[1].environment)
self.assertEqual('cryptography_decrypt_v1',
@ -3325,7 +3325,7 @@ class DBAPICryptParamsPropsTest(common.HeatTestCase):
self.ctx, cfg.CONF.auth_encryption_key, batch_size=50)
self.assertEqual(len(dec_result), 1)
self.assertIs(TypeError, type(dec_result[0]))
session = db_api.get_session()
session = self.ctx.session
dec_tmpls = session.query(models.RawTemplate).all()
self.assertEqual('', dec_tmpls[1].environment)
self.assertEqual('bar',
@ -3467,14 +3467,14 @@ class DBAPICryptParamsPropsTest(common.HeatTestCase):
db_api.raw_template_create(self.ctx, template)
self.assertEqual([], db_api.db_encrypt_parameters_and_properties(
self.ctx, cfg.CONF.auth_encryption_key))
session = db_api.get_session()
session = self.ctx.session
enc_raw_templates = session.query(models.RawTemplate).all()
self.assertNotEqual([], enc_raw_templates)
enc_params = enc_raw_templates[1].environment['parameters']
self.assertEqual([], db_api.db_decrypt_parameters_and_properties(
self.ctx, cfg.CONF.auth_encryption_key, batch_size=50))
session = db_api.get_session()
session = self.ctx.session
dec_tmpls = session.query(models.RawTemplate).all()
dec_params = dec_tmpls[1].environment['parameters']

@ -278,7 +278,7 @@ class StackCreateTest(common.HeatTestCase):
mock_tg.return_value = tools.DummyThreadGroup()
stk = tools.get_stack(stack_name, self.ctx, with_params=True)
tmpl_id = stk.t.store()
tmpl_id = stk.t.store(self.ctx)
mock_load = self.patchobject(templatem.Template, 'load',
return_value=stk.t)

@ -138,7 +138,7 @@ class ServiceStackUpdateTest(common.HeatTestCase):
s = stack_object.Stack.get_by_id(self.ctx, sid)
stk = tools.get_stack(stack_name, self.ctx)
tmpl_id = stk.t.store()
tmpl_id = stk.t.store(self.ctx)
# prepare mocks
mock_stack = self.patchobject(stack, 'Stack', return_value=stk)

@ -177,7 +177,7 @@ class HeatWaitConditionTest(common.HeatTestCase):
rsrc.state)
r = resource_objects.Resource.get_by_name_and_stack(
None, 'wait_handle', self.stack.id)
self.stack.context, 'wait_handle', self.stack.id)
self.assertEqual('wait_handle', r.name)
self.m.VerifyAll()
@ -201,7 +201,7 @@ class HeatWaitConditionTest(common.HeatTestCase):
self.assertTrue(reason.startswith('WaitConditionFailure:'))
r = resource_objects.Resource.get_by_name_and_stack(
None, 'wait_handle', self.stack.id)
self.stack.context, 'wait_handle', self.stack.id)
self.assertEqual('wait_handle', r.name)
self.m.VerifyAll()

@ -1809,7 +1809,7 @@ class ResourceTest(common.HeatTestCase):
'test_res': {'Type': 'ResourceWithPropsType',
'Properties': {'Foo': 'abc'}}
}}, env=self.env)
new_temp.store()
new_temp.store(stack.context)
new_stack = parser.Stack(utils.dummy_context(), 'test_stack',
new_temp, stack_id=self.stack.id)
@ -1843,7 +1843,7 @@ class ResourceTest(common.HeatTestCase):
'test_res': {'Type': 'ResourceWithPropsType',
'Properties': {'Foo': 'abc'}}
}}, env=self.env)
new_temp.store()
new_temp.store(stack.context)
new_stack = parser.Stack(utils.dummy_context(), 'test_stack',
new_temp, stack_id=self.stack.id)
@ -1864,8 +1864,9 @@ class ResourceTest(common.HeatTestCase):
'test_res': {'Type': 'ResourceWithPropsType',
'Properties': {'Foo': 'abc'}}
}}, env=self.env)
new_temp.store()
new_stack = parser.Stack(utils.dummy_context(), 'test_stack',
ctx = utils.dummy_context()
new_temp.store(ctx)
new_stack = parser.Stack(ctx, 'test_stack',
new_temp, stack_id=self.stack.id)
res_data = {}
@ -1917,7 +1918,7 @@ class ResourceTest(common.HeatTestCase):
'test_res': {'Type': 'ResourceWithPropsType',
'Properties': {'Foo': 'abc'}}
}}, env=self.env)
new_temp.store()
new_temp.store(stack.context)
res_data = {(1, True): {u'id': 4, u'name': 'A', 'attrs': {}},
(2, True): {u'id': 3, u'name': 'B', 'attrs': {}}}
@ -1959,7 +1960,7 @@ class ResourceTest(common.HeatTestCase):
'test_res': {'Type': 'ResourceWithPropsType',
'Properties': {'Foo': 'abc'}}
}}, env=self.env)
new_temp.store()
new_temp.store(stack.context)
res_data = {(1, True): {u'id': 4, u'name': 'A', 'attrs': {}},
(2, True): {u'id': 3, u'name': 'B', 'attrs': {}}}

@ -196,7 +196,7 @@ class StackTest(common.HeatTestCase):
def test_load_nonexistant_id(self):
self.assertRaises(exception.NotFound, stack.Stack.load,
None, -1)
self.ctx, -1)
def test_total_resources_empty(self):
self.stack = stack.Stack(self.ctx, 'test_stack', self.tmpl,

@ -12,6 +12,7 @@
from heat.engine import template_files
from heat.tests import common
from heat.tests import utils
template_files_1 = {'template file 1': 'Contents of template 1',
'template file 2': 'More template contents'}
@ -20,8 +21,9 @@ template_files_1 = {'template file 1': 'Contents of template 1',
class TestTemplateFiles(common.HeatTestCase):
def test_cache_miss(self):
ctx = utils.dummy_context()
tf1 = template_files.TemplateFiles(template_files_1)
tf1.store()
tf1.store(ctx)
# As this is the only reference to the value in _d, deleting
# t1.files will cause the value to get removed from _d (due to
# it being a WeakValueDictionary.
@ -33,8 +35,9 @@ class TestTemplateFiles(common.HeatTestCase):
self.assertEqual(template_files_1, template_files._d[tf1.files_id])
def test_d_weakref_behaviour(self):
ctx = utils.dummy_context()
tf1 = template_files.TemplateFiles(template_files_1)
tf1.store()
tf1.store(ctx)
tf2 = template_files.TemplateFiles(tf1)
del tf1.files
self.assertIn(tf2.files_id, template_files._d)

@ -94,7 +94,7 @@ def parse_stack(t, params=None, files=None, stack_name=None,
ctx = dummy_context()
templ = template.Template(t, files=files,
env=environment.Environment(params))
templ.store()
templ.store(ctx)
if stack_name is None:
stack_name = random_name()
stk = stack.Stack(ctx, stack_name, templ, stack_id=stack_id,

Loading…
Cancel
Save