diff --git a/neutron/objects/qos/binding.py b/neutron/objects/qos/binding.py index 71034595b28..b0e63549f1f 100644 --- a/neutron/objects/qos/binding.py +++ b/neutron/objects/qos/binding.py @@ -13,6 +13,8 @@ # License for the specific language governing permissions and limitations # under the License. +import abc + from neutron_lib.objects import common_types from sqlalchemy import and_ from sqlalchemy import exists @@ -22,8 +24,21 @@ from neutron.db.qos import models as qos_db_model from neutron.objects import base +class _QosPolicyBindingMixin(object, metaclass=abc.ABCMeta): + + _bound_model_id = None + + @classmethod + def get_bound_ids(cls, context, policy_id): + if not cls._bound_model_id: + return [] + + return cls.get_values(context, cls._bound_model_id.name, + policy_id=policy_id) + + @base.NeutronObjectRegistry.register -class QosPolicyPortBinding(base.NeutronDbObject): +class QosPolicyPortBinding(base.NeutronDbObject, _QosPolicyBindingMixin): # Version 1.0: Initial version VERSION = '1.0' @@ -36,6 +51,7 @@ class QosPolicyPortBinding(base.NeutronDbObject): primary_keys = ['port_id'] fields_no_update = ['policy_id', 'port_id'] + _bound_model_id = db_model.port_id @classmethod def get_ports_by_network_id(cls, context, network_id, policy_id=None): @@ -53,7 +69,7 @@ class QosPolicyPortBinding(base.NeutronDbObject): @base.NeutronObjectRegistry.register -class QosPolicyNetworkBinding(base.NeutronDbObject): +class QosPolicyNetworkBinding(base.NeutronDbObject, _QosPolicyBindingMixin): # Version 1.0: Initial version VERSION = '1.0' @@ -66,10 +82,11 @@ class QosPolicyNetworkBinding(base.NeutronDbObject): primary_keys = ['network_id'] fields_no_update = ['policy_id', 'network_id'] + _bound_model_id = db_model.network_id @base.NeutronObjectRegistry.register -class QosPolicyFloatingIPBinding(base.NeutronDbObject): +class QosPolicyFloatingIPBinding(base.NeutronDbObject, _QosPolicyBindingMixin): # Version 1.0: Initial version VERSION = '1.0' @@ -82,10 +99,12 @@ class QosPolicyFloatingIPBinding(base.NeutronDbObject): primary_keys = ['policy_id', 'fip_id'] fields_no_update = ['policy_id', 'fip_id'] + _bound_model_id = db_model.fip_id @base.NeutronObjectRegistry.register -class QosPolicyRouterGatewayIPBinding(base.NeutronDbObject): +class QosPolicyRouterGatewayIPBinding(base.NeutronDbObject, + _QosPolicyBindingMixin): # Version 1.0: Initial version VERSION = '1.0' @@ -98,3 +117,4 @@ class QosPolicyRouterGatewayIPBinding(base.NeutronDbObject): primary_keys = ['policy_id', 'router_id'] fields_no_update = ['policy_id', 'router_id'] + _bound_model_id = db_model.router_id diff --git a/neutron/objects/qos/policy.py b/neutron/objects/qos/policy.py index 4a2f206ef85..14c22feb464 100644 --- a/neutron/objects/qos/policy.py +++ b/neutron/objects/qos/policy.py @@ -327,32 +327,20 @@ class QosPolicy(rbac_db.NeutronRbacObject): return qos_default_policy.qos_policy_id def get_bound_networks(self): - return [ - nb.network_id - for nb in binding.QosPolicyNetworkBinding.get_objects( - self.obj_context, policy_id=self.id) - ] + return binding.QosPolicyNetworkBinding.get_bound_ids(self.obj_context, + self.id) def get_bound_ports(self): - return [ - pb.port_id - for pb in binding.QosPolicyPortBinding.get_objects( - self.obj_context, policy_id=self.id) - ] + return binding.QosPolicyPortBinding.get_bound_ids(self.obj_context, + self.id) def get_bound_floatingips(self): - return [ - fb.fip_id - for fb in binding.QosPolicyFloatingIPBinding.get_objects( - self.obj_context, policy_id=self.id) - ] + return binding.QosPolicyFloatingIPBinding.get_objects(self.obj_context, + self.id) def get_bound_routers(self): - return [ - rb.router_id - for rb in binding.QosPolicyRouterGatewayIPBinding.get_objects( - self.obj_context, policy_id=self.id) - ] + return binding.QosPolicyRouterGatewayIPBinding.get_objects( + self.obj_context, self.id) @classmethod def _get_bound_tenant_ids(cls, session, binding_db, bound_db, diff --git a/neutron/tests/unit/objects/qos/test_binding.py b/neutron/tests/unit/objects/qos/test_binding.py index c8e20762ea2..b8be633fcb5 100644 --- a/neutron/tests/unit/objects/qos/test_binding.py +++ b/neutron/tests/unit/objects/qos/test_binding.py @@ -15,13 +15,25 @@ from neutron.tests.unit.objects import test_base from neutron.tests.unit import testlib_api +class _QosPolicyBindingMixinTestCase(object): + + def test_get_bound_ids(self): + [obj.create() for obj in self.objs] + for obj in self.objs: + obj_ids = obj.get_bound_ids(self.context, obj.policy_id) + self.assertEqual(1, len(obj_ids)) + self.assertEqual(obj[obj.__class__._bound_model_id.name], + obj_ids[0]) + + class QosPolicyPortBindingObjectTestCase(test_base.BaseObjectIfaceTestCase): _test_class = binding.QosPolicyPortBinding class QosPolicyPortBindingDbObjectTestCase(test_base.BaseDbObjectTestCase, - testlib_api.SqlTestCase): + testlib_api.SqlTestCase, + _QosPolicyBindingMixinTestCase): _test_class = binding.QosPolicyPortBinding @@ -40,7 +52,8 @@ class QosPolicyNetworkBindingObjectTestCase(test_base.BaseObjectIfaceTestCase): class QosPolicyNetworkBindingDbObjectTestCase(test_base.BaseDbObjectTestCase, - testlib_api.SqlTestCase): + testlib_api.SqlTestCase, + _QosPolicyBindingMixinTestCase): _test_class = binding.QosPolicyNetworkBinding @@ -59,7 +72,8 @@ class QosPolicyFloatingIPBindingObjectTestCase( class QosPolicyFloatingIPBindingDbObjectTestCase( test_base.BaseDbObjectTestCase, - testlib_api.SqlTestCase): + testlib_api.SqlTestCase, + _QosPolicyBindingMixinTestCase): _test_class = binding.QosPolicyFloatingIPBinding @@ -78,7 +92,8 @@ class QosPolicyRouterGatewayIPBindingObjectTestCase( class QosPolicyRouterGatewayIPBindingDbObjectTestCase( test_base.BaseDbObjectTestCase, - testlib_api.SqlTestCase): + testlib_api.SqlTestCase, + _QosPolicyBindingMixinTestCase): _test_class = binding.QosPolicyRouterGatewayIPBinding