Refactor duplicated implementation of _get_policy_obj

This patch moves the implementation of the function
_get_policy_obj to the policy object code scope.

Change-Id: I7057558a5ec32a55a37a6b93dabc997d69abfb98
This commit is contained in:
LIU Yulong 2018-06-11 08:52:26 +08:00
parent abbd534fdf
commit a034e8e0f8
7 changed files with 51 additions and 48 deletions

View File

@ -32,12 +32,6 @@ class QosCoreResourceExtension(base.CoreResourceExtension):
plugin_constants.QOS in directory.get_plugins())
return self._plugin_loaded
def _get_policy_obj(self, context, policy_id):
obj = policy_object.QosPolicy.get_object(context, id=policy_id)
if obj is None:
raise n_exc.QosPolicyNotFound(policy_id=policy_id)
return obj
def _check_policy_change_permission(self, context, old_policy):
"""An existing policy can be modified only if one of the following is
true:
@ -59,7 +53,8 @@ class QosCoreResourceExtension(base.CoreResourceExtension):
qos_policy_id = port_changes.get(qos_consts.QOS_POLICY_ID)
if qos_policy_id is not None:
policy = self._get_policy_obj(context, qos_policy_id)
policy = policy_object.QosPolicy.get_policy_obj(
context, qos_policy_id)
policy.attach_port(port['id'])
port[qos_consts.QOS_POLICY_ID] = qos_policy_id
@ -72,7 +67,8 @@ class QosCoreResourceExtension(base.CoreResourceExtension):
qos_policy_id = policy_obj.qos_policy_id
if qos_policy_id is not None:
policy = self._get_policy_obj(context, qos_policy_id)
policy = policy_object.QosPolicy.get_policy_obj(
context, qos_policy_id)
policy.attach_network(network['id'])
network[qos_consts.QOS_POLICY_ID] = qos_policy_id
@ -85,7 +81,8 @@ class QosCoreResourceExtension(base.CoreResourceExtension):
qos_policy_id = network_changes.get(qos_consts.QOS_POLICY_ID)
if qos_policy_id is not None:
policy = self._get_policy_obj(context, qos_policy_id)
policy = policy_object.QosPolicy.get_policy_obj(
context, qos_policy_id)
policy.attach_network(network['id'])
network[qos_consts.QOS_POLICY_ID] = qos_policy_id

View File

@ -15,7 +15,6 @@
from neutron_lib.api.definitions import l3 as l3_apidef
from neutron_lib.services.qos import constants as qos_consts
from neutron.common import exceptions as n_exc
from neutron.db import _resource_extend as resource_extend
from neutron.objects.qos import policy as policy_object
@ -34,18 +33,12 @@ class FloatingQoSDbMixin(object):
fip_res[qos_consts.QOS_POLICY_ID] = None
return fip_res
def _get_policy_obj(self, context, policy_id):
obj = policy_object.QosPolicy.get_object(context, id=policy_id)
if obj is None:
raise n_exc.QosPolicyNotFound(policy_id=policy_id)
return obj
def _create_fip_qos_db(self, context, fip_id, policy_id):
policy = self._get_policy_obj(context, policy_id)
policy = policy_object.QosPolicy.get_policy_obj(context, policy_id)
policy.attach_floatingip(fip_id)
def _delete_fip_qos_db(self, context, fip_id, policy_id):
policy = self._get_policy_obj(context, policy_id)
policy = policy_object.QosPolicy.get_policy_obj(context, policy_id)
policy.detach_floatingip(fip_id)
def _process_extra_fip_qos_create(self, context, fip_id, fip):

View File

@ -125,6 +125,24 @@ class QosPolicy(rbac_db.NeutronRbacObject):
pass
return _dict
@classmethod
def get_policy_obj(cls, context, policy_id):
"""Fetch a QoS policy.
:param context: neutron api request context
:type context: neutron.context.Context
:param policy_id: the id of the QosPolicy to fetch
:type policy_id: str uuid
:returns: a QosPolicy object
:raises: n_exc.QosPolicyNotFound
"""
obj = cls.get_object(context, id=policy_id)
if obj is None:
raise exceptions.QosPolicyNotFound(policy_id=policy_id)
return obj
@classmethod
def get_object(cls, context, **kwargs):
# We want to get the policy regardless of its tenant id. We'll make

View File

@ -192,7 +192,8 @@ class QoSPlugin(qos.QoSPluginBase):
"""
policy_data = policy['policy']
with db_api.context_manager.writer.using(context):
policy_obj = self._get_policy_obj(context, policy_id)
policy_obj = policy_object.QosPolicy.get_policy_obj(
context, policy_id)
policy_obj.update_fields(policy_data, reset_changes=True)
policy_obj.update()
self.driver_manager.call(qos_consts.UPDATE_POLICY_PRECOMMIT,
@ -223,22 +224,6 @@ class QoSPlugin(qos.QoSPluginBase):
self.driver_manager.call(qos_consts.DELETE_POLICY,
context, policy)
def _get_policy_obj(self, context, policy_id):
"""Fetch a QoS policy.
:param context: neutron api request context
:type context: neutron.context.Context
:param policy_id: the id of the QosPolicy to fetch
:type policy_id: str uuid
:returns: a QosPolicy object
:raises: n_exc.QosPolicyNotFound
"""
obj = policy_object.QosPolicy.get_object(context, id=policy_id)
if obj is None:
raise n_exc.QosPolicyNotFound(policy_id=policy_id)
return obj
@db_base_plugin_common.filter_fields
@db_base_plugin_common.convert_result_to_dict
def get_policy(self, context, policy_id, fields=None):
@ -251,7 +236,7 @@ class QoSPlugin(qos.QoSPluginBase):
:returns: a QosPolicy object
"""
return self._get_policy_obj(context, policy_id)
return policy_object.QosPolicy.get_policy_obj(context, policy_id)
@db_base_plugin_common.filter_fields
@db_base_plugin_common.convert_result_to_dict
@ -314,7 +299,7 @@ class QoSPlugin(qos.QoSPluginBase):
with db_api.autonested_transaction(context.session):
# Ensure that we have access to the policy.
policy = self._get_policy_obj(context, policy_id)
policy = policy_object.QosPolicy.get_policy_obj(context, policy_id)
checker.check_bandwidth_rule_conflict(policy, rule_data)
rule = rule_cls(context, qos_policy_id=policy_id, **rule_data)
checker.check_rules_conflict(policy, rule)
@ -351,7 +336,7 @@ class QoSPlugin(qos.QoSPluginBase):
with db_api.autonested_transaction(context.session):
# Ensure we have access to the policy.
policy = self._get_policy_obj(context, policy_id)
policy = policy_object.QosPolicy.get_policy_obj(context, policy_id)
# Ensure the rule belongs to the policy.
checker.check_bandwidth_rule_conflict(policy, rule_data)
policy.get_rule_by_id(rule_id)
@ -384,7 +369,7 @@ class QoSPlugin(qos.QoSPluginBase):
"""
with db_api.autonested_transaction(context.session):
# Ensure we have access to the policy.
policy = self._get_policy_obj(context, policy_id)
policy = policy_object.QosPolicy.get_policy_obj(context, policy_id)
rule = policy.get_rule_by_id(rule_id)
rule.delete()
policy.obj_load_attr('rules')
@ -413,7 +398,7 @@ class QoSPlugin(qos.QoSPluginBase):
"""
with db_api.autonested_transaction(context.session):
# Ensure we have access to the policy.
self._get_policy_obj(context, policy_id)
policy_object.QosPolicy.get_policy_obj(context, policy_id)
rule = rule_cls.get_object(context, id=rule_id)
if not rule:
raise n_exc.QosRuleNotFound(policy_id=policy_id, rule_id=rule_id)
@ -438,7 +423,7 @@ class QoSPlugin(qos.QoSPluginBase):
"""
with db_api.autonested_transaction(context.session):
# Ensure we have access to the policy.
self._get_policy_obj(context, policy_id)
policy_object.QosPolicy.get_policy_obj(context, policy_id)
filters = filters or dict()
filters[qos_consts.QOS_POLICY_ID] = policy_id
pager = base_obj.Pager(sorts, limit, page_reverse, marker)

View File

@ -66,7 +66,7 @@ class QosCoreResourceExtensionTestCase(base.BaseTestCase):
actual_port = {'id': mock.Mock(),
qos_consts.QOS_POLICY_ID: qos_policy_id}
qos_policy = mock.MagicMock()
self.policy_m.get_object = mock.Mock(return_value=qos_policy)
self.policy_m.get_policy_obj = mock.Mock(return_value=qos_policy)
self.core_extension.process_fields(
self.context, base_core.PORT, base_core.EVENT_UPDATE,
{qos_consts.QOS_POLICY_ID: qos_policy_id},
@ -85,7 +85,8 @@ class QosCoreResourceExtensionTestCase(base.BaseTestCase):
self.policy_m.get_port_policy = mock.Mock(
return_value=old_qos_policy)
new_qos_policy = mock.MagicMock()
self.policy_m.get_object = mock.Mock(return_value=new_qos_policy)
self.policy_m.get_policy_obj = mock.Mock(
return_value=new_qos_policy)
self.core_extension.process_fields(
self.context, base_core.PORT, base_core.EVENT_UPDATE,
{qos_consts.QOS_POLICY_ID: qos_policy2_id},
@ -184,7 +185,7 @@ class QosCoreResourceExtensionTestCase(base.BaseTestCase):
actual_network = {'id': mock.Mock(),
qos_consts.QOS_POLICY_ID: qos_policy_id}
qos_policy = mock.MagicMock()
self.policy_m.get_object = mock.Mock(return_value=qos_policy)
self.policy_m.get_policy_obj = mock.Mock(return_value=qos_policy)
self.core_extension.process_fields(
self.context, base_core.NETWORK, base_core.EVENT_UPDATE,
{qos_consts.QOS_POLICY_ID: qos_policy_id}, actual_network)
@ -202,7 +203,8 @@ class QosCoreResourceExtensionTestCase(base.BaseTestCase):
self.policy_m.get_network_policy = mock.Mock(
return_value=old_qos_policy)
new_qos_policy = mock.MagicMock()
self.policy_m.get_object = mock.Mock(return_value=new_qos_policy)
self.policy_m.get_policy_obj = mock.Mock(
return_value=new_qos_policy)
self.core_extension.process_fields(
self.context, base_core.NETWORK, base_core.EVENT_UPDATE,
{qos_consts.QOS_POLICY_ID: qos_policy_id}, actual_network)
@ -262,7 +264,7 @@ class QosCoreResourceExtensionTestCase(base.BaseTestCase):
self.policy_m.get_network_policy = mock.Mock(
return_value=qos_policy_id)
qos_policy = mock.MagicMock()
self.policy_m.get_object = mock.Mock(return_value=qos_policy)
self.policy_m.get_policy_obj = mock.Mock(return_value=qos_policy)
self.core_extension.process_fields(
self.context, base_core.NETWORK, base_core.EVENT_CREATE,
actual_network, actual_network)
@ -279,7 +281,8 @@ class QosCoreResourceExtensionTestCase(base.BaseTestCase):
qos_policy = mock.MagicMock()
with mock.patch.object(policy.QosPolicyDefault, "get_object",
return_value=qos_policy_id) as mock_get_default_policy_id:
self.policy_m.get_object = mock.Mock(return_value=qos_policy)
self.policy_m.get_policy_obj = mock.Mock(
return_value=qos_policy)
self.core_extension.process_fields(
self.context, base_core.NETWORK, base_core.EVENT_CREATE,
actual_network, actual_network)

View File

@ -123,6 +123,12 @@ class QosPolicyObjectTestCase(test_base.BaseObjectIfaceTestCase):
(super(QosPolicyObjectTestCase, self).
test_to_dict_makes_primitive_field_value())
def test_get_policy_obj_not_found(self):
context = self.context.elevated()
self.assertRaises(n_exc.QosPolicyNotFound,
policy.QosPolicy.get_policy_obj,
context, "fake_id")
class QosPolicyDbObjectTestCase(test_base.BaseDbObjectTestCase,
testlib_api.SqlTestCase):

View File

@ -930,7 +930,8 @@ class TestQosPlugin(base.BaseQosTestCase):
# some actions get rule from policy
get_rule_mock_call = getattr(
mock.call.QosPolicy.get_object().get_rule_by_id(), action)()
mock.call.QosPolicy.get_policy_obj().get_rule_by_id(),
action)()
# some actions construct rule from class reference
rule_mock_call = getattr(mock.call.RuleCls(), action)()