Replace to_dict() calls with a function decorator
Up until now, API server functions would need to return simple iterable objects, such as dicts and lists of dicts. This patch introduces a decorator which allows such functions to return non-simple objects (as long as the returned object implements the 'to_dict()' method, or is a list of such objects) and converts them on its own, simplifying the user's code and removing code duplication. Change-Id: Ib30a9213b86b33826291197cf01f00bc1dd3db52
This commit is contained in:
parent
b0fd13daef
commit
12ff4d6b58
@ -29,16 +29,30 @@ from neutron.db import models_v2
|
||||
LOG = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def convert_result_to_dict(f):
|
||||
@functools.wraps(f)
|
||||
def inner(*args, **kwargs):
|
||||
result = f(*args, **kwargs)
|
||||
|
||||
if result is None:
|
||||
return None
|
||||
elif isinstance(result, list):
|
||||
return [r.to_dict() for r in result]
|
||||
else:
|
||||
return result.to_dict()
|
||||
return inner
|
||||
|
||||
|
||||
def filter_fields(f):
|
||||
@functools.wraps(f)
|
||||
def inner_filter(*args, **kwargs):
|
||||
result = f(*args, **kwargs)
|
||||
fields = kwargs.get('fields')
|
||||
if not fields:
|
||||
pos = f.func_code.co_varnames.index('fields')
|
||||
try:
|
||||
pos = f.func_code.co_varnames.index('fields')
|
||||
fields = args[pos]
|
||||
except IndexError:
|
||||
except (IndexError, ValueError):
|
||||
return result
|
||||
|
||||
do_filter = lambda d: {k: v for k, v in d.items() if k in fields}
|
||||
|
@ -42,18 +42,20 @@ class QoSPlugin(qos.QoSPluginBase):
|
||||
self.notification_driver_manager = (
|
||||
driver_mgr.QosServiceNotificationDriverManager())
|
||||
|
||||
@db_base_plugin_common.convert_result_to_dict
|
||||
def create_policy(self, context, policy):
|
||||
policy = policy_object.QosPolicy(context, **policy['policy'])
|
||||
policy.create()
|
||||
self.notification_driver_manager.create_policy(policy)
|
||||
return policy.to_dict()
|
||||
return policy
|
||||
|
||||
@db_base_plugin_common.convert_result_to_dict
|
||||
def update_policy(self, context, policy_id, policy):
|
||||
policy = policy_object.QosPolicy(context, **policy['policy'])
|
||||
policy.id = policy_id
|
||||
policy.update()
|
||||
self.notification_driver_manager.update_policy(policy)
|
||||
return policy.to_dict()
|
||||
return policy
|
||||
|
||||
def delete_policy(self, context, policy_id):
|
||||
policy = policy_object.QosPolicy(context)
|
||||
@ -68,21 +70,23 @@ class QoSPlugin(qos.QoSPluginBase):
|
||||
return obj
|
||||
|
||||
@db_base_plugin_common.filter_fields
|
||||
@db_base_plugin_common.convert_result_to_dict
|
||||
def get_policy(self, context, policy_id, fields=None):
|
||||
return self._get_policy_obj(context, policy_id).to_dict()
|
||||
return self._get_policy_obj(context, policy_id)
|
||||
|
||||
@db_base_plugin_common.filter_fields
|
||||
@db_base_plugin_common.convert_result_to_dict
|
||||
def get_policies(self, context, filters=None, fields=None,
|
||||
sorts=None, limit=None, marker=None,
|
||||
page_reverse=False):
|
||||
#TODO(QoS): Support all the optional parameters
|
||||
return [policy_obj.to_dict() for policy_obj in
|
||||
policy_object.QosPolicy.get_objects(context)]
|
||||
return policy_object.QosPolicy.get_objects(context)
|
||||
|
||||
#TODO(QoS): Consider adding a proxy catch-all for rules, so
|
||||
# we capture the API function call, and just pass
|
||||
# the rule type as a parameter removing lots of
|
||||
# future code duplication when we have more rules.
|
||||
@db_base_plugin_common.convert_result_to_dict
|
||||
def create_policy_bandwidth_limit_rule(self, context, policy_id,
|
||||
bandwidth_limit_rule):
|
||||
#TODO(QoS): avoid creation of severan bandwidth limit rules
|
||||
@ -96,8 +100,9 @@ class QoSPlugin(qos.QoSPluginBase):
|
||||
**bandwidth_limit_rule['bandwidth_limit_rule'])
|
||||
rule.create()
|
||||
self.notification_driver_manager.update_policy(policy)
|
||||
return rule.to_dict()
|
||||
return rule
|
||||
|
||||
@db_base_plugin_common.convert_result_to_dict
|
||||
def update_policy_bandwidth_limit_rule(self, context, rule_id, policy_id,
|
||||
bandwidth_limit_rule):
|
||||
# validate that we have access to the policy
|
||||
@ -107,7 +112,7 @@ class QoSPlugin(qos.QoSPluginBase):
|
||||
rule.id = rule_id
|
||||
rule.update()
|
||||
self.notification_driver_manager.update_policy(policy)
|
||||
return rule.to_dict()
|
||||
return rule
|
||||
|
||||
def delete_policy_bandwidth_limit_rule(self, context, rule_id, policy_id):
|
||||
# validate that we have access to the policy
|
||||
@ -118,14 +123,16 @@ class QoSPlugin(qos.QoSPluginBase):
|
||||
self.notification_driver_manager.update_policy(policy)
|
||||
|
||||
@db_base_plugin_common.filter_fields
|
||||
@db_base_plugin_common.convert_result_to_dict
|
||||
def get_policy_bandwidth_limit_rule(self, context, rule_id,
|
||||
policy_id, fields=None):
|
||||
# validate that we have access to the policy
|
||||
self._get_policy_obj(context, policy_id)
|
||||
return rule_object.QosBandwidthLimitRule.get_by_id(context,
|
||||
rule_id).to_dict()
|
||||
rule_id)
|
||||
|
||||
@db_base_plugin_common.filter_fields
|
||||
@db_base_plugin_common.convert_result_to_dict
|
||||
def get_policy_bandwidth_limit_rules(self, context, policy_id,
|
||||
filters=None, fields=None,
|
||||
sorts=None, limit=None,
|
||||
@ -133,12 +140,11 @@ class QoSPlugin(qos.QoSPluginBase):
|
||||
#TODO(QoS): Support all the optional parameters
|
||||
# validate that we have access to the policy
|
||||
self._get_policy_obj(context, policy_id)
|
||||
return [rule_obj.to_dict() for rule_obj in
|
||||
rule_object.QosBandwidthLimitRule.get_objects(context)]
|
||||
return rule_object.QosBandwidthLimitRule.get_objects(context)
|
||||
|
||||
@db_base_plugin_common.filter_fields
|
||||
@db_base_plugin_common.convert_result_to_dict
|
||||
def get_rule_types(self, context, filters=None, fields=None,
|
||||
sorts=None, limit=None,
|
||||
marker=None, page_reverse=False):
|
||||
return [rule_type_obj.to_dict() for rule_type_obj in
|
||||
rule_type_object.QosRuleType.get_objects()]
|
||||
return rule_type_object.QosRuleType.get_objects()
|
||||
|
@ -17,6 +17,35 @@ from neutron.db import db_base_plugin_common
|
||||
from neutron.tests import base
|
||||
|
||||
|
||||
class DummyObject(object):
|
||||
def __init__(self, **kwargs):
|
||||
self.kwargs = kwargs
|
||||
|
||||
def to_dict(self):
|
||||
return self.kwargs
|
||||
|
||||
|
||||
class ConvertToDictTestCase(base.BaseTestCase):
|
||||
|
||||
@db_base_plugin_common.convert_result_to_dict
|
||||
def method_dict(self, fields=None):
|
||||
return DummyObject(one=1, two=2, three=3)
|
||||
|
||||
@db_base_plugin_common.convert_result_to_dict
|
||||
def method_list(self):
|
||||
return [DummyObject(one=1, two=2, three=3)] * 3
|
||||
|
||||
def test_simple_object(self):
|
||||
expected = {'one': 1, 'two': 2, 'three': 3}
|
||||
observed = self.method_dict()
|
||||
self.assertEqual(expected, observed)
|
||||
|
||||
def test_list_of_objects(self):
|
||||
expected = [{'one': 1, 'two': 2, 'three': 3}] * 3
|
||||
observed = self.method_list()
|
||||
self.assertEqual(expected, observed)
|
||||
|
||||
|
||||
class FilterFieldsTestCase(base.BaseTestCase):
|
||||
|
||||
@db_base_plugin_common.filter_fields
|
||||
|
Loading…
x
Reference in New Issue
Block a user