Replace to_dict() calls with a function decorator
Up until now, API server functions would need to return simple iterable objects, such as dicts and lists of dicts. This patch introduces a decorator which allows such functions to return non-simple objects (as long as the returned object implements the 'to_dict()' method, or is a list of such objects) and converts them on its own, simplifying the user's code and removing code duplication. Change-Id: Ib30a9213b86b33826291197cf01f00bc1dd3db52
This commit is contained in:
parent
b0fd13daef
commit
12ff4d6b58
|
@ -29,16 +29,30 @@ from neutron.db import models_v2
|
||||||
LOG = logging.getLogger(__name__)
|
LOG = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def convert_result_to_dict(f):
|
||||||
|
@functools.wraps(f)
|
||||||
|
def inner(*args, **kwargs):
|
||||||
|
result = f(*args, **kwargs)
|
||||||
|
|
||||||
|
if result is None:
|
||||||
|
return None
|
||||||
|
elif isinstance(result, list):
|
||||||
|
return [r.to_dict() for r in result]
|
||||||
|
else:
|
||||||
|
return result.to_dict()
|
||||||
|
return inner
|
||||||
|
|
||||||
|
|
||||||
def filter_fields(f):
|
def filter_fields(f):
|
||||||
@functools.wraps(f)
|
@functools.wraps(f)
|
||||||
def inner_filter(*args, **kwargs):
|
def inner_filter(*args, **kwargs):
|
||||||
result = f(*args, **kwargs)
|
result = f(*args, **kwargs)
|
||||||
fields = kwargs.get('fields')
|
fields = kwargs.get('fields')
|
||||||
if not fields:
|
if not fields:
|
||||||
pos = f.func_code.co_varnames.index('fields')
|
|
||||||
try:
|
try:
|
||||||
|
pos = f.func_code.co_varnames.index('fields')
|
||||||
fields = args[pos]
|
fields = args[pos]
|
||||||
except IndexError:
|
except (IndexError, ValueError):
|
||||||
return result
|
return result
|
||||||
|
|
||||||
do_filter = lambda d: {k: v for k, v in d.items() if k in fields}
|
do_filter = lambda d: {k: v for k, v in d.items() if k in fields}
|
||||||
|
|
|
@ -42,18 +42,20 @@ class QoSPlugin(qos.QoSPluginBase):
|
||||||
self.notification_driver_manager = (
|
self.notification_driver_manager = (
|
||||||
driver_mgr.QosServiceNotificationDriverManager())
|
driver_mgr.QosServiceNotificationDriverManager())
|
||||||
|
|
||||||
|
@db_base_plugin_common.convert_result_to_dict
|
||||||
def create_policy(self, context, policy):
|
def create_policy(self, context, policy):
|
||||||
policy = policy_object.QosPolicy(context, **policy['policy'])
|
policy = policy_object.QosPolicy(context, **policy['policy'])
|
||||||
policy.create()
|
policy.create()
|
||||||
self.notification_driver_manager.create_policy(policy)
|
self.notification_driver_manager.create_policy(policy)
|
||||||
return policy.to_dict()
|
return policy
|
||||||
|
|
||||||
|
@db_base_plugin_common.convert_result_to_dict
|
||||||
def update_policy(self, context, policy_id, policy):
|
def update_policy(self, context, policy_id, policy):
|
||||||
policy = policy_object.QosPolicy(context, **policy['policy'])
|
policy = policy_object.QosPolicy(context, **policy['policy'])
|
||||||
policy.id = policy_id
|
policy.id = policy_id
|
||||||
policy.update()
|
policy.update()
|
||||||
self.notification_driver_manager.update_policy(policy)
|
self.notification_driver_manager.update_policy(policy)
|
||||||
return policy.to_dict()
|
return policy
|
||||||
|
|
||||||
def delete_policy(self, context, policy_id):
|
def delete_policy(self, context, policy_id):
|
||||||
policy = policy_object.QosPolicy(context)
|
policy = policy_object.QosPolicy(context)
|
||||||
|
@ -68,21 +70,23 @@ class QoSPlugin(qos.QoSPluginBase):
|
||||||
return obj
|
return obj
|
||||||
|
|
||||||
@db_base_plugin_common.filter_fields
|
@db_base_plugin_common.filter_fields
|
||||||
|
@db_base_plugin_common.convert_result_to_dict
|
||||||
def get_policy(self, context, policy_id, fields=None):
|
def get_policy(self, context, policy_id, fields=None):
|
||||||
return self._get_policy_obj(context, policy_id).to_dict()
|
return self._get_policy_obj(context, policy_id)
|
||||||
|
|
||||||
@db_base_plugin_common.filter_fields
|
@db_base_plugin_common.filter_fields
|
||||||
|
@db_base_plugin_common.convert_result_to_dict
|
||||||
def get_policies(self, context, filters=None, fields=None,
|
def get_policies(self, context, filters=None, fields=None,
|
||||||
sorts=None, limit=None, marker=None,
|
sorts=None, limit=None, marker=None,
|
||||||
page_reverse=False):
|
page_reverse=False):
|
||||||
#TODO(QoS): Support all the optional parameters
|
#TODO(QoS): Support all the optional parameters
|
||||||
return [policy_obj.to_dict() for policy_obj in
|
return policy_object.QosPolicy.get_objects(context)
|
||||||
policy_object.QosPolicy.get_objects(context)]
|
|
||||||
|
|
||||||
#TODO(QoS): Consider adding a proxy catch-all for rules, so
|
#TODO(QoS): Consider adding a proxy catch-all for rules, so
|
||||||
# we capture the API function call, and just pass
|
# we capture the API function call, and just pass
|
||||||
# the rule type as a parameter removing lots of
|
# the rule type as a parameter removing lots of
|
||||||
# future code duplication when we have more rules.
|
# future code duplication when we have more rules.
|
||||||
|
@db_base_plugin_common.convert_result_to_dict
|
||||||
def create_policy_bandwidth_limit_rule(self, context, policy_id,
|
def create_policy_bandwidth_limit_rule(self, context, policy_id,
|
||||||
bandwidth_limit_rule):
|
bandwidth_limit_rule):
|
||||||
#TODO(QoS): avoid creation of severan bandwidth limit rules
|
#TODO(QoS): avoid creation of severan bandwidth limit rules
|
||||||
|
@ -96,8 +100,9 @@ class QoSPlugin(qos.QoSPluginBase):
|
||||||
**bandwidth_limit_rule['bandwidth_limit_rule'])
|
**bandwidth_limit_rule['bandwidth_limit_rule'])
|
||||||
rule.create()
|
rule.create()
|
||||||
self.notification_driver_manager.update_policy(policy)
|
self.notification_driver_manager.update_policy(policy)
|
||||||
return rule.to_dict()
|
return rule
|
||||||
|
|
||||||
|
@db_base_plugin_common.convert_result_to_dict
|
||||||
def update_policy_bandwidth_limit_rule(self, context, rule_id, policy_id,
|
def update_policy_bandwidth_limit_rule(self, context, rule_id, policy_id,
|
||||||
bandwidth_limit_rule):
|
bandwidth_limit_rule):
|
||||||
# validate that we have access to the policy
|
# validate that we have access to the policy
|
||||||
|
@ -107,7 +112,7 @@ class QoSPlugin(qos.QoSPluginBase):
|
||||||
rule.id = rule_id
|
rule.id = rule_id
|
||||||
rule.update()
|
rule.update()
|
||||||
self.notification_driver_manager.update_policy(policy)
|
self.notification_driver_manager.update_policy(policy)
|
||||||
return rule.to_dict()
|
return rule
|
||||||
|
|
||||||
def delete_policy_bandwidth_limit_rule(self, context, rule_id, policy_id):
|
def delete_policy_bandwidth_limit_rule(self, context, rule_id, policy_id):
|
||||||
# validate that we have access to the policy
|
# validate that we have access to the policy
|
||||||
|
@ -118,14 +123,16 @@ class QoSPlugin(qos.QoSPluginBase):
|
||||||
self.notification_driver_manager.update_policy(policy)
|
self.notification_driver_manager.update_policy(policy)
|
||||||
|
|
||||||
@db_base_plugin_common.filter_fields
|
@db_base_plugin_common.filter_fields
|
||||||
|
@db_base_plugin_common.convert_result_to_dict
|
||||||
def get_policy_bandwidth_limit_rule(self, context, rule_id,
|
def get_policy_bandwidth_limit_rule(self, context, rule_id,
|
||||||
policy_id, fields=None):
|
policy_id, fields=None):
|
||||||
# validate that we have access to the policy
|
# validate that we have access to the policy
|
||||||
self._get_policy_obj(context, policy_id)
|
self._get_policy_obj(context, policy_id)
|
||||||
return rule_object.QosBandwidthLimitRule.get_by_id(context,
|
return rule_object.QosBandwidthLimitRule.get_by_id(context,
|
||||||
rule_id).to_dict()
|
rule_id)
|
||||||
|
|
||||||
@db_base_plugin_common.filter_fields
|
@db_base_plugin_common.filter_fields
|
||||||
|
@db_base_plugin_common.convert_result_to_dict
|
||||||
def get_policy_bandwidth_limit_rules(self, context, policy_id,
|
def get_policy_bandwidth_limit_rules(self, context, policy_id,
|
||||||
filters=None, fields=None,
|
filters=None, fields=None,
|
||||||
sorts=None, limit=None,
|
sorts=None, limit=None,
|
||||||
|
@ -133,12 +140,11 @@ class QoSPlugin(qos.QoSPluginBase):
|
||||||
#TODO(QoS): Support all the optional parameters
|
#TODO(QoS): Support all the optional parameters
|
||||||
# validate that we have access to the policy
|
# validate that we have access to the policy
|
||||||
self._get_policy_obj(context, policy_id)
|
self._get_policy_obj(context, policy_id)
|
||||||
return [rule_obj.to_dict() for rule_obj in
|
return rule_object.QosBandwidthLimitRule.get_objects(context)
|
||||||
rule_object.QosBandwidthLimitRule.get_objects(context)]
|
|
||||||
|
|
||||||
@db_base_plugin_common.filter_fields
|
@db_base_plugin_common.filter_fields
|
||||||
|
@db_base_plugin_common.convert_result_to_dict
|
||||||
def get_rule_types(self, context, filters=None, fields=None,
|
def get_rule_types(self, context, filters=None, fields=None,
|
||||||
sorts=None, limit=None,
|
sorts=None, limit=None,
|
||||||
marker=None, page_reverse=False):
|
marker=None, page_reverse=False):
|
||||||
return [rule_type_obj.to_dict() for rule_type_obj in
|
return rule_type_object.QosRuleType.get_objects()
|
||||||
rule_type_object.QosRuleType.get_objects()]
|
|
||||||
|
|
|
@ -17,6 +17,35 @@ from neutron.db import db_base_plugin_common
|
||||||
from neutron.tests import base
|
from neutron.tests import base
|
||||||
|
|
||||||
|
|
||||||
|
class DummyObject(object):
|
||||||
|
def __init__(self, **kwargs):
|
||||||
|
self.kwargs = kwargs
|
||||||
|
|
||||||
|
def to_dict(self):
|
||||||
|
return self.kwargs
|
||||||
|
|
||||||
|
|
||||||
|
class ConvertToDictTestCase(base.BaseTestCase):
|
||||||
|
|
||||||
|
@db_base_plugin_common.convert_result_to_dict
|
||||||
|
def method_dict(self, fields=None):
|
||||||
|
return DummyObject(one=1, two=2, three=3)
|
||||||
|
|
||||||
|
@db_base_plugin_common.convert_result_to_dict
|
||||||
|
def method_list(self):
|
||||||
|
return [DummyObject(one=1, two=2, three=3)] * 3
|
||||||
|
|
||||||
|
def test_simple_object(self):
|
||||||
|
expected = {'one': 1, 'two': 2, 'three': 3}
|
||||||
|
observed = self.method_dict()
|
||||||
|
self.assertEqual(expected, observed)
|
||||||
|
|
||||||
|
def test_list_of_objects(self):
|
||||||
|
expected = [{'one': 1, 'two': 2, 'three': 3}] * 3
|
||||||
|
observed = self.method_list()
|
||||||
|
self.assertEqual(expected, observed)
|
||||||
|
|
||||||
|
|
||||||
class FilterFieldsTestCase(base.BaseTestCase):
|
class FilterFieldsTestCase(base.BaseTestCase):
|
||||||
|
|
||||||
@db_base_plugin_common.filter_fields
|
@db_base_plugin_common.filter_fields
|
||||||
|
|
Loading…
Reference in New Issue