Merge "Prevent all primary keys in Neutron OVOs from being updated"
This commit is contained in:
commit
8ac3f269df
|
@ -12,6 +12,7 @@
|
||||||
|
|
||||||
import abc
|
import abc
|
||||||
import copy
|
import copy
|
||||||
|
import itertools
|
||||||
|
|
||||||
from neutron_lib import exceptions
|
from neutron_lib import exceptions
|
||||||
from oslo_db import exception as obj_exc
|
from oslo_db import exception as obj_exc
|
||||||
|
@ -107,6 +108,18 @@ class NeutronObject(obj_base.VersionedObject,
|
||||||
raise NotImplementedError()
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
|
||||||
|
class DeclarativeObject(abc.ABCMeta):
|
||||||
|
|
||||||
|
def __init__(cls, name, bases, dct):
|
||||||
|
super(DeclarativeObject, cls).__init__(name, bases, dct)
|
||||||
|
for base in itertools.chain([cls], bases):
|
||||||
|
if hasattr(base, 'primary_keys'):
|
||||||
|
cls.fields_no_update += base.primary_keys
|
||||||
|
# avoid duplicate entries
|
||||||
|
cls.fields_no_update = list(set(cls.fields_no_update))
|
||||||
|
|
||||||
|
|
||||||
|
@six.add_metaclass(DeclarativeObject)
|
||||||
class NeutronDbObject(NeutronObject):
|
class NeutronDbObject(NeutronObject):
|
||||||
|
|
||||||
# should be overridden for all persistent objects
|
# should be overridden for all persistent objects
|
||||||
|
@ -214,10 +227,6 @@ class NeutronDbObject(NeutronObject):
|
||||||
|
|
||||||
def _validate_changed_fields(self, fields):
|
def _validate_changed_fields(self, fields):
|
||||||
fields = fields.copy()
|
fields = fields.copy()
|
||||||
# We won't allow id update anyway, so let's pop it out not to trigger
|
|
||||||
# update on id field touched by the consumer
|
|
||||||
fields.pop('id', None)
|
|
||||||
|
|
||||||
forbidden_updates = set(self.fields_no_update) & set(fields.keys())
|
forbidden_updates = set(self.fields_no_update) & set(fields.keys())
|
||||||
if forbidden_updates:
|
if forbidden_updates:
|
||||||
raise NeutronObjectUpdateForbidden(fields=forbidden_updates)
|
raise NeutronObjectUpdateForbidden(fields=forbidden_updates)
|
||||||
|
|
|
@ -48,11 +48,14 @@ class QoSPlugin(qos.QoSPluginBase):
|
||||||
|
|
||||||
@db_base_plugin_common.convert_result_to_dict
|
@db_base_plugin_common.convert_result_to_dict
|
||||||
def update_policy(self, context, policy_id, policy):
|
def update_policy(self, context, policy_id, policy):
|
||||||
policy = policy_object.QosPolicy(context, **policy['policy'])
|
obj = policy_object.QosPolicy(context, id=policy_id)
|
||||||
policy.id = policy_id
|
obj.obj_reset_changes()
|
||||||
policy.update()
|
for k, v in policy['policy'].items():
|
||||||
self.notification_driver_manager.update_policy(context, policy)
|
if k != 'id':
|
||||||
return policy
|
setattr(obj, k, v)
|
||||||
|
obj.update()
|
||||||
|
self.notification_driver_manager.update_policy(context, obj)
|
||||||
|
return obj
|
||||||
|
|
||||||
def delete_policy(self, context, policy_id):
|
def delete_policy(self, context, policy_id):
|
||||||
policy = policy_object.QosPolicy(context)
|
policy = policy_object.QosPolicy(context)
|
||||||
|
@ -107,8 +110,11 @@ class QoSPlugin(qos.QoSPluginBase):
|
||||||
# check if the rule belong to the policy
|
# check if the rule belong to the policy
|
||||||
policy.get_rule_by_id(rule_id)
|
policy.get_rule_by_id(rule_id)
|
||||||
rule = rule_object.QosBandwidthLimitRule(
|
rule = rule_object.QosBandwidthLimitRule(
|
||||||
context, **bandwidth_limit_rule['bandwidth_limit_rule'])
|
context, id=rule_id)
|
||||||
rule.id = rule_id
|
rule.obj_reset_changes()
|
||||||
|
for k, v in bandwidth_limit_rule['bandwidth_limit_rule'].items():
|
||||||
|
if k != 'id':
|
||||||
|
setattr(rule, k, v)
|
||||||
rule.update()
|
rule.update()
|
||||||
policy.reload_rules()
|
policy.reload_rules()
|
||||||
self.notification_driver_manager.update_policy(context, policy)
|
self.notification_driver_manager.update_policy(context, policy)
|
||||||
|
|
|
@ -51,7 +51,9 @@ class FakeNeutronObject(base.NeutronDbObject):
|
||||||
'field2': obj_fields.StringField()
|
'field2': obj_fields.StringField()
|
||||||
}
|
}
|
||||||
|
|
||||||
fields_no_update = ['id']
|
primary_keys = ['id']
|
||||||
|
|
||||||
|
fields_no_update = ['field1']
|
||||||
|
|
||||||
synthetic_fields = ['field2']
|
synthetic_fields = ['field2']
|
||||||
|
|
||||||
|
@ -115,8 +117,6 @@ class FakeNeutronObjectRenamedField(base.NeutronDbObject):
|
||||||
|
|
||||||
synthetic_fields = ['field2']
|
synthetic_fields = ['field2']
|
||||||
|
|
||||||
fields_no_update = ['id']
|
|
||||||
|
|
||||||
fields_need_translation = {'field_ovo': 'field_db'}
|
fields_need_translation = {'field_ovo': 'field_db'}
|
||||||
|
|
||||||
|
|
||||||
|
@ -137,8 +137,6 @@ class FakeNeutronObjectCompositePrimaryKeyWithId(base.NeutronDbObject):
|
||||||
|
|
||||||
synthetic_fields = ['field2']
|
synthetic_fields = ['field2']
|
||||||
|
|
||||||
fields_no_update = ['id']
|
|
||||||
|
|
||||||
|
|
||||||
FIELD_TYPE_VALUE_GENERATOR_MAP = {
|
FIELD_TYPE_VALUE_GENERATOR_MAP = {
|
||||||
obj_fields.BooleanField: tools.get_random_boolean,
|
obj_fields.BooleanField: tools.get_random_boolean,
|
||||||
|
@ -333,6 +331,10 @@ class BaseObjectIfaceTestCase(_BaseObjectTestCase, test_base.BaseTestCase):
|
||||||
@mock.patch.object(obj_db_api, 'update_object')
|
@mock.patch.object(obj_db_api, 'update_object')
|
||||||
def test_update_changes(self, update_mock):
|
def test_update_changes(self, update_mock):
|
||||||
fields_to_update = self.get_updatable_fields(self.db_obj)
|
fields_to_update = self.get_updatable_fields(self.db_obj)
|
||||||
|
if not fields_to_update:
|
||||||
|
self.skipTest('No updatable fields found in test class %r' %
|
||||||
|
self._test_class)
|
||||||
|
|
||||||
with mock.patch.object(base.NeutronDbObject,
|
with mock.patch.object(base.NeutronDbObject,
|
||||||
'_get_changed_persistent_fields',
|
'_get_changed_persistent_fields',
|
||||||
return_value=fields_to_update):
|
return_value=fields_to_update):
|
||||||
|
@ -360,6 +362,9 @@ class BaseObjectIfaceTestCase(_BaseObjectTestCase, test_base.BaseTestCase):
|
||||||
return_value=self.db_obj):
|
return_value=self.db_obj):
|
||||||
obj = self._test_class(self.context, **self.obj_fields[1])
|
obj = self._test_class(self.context, **self.obj_fields[1])
|
||||||
fields_to_update = self.get_updatable_fields(self.obj_fields[1])
|
fields_to_update = self.get_updatable_fields(self.obj_fields[1])
|
||||||
|
if not fields_to_update:
|
||||||
|
self.skipTest('No updatable fields found in test class %r' %
|
||||||
|
self._test_class)
|
||||||
with mock.patch.object(base.NeutronDbObject,
|
with mock.patch.object(base.NeutronDbObject,
|
||||||
'_get_changed_persistent_fields',
|
'_get_changed_persistent_fields',
|
||||||
return_value=fields_to_update):
|
return_value=fields_to_update):
|
||||||
|
@ -383,6 +388,21 @@ class BaseObjectIfaceTestCase(_BaseObjectTestCase, test_base.BaseTestCase):
|
||||||
self.assertIs(expected_obj, observed_obj)
|
self.assertIs(expected_obj, observed_obj)
|
||||||
self.assertTrue(observed_obj.obj_reset_changes.called)
|
self.assertTrue(observed_obj.obj_reset_changes.called)
|
||||||
|
|
||||||
|
def test_update_primary_key_forbidden_fail(self):
|
||||||
|
obj = self._test_class(self.context, **self.db_obj)
|
||||||
|
obj.obj_reset_changes()
|
||||||
|
|
||||||
|
if not self._test_class.primary_keys:
|
||||||
|
self.skipTest(
|
||||||
|
'All non-updatable fields found in test class %r '
|
||||||
|
'are primary keys' % self._test_class)
|
||||||
|
|
||||||
|
for key, val in self.db_obj.items():
|
||||||
|
if key in self._test_class.primary_keys:
|
||||||
|
setattr(obj, key, val)
|
||||||
|
|
||||||
|
self.assertRaises(base.NeutronObjectUpdateForbidden, obj.update)
|
||||||
|
|
||||||
|
|
||||||
class BaseDbObjectNonStandardPrimaryKeyTestCase(BaseObjectIfaceTestCase):
|
class BaseDbObjectNonStandardPrimaryKeyTestCase(BaseObjectIfaceTestCase):
|
||||||
|
|
||||||
|
@ -455,7 +475,11 @@ class BaseDbObjectTestCase(_BaseObjectTestCase):
|
||||||
obj = self._test_class(self.context, **self.obj_fields[0])
|
obj = self._test_class(self.context, **self.obj_fields[0])
|
||||||
obj.obj_reset_changes()
|
obj.obj_reset_changes()
|
||||||
|
|
||||||
for key, val in self.get_updatable_fields(self.obj_fields[0]).items():
|
fields_to_update = self.get_updatable_fields(self.obj_fields[0])
|
||||||
|
if not fields_to_update:
|
||||||
|
self.skipTest('No updatable fields found in test class %r' %
|
||||||
|
self._test_class)
|
||||||
|
for key, val in fields_to_update.items():
|
||||||
setattr(obj, key, val)
|
setattr(obj, key, val)
|
||||||
|
|
||||||
self.assertRaises(n_exc.ObjectNotFound, obj.update)
|
self.assertRaises(n_exc.ObjectNotFound, obj.update)
|
||||||
|
@ -474,7 +498,11 @@ class BaseDbObjectTestCase(_BaseObjectTestCase):
|
||||||
obj = self._test_class(self.context, **self.obj_fields[0])
|
obj = self._test_class(self.context, **self.obj_fields[0])
|
||||||
obj.create()
|
obj.create()
|
||||||
|
|
||||||
for key, val in self.get_updatable_fields(self.obj_fields[1]).items():
|
fields_to_update = self.get_updatable_fields(self.obj_fields[1])
|
||||||
|
if not fields_to_update:
|
||||||
|
self.skipTest('No updatable fields found in test class %r' %
|
||||||
|
self._test_class)
|
||||||
|
for key, val in fields_to_update.items():
|
||||||
setattr(obj, key, val)
|
setattr(obj, key, val)
|
||||||
|
|
||||||
with mock.patch(SQLALCHEMY_COMMIT) as mock_commit:
|
with mock.patch(SQLALCHEMY_COMMIT) as mock_commit:
|
||||||
|
|
Loading…
Reference in New Issue