diff --git a/heat/engine/sync_point.py b/heat/engine/sync_point.py index 8e18eb5a2a..b0f08adc45 100644 --- a/heat/engine/sync_point.py +++ b/heat/engine/sync_point.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +import ast from oslo_log import log as logging import six @@ -75,16 +76,47 @@ def update_input_data(context, entity_id, current_traversal, return rows_updated +def _str_pack_tuple(t): + return u'tuple:' + str(t) + + +def _str_unpack_tuple(s): + s = s[s.index(':') + 1:] + return ast.literal_eval(s) + + +def _deserialize(d): + d2 = {} + for k, v in d.items(): + if isinstance(k, six.string_types) and k.startswith(u'tuple:('): + k = _str_unpack_tuple(k) + if isinstance(v, dict): + v = _deserialize(v) + d2[k] = v + return d2 + + +def _serialize(d): + d2 = {} + for k, v in d.items(): + if isinstance(k, tuple): + k = _str_pack_tuple(k) + if isinstance(v, dict): + v = _serialize(v) + d2[k] = v + return d2 + + def deserialize_input_data(db_input_data): db_input_data = db_input_data.get('input_data') if not db_input_data: return {} - return {tuple(i): j for i, j in db_input_data} + return dict(_deserialize(db_input_data)) def serialize_input_data(input_data): - return {'input_data': [[list(i), j] for i, j in six.iteritems(input_data)]} + return {'input_data': _serialize(input_data)} def sync(cnxt, entity_id, current_traversal, is_update, propagate, @@ -95,7 +127,7 @@ def sync(cnxt, entity_id, current_traversal, is_update, propagate, while not rows_updated: # TODO(sirushtim): Add a conf option to add no. of retries sync_point = get(cnxt, entity_id, current_traversal, is_update) - input_data = dict(deserialize_input_data(sync_point.input_data)) + input_data = deserialize_input_data(sync_point.input_data) input_data.update(new_data) rows_updated = update_input_data( cnxt, entity_id, current_traversal, is_update, diff --git a/heat/tests/engine/test_sync_point.py b/heat/tests/engine/test_sync_point.py index 375d8103a0..f10aaf9b1e 100644 --- a/heat/tests/engine/test_sync_point.py +++ b/heat/tests/engine/test_sync_point.py @@ -65,4 +65,4 @@ class SyncPointTestCase(common.HeatTestCase): def test_serialize_input_data(self): res = sync_point.serialize_input_data({(3, 8): None}) - self.assertEqual({'input_data': [[[3, 8], None]]}, res) + self.assertEqual({'input_data': {u'tuple:(3, 8)': None}}, res) diff --git a/heat/tests/test_convg_stack.py b/heat/tests/test_convg_stack.py index 61c5a4f733..eecae5640f 100644 --- a/heat/tests/test_convg_stack.py +++ b/heat/tests/test_convg_stack.py @@ -70,7 +70,7 @@ class StackConvergenceCreateUpdateDeleteTest(common.HeatTestCase): expected_calls.append( mock.call.worker_client.WorkerClient.check_resource( stack.context, rsrc_id, stack.current_traversal, - {'input_data': []}, + {'input_data': {}}, is_update, None)) self.assertEqual(expected_calls, mock_cr.mock_calls) @@ -127,7 +127,7 @@ class StackConvergenceCreateUpdateDeleteTest(common.HeatTestCase): expected_calls.append( mock.call.worker_client.WorkerClient.check_resource( stack.context, rsrc_id, stack.current_traversal, - {'input_data': []}, + {'input_data': {}}, is_update, None)) self.assertEqual(expected_calls, mock_cr.mock_calls) @@ -265,7 +265,7 @@ class StackConvergenceCreateUpdateDeleteTest(common.HeatTestCase): expected_calls.append( mock.call.worker_client.WorkerClient.check_resource( stack.context, rsrc_id, stack.current_traversal, - {'input_data': []}, + {'input_data': {}}, is_update, None)) leaves = curr_stack.convergence_dependencies.leaves() @@ -273,7 +273,7 @@ class StackConvergenceCreateUpdateDeleteTest(common.HeatTestCase): expected_calls.append( mock.call.worker_client.WorkerClient.check_resource( curr_stack.context, rsrc_id, curr_stack.current_traversal, - {'input_data': []}, + {'input_data': {}}, is_update, None)) self.assertEqual(expected_calls, mock_cr.mock_calls) @@ -347,7 +347,7 @@ class StackConvergenceCreateUpdateDeleteTest(common.HeatTestCase): expected_calls.append( mock.call.worker_client.WorkerClient.check_resource( stack.context, rsrc_id, stack.current_traversal, - {'input_data': []}, + {'input_data': {}}, is_update, None)) leaves = curr_stack.convergence_dependencies.leaves() @@ -355,7 +355,7 @@ class StackConvergenceCreateUpdateDeleteTest(common.HeatTestCase): expected_calls.append( mock.call.worker_client.WorkerClient.check_resource( curr_stack.context, rsrc_id, curr_stack.current_traversal, - {'input_data': []}, + {'input_data': {}}, is_update, None)) self.assertEqual(expected_calls, mock_cr.mock_calls)