diff --git a/heat/db/sqlalchemy/api.py b/heat/db/sqlalchemy/api.py index 36a3b837d..e3226121d 100644 --- a/heat/db/sqlalchemy/api.py +++ b/heat/db/sqlalchemy/api.py @@ -1125,57 +1125,60 @@ def db_version(engine): def db_encrypt_parameters_and_properties(ctxt, encryption_key): from heat.engine import template session = get_session() - session.begin() + with session.begin(): - raw_templates = session.query(models.RawTemplate).all() + raw_templates = session.query(models.RawTemplate).all() - for raw_template in raw_templates: - tmpl = template.Template.load(ctxt, raw_template.id, raw_template) + for raw_template in raw_templates: + tmpl = template.Template.load(ctxt, raw_template.id, raw_template) - encrypted_params = [] - for param_name, param in tmpl.param_schemata().items(): - if (param_name in encrypted_params) or (not param.hidden): - continue + encrypted_params = [] + for param_name, param in tmpl.param_schemata().items(): + if (param_name in encrypted_params) or (not param.hidden): + continue - try: - param_val = raw_template.environment['parameters'][ - param_name] - except KeyError: - param_val = param.default + try: + param_val = raw_template.environment['parameters'][ + param_name] + except KeyError: + param_val = param.default - encoded_val = encodeutils.safe_encode(param_val) - encrypted_val = crypt.encrypt(encoded_val, encryption_key) - raw_template.environment['parameters'][param_name] = \ - encrypted_val - encrypted_params.append(param_name) + encoded_val = encodeutils.safe_encode(param_val) + encrypted_val = crypt.encrypt(encoded_val, encryption_key) + raw_template.environment['parameters'][param_name] = \ + encrypted_val + encrypted_params.append(param_name) - raw_template.environment['encrypted_param_names'] = \ - encrypted_params - - session.commit() + if encrypted_params: + environment = raw_template.environment.copy() + environment['encrypted_param_names'] = encrypted_params + raw_template_update(ctxt, raw_template.id, + {'environment': environment}) def db_decrypt_parameters_and_properties(ctxt, encryption_key): session = get_session() - session.begin() - raw_templates = session.query(models.RawTemplate).all() - for raw_template in raw_templates: - parameters = raw_template.environment['parameters'] - encrypted_params = raw_template.environment[ - 'encrypted_param_names'] - for param_name in encrypted_params: - decrypt_function_name = parameters[param_name][0] - decrypt_function = getattr(crypt, decrypt_function_name) - decrypted_val = decrypt_function(parameters[param_name][1], - encryption_key) - try: - parameters[param_name] = encodeutils.safe_decode(decrypted_val) - except UnicodeDecodeError: - # if the incorrect encryption_key was used then we can - # get total gibberish here and safe_decode() will freak out. - parameters[param_name] = decrypted_val + with session.begin(): + raw_templates = session.query(models.RawTemplate).all() - raw_template.environment['encrypted_param_names'] = [] - - session.commit() + for raw_template in raw_templates: + parameters = raw_template.environment['parameters'] + encrypted_params = raw_template.environment[ + 'encrypted_param_names'] + for param_name in encrypted_params: + decrypt_function_name = parameters[param_name][0] + decrypt_function = getattr(crypt, decrypt_function_name) + decrypted_val = decrypt_function(parameters[param_name][1], + encryption_key) + try: + parameters[param_name] = encodeutils.safe_decode( + decrypted_val) + except UnicodeDecodeError: + # if the incorrect encryption_key was used then we can get + # total gibberish here and safe_decode() will freak out. + parameters[param_name] = decrypted_val + environment = raw_template.environment.copy() + environment['encrypted_param_names'] = [] + raw_template_update(ctxt, raw_template.id, + {'environment': environment}) diff --git a/heat/db/sqlalchemy/types.py b/heat/db/sqlalchemy/types.py index 5d6ff17ec..723d0bc4e 100644 --- a/heat/db/sqlalchemy/types.py +++ b/heat/db/sqlalchemy/types.py @@ -13,7 +13,6 @@ from oslo_serialization import jsonutils from sqlalchemy.dialects import mysql -from sqlalchemy.ext import mutable from sqlalchemy import types @@ -42,42 +41,6 @@ class Json(LongText): return loads(value) -class MutableList(mutable.Mutable, list): - @classmethod - def coerce(cls, key, value): - if not isinstance(value, cls): - if isinstance(value, list): - return cls(value) - return mutable.Mutable.coerce(key, value) - else: - return value - - def __delitem__(self, key): - list.__delitem__(self, key) - self.changed() - - def __setitem__(self, key, value): - list.__setitem__(self, key, value) - self.changed() - - def __getstate__(self): - return list(self) - - def __setstate__(self, state): - len = list.__len__(self) - list.__delslice__(self, 0, len) - list.__add__(self, state) - self.changed() - - def append(self, value): - list.append(self, value) - self.changed() - - def remove(self, value): - list.remove(self, value) - self.changed() - - class List(types.TypeDecorator): impl = types.Text @@ -94,8 +57,3 @@ class List(types.TypeDecorator): if value is None: return None return loads(value) - - -MutableList.associate_with(List) -mutable.MutableDict.associate_with(LongText) -mutable.MutableDict.associate_with(Json)