diff --git a/heat/db/sqlalchemy/api.py b/heat/db/sqlalchemy/api.py index 23f482d3d8..e87334fdad 100644 --- a/heat/db/sqlalchemy/api.py +++ b/heat/db/sqlalchemy/api.py @@ -81,7 +81,7 @@ def get_backend(): def model_query(context, *args): - session = _session(context) + session = context.session query = session.query(*args) return query @@ -100,10 +100,6 @@ def soft_delete_aware_query(context, *args, **kwargs): return query -def _session(context): - return (context and context.session) or get_session() - - def raw_template_get(context, template_id): result = model_query(context, models.RawTemplate).get(template_id) @@ -116,7 +112,7 @@ def raw_template_get(context, template_id): def raw_template_create(context, values): raw_template_ref = models.RawTemplate() raw_template_ref.update(values) - raw_template_ref.save(_session(context)) + raw_template_ref.save(context.session) return raw_template_ref @@ -140,13 +136,13 @@ def raw_template_delete(context, template_id): return # If no other raw_template is referencing the same raw_template_files, # delete that too - if _session(context).query(models.RawTemplate).filter_by( + if context.session.query(models.RawTemplate).filter_by( files_id=raw_tmpl_files_id).first() is None: raw_template_files_get(context, raw_tmpl_files_id).delete() def raw_template_files_create(context, values): - session = _session(context) + session = context.session raw_templ_files_ref = models.RawTemplateFiles() raw_templ_files_ref.update(values) with session.begin(): @@ -155,7 +151,8 @@ def raw_template_files_create(context, values): def raw_template_files_get(context, files_id): - result = model_query(context, models.RawTemplateFiles).get(files_id) + session = context.session if context else get_session() + result = session.query(models.RawTemplateFiles).get(files_id) if not result: raise exception.NotFound( _("raw_template_files with files_id %d not found") % @@ -205,7 +202,7 @@ def resource_get_all(context): def resource_update(context, resource_id, values, atomic_key, expected_engine_id=None): - session = _session(context) + session = context.session with session.begin(): if atomic_key is None: values['atomic_key'] = 1 @@ -254,7 +251,7 @@ def resource_data_get(context, resource_id, key): def stack_tags_set(context, stack_id, tags): - session = _session(context) + session = context.session with session.begin(): stack_tags_delete(context, stack_id) result = [] @@ -268,7 +265,7 @@ def stack_tags_set(context, stack_id, tags): def stack_tags_delete(context, stack_id): - session = _session(context) + session = context.session with session.begin(subtransactions=True): result = stack_tags_get(context, stack_id) if result: @@ -312,7 +309,7 @@ def resource_data_set(context, resource_id, key, value, redact=False): current.redact = redact current.value = value current.decrypt_method = method - current.save(session=_session(context)) + current.save(session=context.session) return current @@ -337,7 +334,7 @@ def resource_data_delete(context, resource_id, key): def resource_create(context, values): resource_ref = models.Resource() resource_ref.update(values) - resource_ref.save(_session(context)) + resource_ref.save(context.session) return resource_ref @@ -580,7 +577,7 @@ def stack_count_all(context, filters=None, tenant_safe=True, def stack_create(context, values): stack_ref = models.Stack() stack_ref.update(values) - stack_ref.save(_session(context)) + stack_ref.save(context.session) return stack_ref @@ -600,7 +597,7 @@ def stack_update(context, stack_id, values, exp_trvsl=None): # stack updated by another update return False - session = _session(context) + session = context.session with session.begin(): rows_updated = (session.query(models.Stack) @@ -646,7 +643,7 @@ def stack_lock_get_engine_id(stack_id): def persist_state_and_release_lock(context, stack_id, engine_id, values): - session = _session(context) + session = context.session with session.begin(): rows_updated = (session.query(models.Stack) .filter(models.Stack.id == stack_id) @@ -722,7 +719,7 @@ def user_creds_create(context): " exceeds Heat limit (255 chars)")) user_creds_ref.password = password user_creds_ref.decrypt_method = method - user_creds_ref.save(_session(context)) + user_creds_ref.save(context.session) result = dict(user_creds_ref) if values.get('trust_id'): @@ -855,7 +852,7 @@ def _delete_event_rows(context, stack_id, limit): # pgsql SHOULD work with the pure DELETE/JOIN below but that must be # confirmed via integration tests. query = _query_all_by_stack(context, stack_id) - session = _session(context) + session = context.session ids = [r.id for r in query.order_by( models.Event.id).limit(limit).all()] q = session.query(models.Event).filter( @@ -872,7 +869,7 @@ def event_create(context, values): context, values['stack_id'], cfg.CONF.event_purge_batch_size) event_ref = models.Event() event_ref.update(values) - event_ref.save(_session(context)) + event_ref.save(context.session) return event_ref @@ -901,7 +898,7 @@ def watch_rule_get_all_by_stack(context, stack_id): def watch_rule_create(context, values): obj_ref = models.WatchRule() obj_ref.update(values) - obj_ref.save(_session(context)) + obj_ref.save(context.session) return obj_ref @@ -914,7 +911,7 @@ def watch_rule_update(context, watch_id, values): 'id': watch_id, 'msg': 'that does not exist'}) wr.update(values) - wr.save(_session(context)) + wr.save(context.session) def watch_rule_delete(context, watch_id): @@ -934,7 +931,7 @@ def watch_rule_delete(context, watch_id): def watch_data_create(context, values): obj_ref = models.WatchData() obj_ref.update(values) - obj_ref.save(_session(context)) + obj_ref.save(context.session) return obj_ref @@ -952,7 +949,7 @@ def watch_data_get_all_by_watch_rule_id(context, watch_rule_id): def software_config_create(context, values): obj_ref = models.SoftwareConfig() obj_ref.update(values) - obj_ref.save(_session(context)) + obj_ref.save(context.session) return obj_ref @@ -994,7 +991,7 @@ def software_config_delete(context, config_id): def software_deployment_create(context, values): obj_ref = models.SoftwareDeployment() obj_ref.update(values) - session = _session(context) + session = context.session session.begin() obj_ref.save(session) session.commit() @@ -1041,7 +1038,7 @@ def software_deployment_delete(context, deployment_id): def snapshot_create(context, values): obj_ref = models.Snapshot() obj_ref.update(values) - obj_ref.save(_session(context)) + obj_ref.save(context.session) return obj_ref @@ -1069,7 +1066,7 @@ def snapshot_get_by_stack(context, snapshot_id, stack): def snapshot_update(context, snapshot_id, values): snapshot = snapshot_get(context, snapshot_id) snapshot.update(values) - snapshot.save(_session(context)) + snapshot.save(context.session) return snapshot @@ -1088,7 +1085,7 @@ def snapshot_get_all(context, stack_id): def service_create(context, values): service = models.Service() service.update(values) - service.save(_session(context)) + service.save(context.session) return service @@ -1096,7 +1093,7 @@ def service_update(context, service_id, values): service = service_get(context, service_id) values.update({'updated_at': timeutils.utcnow()}) service.update(values) - service.save(_session(context)) + service.save(context.session) return service @@ -1264,7 +1261,7 @@ def sync_point_create(context, values): values['entity_id'] = str(values['entity_id']) sync_point_ref = models.SyncPoint() sync_point_ref.update(values) - sync_point_ref.save(_session(context)) + sync_point_ref.save(context.session) return sync_point_ref @@ -1507,7 +1504,7 @@ def reset_stack_status(context, stack_id, stack=None): if stack is None: raise exception.NotFound(_('Stack with id %s not found') % stack_id) - session = _session(context) + session = context.session with session.begin(): query = model_query(context, models.Resource).filter_by( status='IN_PROGRESS', stack_id=stack_id) diff --git a/heat/engine/stack.py b/heat/engine/stack.py index 04a821869f..ad89ced969 100644 --- a/heat/engine/stack.py +++ b/heat/engine/stack.py @@ -1370,7 +1370,7 @@ class Stack(collections.Mapping): prev_tmpl_id = self.prev_raw_template_id # newstack.t may have been pre-stored, so save with that one bu_tmpl, newstack.t = newstack.t, copy.deepcopy(newstack.t) - self.prev_raw_template_id = bu_tmpl.store() + self.prev_raw_template_id = bu_tmpl.store(self.context) self.action = action self.status = self.IN_PROGRESS self.status_reason = 'Stack %s started' % action diff --git a/heat/engine/template.py b/heat/engine/template.py index fb941ce13f..b414618387 100644 --- a/heat/engine/template.py +++ b/heat/engine/template.py @@ -137,7 +137,7 @@ class Template(collections.Mapping): """Store the Template in the database and return its ID.""" rt = { 'template': self.t, - 'files_id': self.files.store(), + 'files_id': self.files.store(context), 'environment': self.env.user_env_as_dict() } if self.id is None: diff --git a/heat/tests/aws/test_waitcondition.py b/heat/tests/aws/test_waitcondition.py index 854390e6f0..9a4aed9fd9 100644 --- a/heat/tests/aws/test_waitcondition.py +++ b/heat/tests/aws/test_waitcondition.py @@ -128,7 +128,7 @@ class WaitConditionTest(common.HeatTestCase): rsrc.state) r = resource_objects.Resource.get_by_name_and_stack( - None, 'WaitHandle', self.stack.id) + self.stack.context, 'WaitHandle', self.stack.id) self.assertEqual('WaitHandle', r.name) self.m.VerifyAll() @@ -148,7 +148,7 @@ class WaitConditionTest(common.HeatTestCase): self.assertTrue(reason.startswith('WaitConditionFailure:')) r = resource_objects.Resource.get_by_name_and_stack( - None, 'WaitHandle', self.stack.id) + self.stack.context, 'WaitHandle', self.stack.id) self.assertEqual('WaitHandle', r.name) self.m.VerifyAll() @@ -171,7 +171,7 @@ class WaitConditionTest(common.HeatTestCase): rsrc.state) r = resource_objects.Resource.get_by_name_and_stack( - None, 'WaitHandle', self.stack.id) + self.stack.context, 'WaitHandle', self.stack.id) self.assertEqual('WaitHandle', r.name) self.m.VerifyAll() @@ -192,7 +192,7 @@ class WaitConditionTest(common.HeatTestCase): self.assertTrue(reason.startswith('WaitConditionFailure:')) r = resource_objects.Resource.get_by_name_and_stack( - None, 'WaitHandle', self.stack.id) + self.stack.context, 'WaitHandle', self.stack.id) self.assertEqual('WaitHandle', r.name) self.m.VerifyAll() diff --git a/heat/tests/db/test_sqlalchemy_api.py b/heat/tests/db/test_sqlalchemy_api.py index 75f0ee898f..02d0e0cc48 100644 --- a/heat/tests/db/test_sqlalchemy_api.py +++ b/heat/tests/db/test_sqlalchemy_api.py @@ -1385,7 +1385,7 @@ def create_raw_template(context, **kwargs): if 'files' not in kwargs and 'files_id' not in kwargs: # modern raw_templates have associated raw_template_files db obj tf = template_files.TemplateFiles({'foo': 'bar'}) - tf.store() + tf.store(context) kwargs['files_id'] = tf.files_id template.update(kwargs) return db_api.raw_template_create(context, template) @@ -2462,7 +2462,7 @@ class DBAPIResourceDataTest(common.HeatTestCase): self.assertEqual('test_value', vals.get('encryped_resource_key')) # get all by using associated resource data - vals = db_api.resource_data_get_all(None, None, self.resource.data) + vals = db_api.resource_data_get_all(self.ctx, None, self.resource.data) self.assertEqual(2, len(vals)) self.assertEqual('foo', vals.get('test_resource_key')) self.assertEqual('test_value', vals.get('encryped_resource_key')) @@ -3145,7 +3145,7 @@ class DBAPICryptParamsPropsTest(common.HeatTestCase): return db_api.raw_template_create(self.ctx, template) def _test_db_encrypt_decrypt(self, batch_size=50): - session = db_api.get_session() + session = self.ctx.session hidden_params_dict = { 'param2': 'bar', 'param_number': '456', @@ -3171,7 +3171,7 @@ class DBAPICryptParamsPropsTest(common.HeatTestCase): enc_key = cfg.CONF.auth_encryption_key self.assertEqual([], db_api.db_encrypt_parameters_and_properties( self.ctx, enc_key, batch_size=batch_size)) - session = db_api.get_session() + session = self.ctx.session enc_raw_templates = session.query(models.RawTemplate).all() self.assertNotEqual([], enc_raw_templates) for enc_tmpl in enc_raw_templates: @@ -3205,7 +3205,7 @@ class DBAPICryptParamsPropsTest(common.HeatTestCase): enc_key = cfg.CONF.auth_encryption_key self.assertEqual([], db_api.db_decrypt_parameters_and_properties( self.ctx, enc_key, batch_size=batch_size)) - session = db_api.get_session() + session = self.ctx.session dec_templates = session.query(models.RawTemplate).all() self.assertNotEqual([], dec_templates) for dec_tmpl in dec_templates: @@ -3305,7 +3305,7 @@ class DBAPICryptParamsPropsTest(common.HeatTestCase): tmpl2 = self._create_template() self.addCleanup(self._delete_templates, [tmpl1, tmpl2]) - session = db_api.get_session() + session = self.ctx.session r_tmpls = session.query(models.RawTemplate).all() self.assertEqual('', r_tmpls[1].environment) @@ -3314,7 +3314,7 @@ class DBAPICryptParamsPropsTest(common.HeatTestCase): self.ctx, cfg.CONF.auth_encryption_key, batch_size=50) self.assertEqual(1, len(enc_result)) self.assertIs(AttributeError, type(enc_result[0])) - session = db_api.get_session() + session = self.ctx.session enc_tmpls = session.query(models.RawTemplate).all() self.assertEqual('', enc_tmpls[1].environment) self.assertEqual('cryptography_decrypt_v1', @@ -3325,7 +3325,7 @@ class DBAPICryptParamsPropsTest(common.HeatTestCase): self.ctx, cfg.CONF.auth_encryption_key, batch_size=50) self.assertEqual(len(dec_result), 1) self.assertIs(TypeError, type(dec_result[0])) - session = db_api.get_session() + session = self.ctx.session dec_tmpls = session.query(models.RawTemplate).all() self.assertEqual('', dec_tmpls[1].environment) self.assertEqual('bar', @@ -3467,14 +3467,14 @@ class DBAPICryptParamsPropsTest(common.HeatTestCase): db_api.raw_template_create(self.ctx, template) self.assertEqual([], db_api.db_encrypt_parameters_and_properties( self.ctx, cfg.CONF.auth_encryption_key)) - session = db_api.get_session() + session = self.ctx.session enc_raw_templates = session.query(models.RawTemplate).all() self.assertNotEqual([], enc_raw_templates) enc_params = enc_raw_templates[1].environment['parameters'] self.assertEqual([], db_api.db_decrypt_parameters_and_properties( self.ctx, cfg.CONF.auth_encryption_key, batch_size=50)) - session = db_api.get_session() + session = self.ctx.session dec_tmpls = session.query(models.RawTemplate).all() dec_params = dec_tmpls[1].environment['parameters'] diff --git a/heat/tests/engine/service/test_stack_create.py b/heat/tests/engine/service/test_stack_create.py index 8879650c6b..761b9aea4f 100644 --- a/heat/tests/engine/service/test_stack_create.py +++ b/heat/tests/engine/service/test_stack_create.py @@ -278,7 +278,7 @@ class StackCreateTest(common.HeatTestCase): mock_tg.return_value = tools.DummyThreadGroup() stk = tools.get_stack(stack_name, self.ctx, with_params=True) - tmpl_id = stk.t.store() + tmpl_id = stk.t.store(self.ctx) mock_load = self.patchobject(templatem.Template, 'load', return_value=stk.t) diff --git a/heat/tests/engine/service/test_stack_update.py b/heat/tests/engine/service/test_stack_update.py index 9bade28b47..745e1065b8 100644 --- a/heat/tests/engine/service/test_stack_update.py +++ b/heat/tests/engine/service/test_stack_update.py @@ -138,7 +138,7 @@ class ServiceStackUpdateTest(common.HeatTestCase): s = stack_object.Stack.get_by_id(self.ctx, sid) stk = tools.get_stack(stack_name, self.ctx) - tmpl_id = stk.t.store() + tmpl_id = stk.t.store(self.ctx) # prepare mocks mock_stack = self.patchobject(stack, 'Stack', return_value=stk) diff --git a/heat/tests/openstack/heat/test_waitcondition.py b/heat/tests/openstack/heat/test_waitcondition.py index 85108483a4..4c570cc2f4 100644 --- a/heat/tests/openstack/heat/test_waitcondition.py +++ b/heat/tests/openstack/heat/test_waitcondition.py @@ -177,7 +177,7 @@ class HeatWaitConditionTest(common.HeatTestCase): rsrc.state) r = resource_objects.Resource.get_by_name_and_stack( - None, 'wait_handle', self.stack.id) + self.stack.context, 'wait_handle', self.stack.id) self.assertEqual('wait_handle', r.name) self.m.VerifyAll() @@ -201,7 +201,7 @@ class HeatWaitConditionTest(common.HeatTestCase): self.assertTrue(reason.startswith('WaitConditionFailure:')) r = resource_objects.Resource.get_by_name_and_stack( - None, 'wait_handle', self.stack.id) + self.stack.context, 'wait_handle', self.stack.id) self.assertEqual('wait_handle', r.name) self.m.VerifyAll() diff --git a/heat/tests/test_resource.py b/heat/tests/test_resource.py index ca4200302e..a9cca9907e 100644 --- a/heat/tests/test_resource.py +++ b/heat/tests/test_resource.py @@ -1809,7 +1809,7 @@ class ResourceTest(common.HeatTestCase): 'test_res': {'Type': 'ResourceWithPropsType', 'Properties': {'Foo': 'abc'}} }}, env=self.env) - new_temp.store() + new_temp.store(stack.context) new_stack = parser.Stack(utils.dummy_context(), 'test_stack', new_temp, stack_id=self.stack.id) @@ -1843,7 +1843,7 @@ class ResourceTest(common.HeatTestCase): 'test_res': {'Type': 'ResourceWithPropsType', 'Properties': {'Foo': 'abc'}} }}, env=self.env) - new_temp.store() + new_temp.store(stack.context) new_stack = parser.Stack(utils.dummy_context(), 'test_stack', new_temp, stack_id=self.stack.id) @@ -1864,8 +1864,9 @@ class ResourceTest(common.HeatTestCase): 'test_res': {'Type': 'ResourceWithPropsType', 'Properties': {'Foo': 'abc'}} }}, env=self.env) - new_temp.store() - new_stack = parser.Stack(utils.dummy_context(), 'test_stack', + ctx = utils.dummy_context() + new_temp.store(ctx) + new_stack = parser.Stack(ctx, 'test_stack', new_temp, stack_id=self.stack.id) res_data = {} @@ -1917,7 +1918,7 @@ class ResourceTest(common.HeatTestCase): 'test_res': {'Type': 'ResourceWithPropsType', 'Properties': {'Foo': 'abc'}} }}, env=self.env) - new_temp.store() + new_temp.store(stack.context) res_data = {(1, True): {u'id': 4, u'name': 'A', 'attrs': {}}, (2, True): {u'id': 3, u'name': 'B', 'attrs': {}}} @@ -1959,7 +1960,7 @@ class ResourceTest(common.HeatTestCase): 'test_res': {'Type': 'ResourceWithPropsType', 'Properties': {'Foo': 'abc'}} }}, env=self.env) - new_temp.store() + new_temp.store(stack.context) res_data = {(1, True): {u'id': 4, u'name': 'A', 'attrs': {}}, (2, True): {u'id': 3, u'name': 'B', 'attrs': {}}} diff --git a/heat/tests/test_stack.py b/heat/tests/test_stack.py index 27770990dd..69c5e19199 100644 --- a/heat/tests/test_stack.py +++ b/heat/tests/test_stack.py @@ -196,7 +196,7 @@ class StackTest(common.HeatTestCase): def test_load_nonexistant_id(self): self.assertRaises(exception.NotFound, stack.Stack.load, - None, -1) + self.ctx, -1) def test_total_resources_empty(self): self.stack = stack.Stack(self.ctx, 'test_stack', self.tmpl, diff --git a/heat/tests/test_template_files.py b/heat/tests/test_template_files.py index 6a29b5c70a..c0c6b0b75f 100644 --- a/heat/tests/test_template_files.py +++ b/heat/tests/test_template_files.py @@ -12,6 +12,7 @@ from heat.engine import template_files from heat.tests import common +from heat.tests import utils template_files_1 = {'template file 1': 'Contents of template 1', 'template file 2': 'More template contents'} @@ -20,8 +21,9 @@ template_files_1 = {'template file 1': 'Contents of template 1', class TestTemplateFiles(common.HeatTestCase): def test_cache_miss(self): + ctx = utils.dummy_context() tf1 = template_files.TemplateFiles(template_files_1) - tf1.store() + tf1.store(ctx) # As this is the only reference to the value in _d, deleting # t1.files will cause the value to get removed from _d (due to # it being a WeakValueDictionary. @@ -33,8 +35,9 @@ class TestTemplateFiles(common.HeatTestCase): self.assertEqual(template_files_1, template_files._d[tf1.files_id]) def test_d_weakref_behaviour(self): + ctx = utils.dummy_context() tf1 = template_files.TemplateFiles(template_files_1) - tf1.store() + tf1.store(ctx) tf2 = template_files.TemplateFiles(tf1) del tf1.files self.assertIn(tf2.files_id, template_files._d) diff --git a/heat/tests/utils.py b/heat/tests/utils.py index 657f1c2c2e..5fa721521f 100644 --- a/heat/tests/utils.py +++ b/heat/tests/utils.py @@ -94,7 +94,7 @@ def parse_stack(t, params=None, files=None, stack_name=None, ctx = dummy_context() templ = template.Template(t, files=files, env=environment.Environment(params)) - templ.store() + templ.store(ctx) if stack_name is None: stack_name = random_name() stk = stack.Stack(ctx, stack_name, templ, stack_id=stack_id,