objects.qos.policy: support per type rule lists as synthetic fields

This is a significant piece of work.

It enables neutron objects to define fields that are lazily loaded on
field access. To achieve that,

- field should be mentioned in cls.synthetic_fields
- obj_load_attr should be extended to lazily fetch and cache the field

Based on this work, we define per type rule fields that are lists of
appropriate neutron objects. (At the moment, we have only single type
supported, but I tried hard to make it easily extendable, with little or
no coding needed when a new rule type object definition is added to
rule.py: for example, we inspect object definitions based on
VALID_RULE_TYPES, and define appropriate fields for the policy object).

To implement lazy loading for those fields, I redefined get_by_id for
rules that now meld fields from both base and subtype db models into the
corresponding neutron object.

Added a simple test that checks bandwidth_rules attribute behaves for
policies.

Some objects unit test framework rework was needed to accomodate
synthetic fields that are not propagated to db layer.

Change-Id: Ia16393453b1ed48651fbd778bbe0ac6427560117
This commit is contained in:
Ihar Hrachyshka 2015-07-10 18:00:34 +02:00
parent a28769ff7e
commit 0a33e355bc
10 changed files with 183 additions and 30 deletions

View File

@ -470,3 +470,7 @@ class DeviceNotFoundError(NeutronException):
class NetworkSubnetPoolAffinityError(BadRequest):
message = _("Subnets hosted on the same network must be allocated from "
"the same subnet pool")
class ObjectActionError(NeutronException):
message = _('Object action %(action)s failed because: %(reason)s')

View File

@ -423,3 +423,7 @@ class DelayedStringRenderer(object):
def __str__(self):
return str(self.function(*self.args, **self.kwargs))
def camelize(s):
return ''.join(s.replace('_', ' ').title().split())

View File

@ -91,7 +91,7 @@ class convert_db_exception_to_retry(object):
# Common database operation implementations
# TODO(QoS): consider handling multiple objects found, or no objects at all
# TODO(QoS): consider reusing get_objects below
# TODO(QoS): consider changing the name and making it public, officially
def _find_object(context, model, **kwargs):
with context.session.begin(subtransactions=True):
@ -101,15 +101,18 @@ def _find_object(context, model, **kwargs):
def get_object(context, model, id):
# TODO(QoS): consider reusing get_objects below
with context.session.begin(subtransactions=True):
return (common_db_mixin.model_query(context, model)
.filter_by(id=id)
.first())
def get_objects(context, model):
def get_objects(context, model, **kwargs):
with context.session.begin(subtransactions=True):
return common_db_mixin.model_query(context, model).all()
return (common_db_mixin.model_query(context, model)
.filter_by(**kwargs)
.all())
def create_object(context, model, values):

View File

@ -32,6 +32,8 @@ class NeutronObject(obj_base.VersionedObject,
# fields that are not allowed to update
fields_no_update = []
synthetic_fields = []
def from_db_object(self, *objs):
for field in self.fields:
for db_obj in objs:
@ -53,21 +55,27 @@ class NeutronObject(obj_base.VersionedObject,
return obj
@classmethod
def get_objects(cls, context):
db_objs = db_api.get_objects(context, cls.db_model)
def get_objects(cls, context, **kwargs):
db_objs = db_api.get_objects(context, cls.db_model, **kwargs)
objs = [cls(context, **db_obj) for db_obj in db_objs]
for obj in objs:
obj.obj_reset_changes()
return objs
def create(self):
def _get_changed_persistent_fields(self):
fields = self.obj_get_changes()
for field in self.synthetic_fields:
if field in fields:
del fields[field]
return fields
def create(self):
fields = self._get_changed_persistent_fields()
db_obj = db_api.create_object(self._context, self.db_model, fields)
self.from_db_object(db_obj)
def update(self):
# TODO(QoS): enforce fields_no_update
updates = self.obj_get_changes()
updates = self._get_changed_persistent_fields()
if updates:
db_obj = db_api.update_object(self._context, self.db_model,
self.id, updates)

View File

@ -13,20 +13,41 @@
# License for the specific language governing permissions and limitations
# under the License.
import abc
from oslo_versionedobjects import base as obj_base
from oslo_versionedobjects import fields as obj_fields
import six
from neutron.common import exceptions
from neutron.common import utils
from neutron.db import api as db_api
from neutron.db.qos import api as qos_db_api
from neutron.db.qos import models as qos_db_model
from neutron.extensions import qos as qos_extension
from neutron.objects import base
from neutron.objects.qos import rule as rule_obj_impl
# TODO(QoS): add rule lists to object fields
# TODO(QoS): implement something for binding networks and ports with policies
class QosRulesExtenderMeta(abc.ABCMeta):
def __new__(cls, *args, **kwargs):
cls_ = super(QosRulesExtenderMeta, cls).__new__(cls, *args, **kwargs)
cls_.rule_fields = {}
for rule in qos_extension.VALID_RULE_TYPES:
rule_cls_name = 'Qos%sRule' % utils.camelize(rule)
field = '%s_rules' % rule
cls_.fields[field] = obj_fields.ListOfObjectsField(rule_cls_name)
cls_.rule_fields[field] = rule_cls_name
cls_.synthetic_fields = list(cls_.rule_fields.keys())
return cls_
@obj_base.VersionedObjectRegistry.register
@six.add_metaclass(QosRulesExtenderMeta)
class QosPolicy(base.NeutronObject):
db_model = qos_db_model.QosPolicy
@ -44,6 +65,16 @@ class QosPolicy(base.NeutronObject):
fields_no_update = ['id', 'tenant_id']
def obj_load_attr(self, attrname):
if attrname not in self.rule_fields:
raise exceptions.ObjectActionError(
action='obj_load_attr', reason='unable to load %s' % attrname)
rule_cls = getattr(rule_obj_impl, self.rule_fields[attrname])
rules = rule_cls.get_rules_by_policy(self._context, self.id)
setattr(self, attrname, rules)
self.obj_reset_changes([attrname])
@classmethod
def _get_object_policy(cls, context, model, **kwargs):
# TODO(QoS): we should make sure we use public functions

View File

@ -21,6 +21,7 @@ import six
from neutron.db import api as db_api
from neutron.db.qos import models as qos_db_model
from neutron.extensions import qos as qos_extension
from neutron.objects import base
@ -37,6 +38,9 @@ class QosRule(base.NeutronObject):
fields_no_update = ['id', 'tenant_id', 'qos_policy_id']
# each rule subclass should redefine it
rule_type = None
_core_fields = list(fields.keys())
_common_fields = ['id']
@ -60,8 +64,6 @@ class QosRule(base.NeutronObject):
if func(key)
}
# TODO(QoS): reimplement get_by_id to merge both core and addn fields
def _get_changed_core_fields(self):
fields = self.obj_get_changes()
return self._filter_fields(fields, self._is_core_field)
@ -75,9 +77,32 @@ class QosRule(base.NeutronObject):
for field in self._common_fields:
to_[field] = from_[field]
@classmethod
def get_objects(cls, context, **kwargs):
# TODO(QoS): support searching for subtype fields
db_objs = db_api.get_objects(context, cls.base_db_model, **kwargs)
return [cls.get_by_id(context, db_obj['id']) for db_obj in db_objs]
@classmethod
def get_by_id(cls, context, id):
obj = super(QosRule, cls).get_by_id(context, id)
if obj:
# the object above does not contain fields from base QosRule yet,
# so fetch it and mix its fields into the object
base_db_obj = db_api.get_object(context, cls.base_db_model, id)
for field in cls._core_fields:
setattr(obj, field, base_db_obj[field])
obj.obj_reset_changes()
return obj
# TODO(QoS): create and update are not transactional safe
def create(self):
# TODO(QoS): enforce that type field value is bound to specific class
self.type = self.rule_type
# create base qos_rule
core_fields = self._get_changed_core_fields()
base_db_obj = db_api.create_object(
@ -95,6 +120,8 @@ class QosRule(base.NeutronObject):
def update(self):
updated_db_objs = []
# TODO(QoS): enforce that type field cannot be changed
# update base qos_rule, if needed
core_fields = self._get_changed_core_fields()
if core_fields:
@ -113,13 +140,19 @@ class QosRule(base.NeutronObject):
# delete is the same, additional rule object cleanup is done thru cascading
@classmethod
def get_rules_by_policy(cls, context, policy_id):
return cls.get_objects(context, qos_policy_id=policy_id)
@obj_base.VersionedObjectRegistry.register
class QosBandwidthLimitRule(QosRule):
db_model = qos_db_model.QosBandwidthLimitRule
rule_type = qos_extension.RULE_TYPE_BANDWIDTH_LIMIT
fields = {
'max_kbps': obj_fields.IntegerField(),
'max_burst_kbps': obj_fields.IntegerField()
'max_kbps': obj_fields.IntegerField(nullable=True),
'max_burst_kbps': obj_fields.IntegerField(nullable=True)
}

View File

@ -663,3 +663,14 @@ class TestDelayedStringRenderer(base.BaseTestCase):
LOG.logger.setLevel(logging.logging.DEBUG)
LOG.debug("Hello %s", delayed)
self.assertTrue(my_func.called)
class TestCamelize(base.BaseTestCase):
def test_camelize(self):
data = {'bandwidth_limit': 'BandwidthLimit',
'test': 'Test',
'some__more__dashes': 'SomeMoreDashes',
'a_penguin_walks_into_a_bar': 'APenguinWalksIntoABar'}
for s, expected in data.items():
self.assertEqual(expected, utils.camelize(s))

View File

@ -13,6 +13,7 @@
from neutron.db import api as db_api
from neutron.db import models_v2
from neutron.objects.qos import policy
from neutron.objects.qos import rule
from neutron.tests.unit.objects import test_base
from neutron.tests.unit import testlib_api
@ -112,3 +113,18 @@ class QosPolicyDbObjectTestCase(QosPolicyBaseTestCase,
policy_obj = policy.QosPolicy.get_network_policy(self.context,
self._network['id'])
self.assertIsNone(policy_obj)
def test_synthetic_rule_fields(self):
obj = policy.QosPolicy(self.context, **self.db_obj)
obj.create()
rule_fields = self.get_random_fields(
obj_cls=rule.QosBandwidthLimitRule)
rule_fields['qos_policy_id'] = obj.id
rule_fields['tenant_id'] = obj.tenant_id
rule_obj = rule.QosBandwidthLimitRule(self.context, **rule_fields)
rule_obj.create()
obj = policy.QosPolicy.get_by_id(self.context, obj.id)
self.assertEqual([rule_obj], obj.bandwidth_limit_rules)

View File

@ -21,6 +21,15 @@ class QosBandwidthLimitPolicyObjectTestCase(test_base.BaseObjectIfaceTestCase):
_test_class = rule.QosBandwidthLimitRule
@classmethod
def get_random_fields(cls):
# object middleware should not allow random types, so override it with
# proper type
fields = (super(QosBandwidthLimitPolicyObjectTestCase, cls)
.get_random_fields())
fields['type'] = cls._test_class.rule_type
return fields
def _filter_db_object(self, func):
return {
field: self.db_obj[field]
@ -36,6 +45,36 @@ class QosBandwidthLimitPolicyObjectTestCase(test_base.BaseObjectIfaceTestCase):
return self._filter_db_object(
lambda field: self._test_class._is_addn_field(field))
def test_get_by_id(self):
with mock.patch.object(db_api, 'get_object',
return_value=self.db_obj) as get_object_mock:
obj = self._test_class.get_by_id(self.context, id='fake_id')
self.assertTrue(self._is_test_class(obj))
self.assertEqual(self.db_obj, test_base.get_obj_db_fields(obj))
get_object_mock.assert_has_calls([
mock.call(self.context, model, 'fake_id')
for model in (self._test_class.db_model,
self._test_class.base_db_model)
], any_order=True)
def test_get_objects(self):
with mock.patch.object(db_api, 'get_objects',
return_value=self.db_objs):
@classmethod
def _get_by_id(cls, context, id):
for db_obj in self.db_objs:
if db_obj['id'] == id:
return self._test_class(context, **db_obj)
with mock.patch.object(rule.QosRule, 'get_by_id', new=_get_by_id):
objs = self._test_class.get_objects(self.context)
self.assertFalse(
filter(lambda obj: not self._is_test_class(obj), objs))
self.assertEqual(
sorted(self.db_objs),
sorted(test_base.get_obj_db_fields(obj) for obj in objs))
def test_create(self):
with mock.patch.object(db_api, 'create_object',
return_value=self.db_obj) as create_mock:
@ -46,13 +85,13 @@ class QosBandwidthLimitPolicyObjectTestCase(test_base.BaseObjectIfaceTestCase):
self._check_equal(obj, self.db_obj)
core_db_obj = self._get_core_db_obj()
create_mock.assert_any_call(
self.context, self._test_class.base_db_model, core_db_obj)
addn_db_obj = self._get_addn_db_obj()
create_mock.assert_any_call(
self.context, self._test_class.db_model,
addn_db_obj)
create_mock.assert_has_calls(
[mock.call(self.context, self._test_class.base_db_model,
core_db_obj),
mock.call(self.context, self._test_class.db_model,
addn_db_obj)]
)
def test_update_changes(self):
with mock.patch.object(db_api, 'update_object',

View File

@ -52,11 +52,13 @@ FIELD_TYPE_VALUE_GENERATOR_MAP = {
obj_fields.IntegerField: _random_integer,
obj_fields.StringField: _random_string,
obj_fields.UUIDField: _random_string,
obj_fields.ListOfObjectsField: lambda: []
}
def get_obj_fields(obj):
return {field: getattr(obj, field) for field in obj.fields}
def get_obj_db_fields(obj):
return {field: getattr(obj, field) for field in obj.fields
if field not in obj.synthetic_fields}
class _BaseObjectTestCase(object):
@ -66,15 +68,17 @@ class _BaseObjectTestCase(object):
def setUp(self):
super(_BaseObjectTestCase, self).setUp()
self.context = context.get_admin_context()
self.db_objs = list(self._get_random_fields() for _ in range(3))
self.db_objs = list(self.get_random_fields() for _ in range(3))
self.db_obj = self.db_objs[0]
@classmethod
def _get_random_fields(cls):
def get_random_fields(cls, obj_cls=None):
obj_cls = obj_cls or cls._test_class
fields = {}
for field in cls._test_class.fields:
field_obj = cls._test_class.fields[field]
fields[field] = FIELD_TYPE_VALUE_GENERATOR_MAP[type(field_obj)]()
for field, field_obj in obj_cls.fields.items():
if field not in obj_cls.synthetic_fields:
generator = FIELD_TYPE_VALUE_GENERATOR_MAP[type(field_obj)]
fields[field] = generator()
return fields
@classmethod
@ -89,7 +93,7 @@ class BaseObjectIfaceTestCase(_BaseObjectTestCase, test_base.BaseTestCase):
return_value=self.db_obj) as get_object_mock:
obj = self._test_class.get_by_id(self.context, id='fake_id')
self.assertTrue(self._is_test_class(obj))
self.assertEqual(self.db_obj, get_obj_fields(obj))
self.assertEqual(self.db_obj, get_obj_db_fields(obj))
get_object_mock.assert_called_once_with(
self.context, self._test_class.db_model, 'fake_id')
@ -106,14 +110,14 @@ class BaseObjectIfaceTestCase(_BaseObjectTestCase, test_base.BaseTestCase):
filter(lambda obj: not self._is_test_class(obj), objs))
self.assertEqual(
sorted(self.db_objs),
sorted(get_obj_fields(obj) for obj in objs))
sorted(get_obj_db_fields(obj) for obj in objs))
get_objects_mock.assert_called_once_with(
self.context, self._test_class.db_model)
def _check_equal(self, obj, db_obj):
self.assertEqual(
sorted(db_obj),
sorted(get_obj_fields(obj)))
sorted(get_obj_db_fields(obj)))
def test_create(self):
with mock.patch.object(db_api, 'create_object',