Replace to_dict() calls with a function decorator

Up until now, API server functions would need to return simple iterable
objects, such as dicts and lists of dicts. This patch introduces a
decorator which allows such functions to return non-simple objects (as
long as the returned object implements the 'to_dict()' method, or is a
list of such objects) and converts them on its own, simplifying the
user's code and removing code duplication.

Change-Id: Ib30a9213b86b33826291197cf01f00bc1dd3db52
This commit is contained in:
John Schwarz 2015-07-27 12:09:10 +03:00
parent b0fd13daef
commit 12ff4d6b58
3 changed files with 63 additions and 14 deletions

View File

@ -29,16 +29,30 @@ from neutron.db import models_v2
LOG = logging.getLogger(__name__)
def convert_result_to_dict(f):
@functools.wraps(f)
def inner(*args, **kwargs):
result = f(*args, **kwargs)
if result is None:
return None
elif isinstance(result, list):
return [r.to_dict() for r in result]
else:
return result.to_dict()
return inner
def filter_fields(f):
@functools.wraps(f)
def inner_filter(*args, **kwargs):
result = f(*args, **kwargs)
fields = kwargs.get('fields')
if not fields:
pos = f.func_code.co_varnames.index('fields')
try:
pos = f.func_code.co_varnames.index('fields')
fields = args[pos]
except IndexError:
except (IndexError, ValueError):
return result
do_filter = lambda d: {k: v for k, v in d.items() if k in fields}

View File

@ -42,18 +42,20 @@ class QoSPlugin(qos.QoSPluginBase):
self.notification_driver_manager = (
driver_mgr.QosServiceNotificationDriverManager())
@db_base_plugin_common.convert_result_to_dict
def create_policy(self, context, policy):
policy = policy_object.QosPolicy(context, **policy['policy'])
policy.create()
self.notification_driver_manager.create_policy(policy)
return policy.to_dict()
return policy
@db_base_plugin_common.convert_result_to_dict
def update_policy(self, context, policy_id, policy):
policy = policy_object.QosPolicy(context, **policy['policy'])
policy.id = policy_id
policy.update()
self.notification_driver_manager.update_policy(policy)
return policy.to_dict()
return policy
def delete_policy(self, context, policy_id):
policy = policy_object.QosPolicy(context)
@ -68,21 +70,23 @@ class QoSPlugin(qos.QoSPluginBase):
return obj
@db_base_plugin_common.filter_fields
@db_base_plugin_common.convert_result_to_dict
def get_policy(self, context, policy_id, fields=None):
return self._get_policy_obj(context, policy_id).to_dict()
return self._get_policy_obj(context, policy_id)
@db_base_plugin_common.filter_fields
@db_base_plugin_common.convert_result_to_dict
def get_policies(self, context, filters=None, fields=None,
sorts=None, limit=None, marker=None,
page_reverse=False):
#TODO(QoS): Support all the optional parameters
return [policy_obj.to_dict() for policy_obj in
policy_object.QosPolicy.get_objects(context)]
return policy_object.QosPolicy.get_objects(context)
#TODO(QoS): Consider adding a proxy catch-all for rules, so
# we capture the API function call, and just pass
# the rule type as a parameter removing lots of
# future code duplication when we have more rules.
@db_base_plugin_common.convert_result_to_dict
def create_policy_bandwidth_limit_rule(self, context, policy_id,
bandwidth_limit_rule):
#TODO(QoS): avoid creation of severan bandwidth limit rules
@ -96,8 +100,9 @@ class QoSPlugin(qos.QoSPluginBase):
**bandwidth_limit_rule['bandwidth_limit_rule'])
rule.create()
self.notification_driver_manager.update_policy(policy)
return rule.to_dict()
return rule
@db_base_plugin_common.convert_result_to_dict
def update_policy_bandwidth_limit_rule(self, context, rule_id, policy_id,
bandwidth_limit_rule):
# validate that we have access to the policy
@ -107,7 +112,7 @@ class QoSPlugin(qos.QoSPluginBase):
rule.id = rule_id
rule.update()
self.notification_driver_manager.update_policy(policy)
return rule.to_dict()
return rule
def delete_policy_bandwidth_limit_rule(self, context, rule_id, policy_id):
# validate that we have access to the policy
@ -118,14 +123,16 @@ class QoSPlugin(qos.QoSPluginBase):
self.notification_driver_manager.update_policy(policy)
@db_base_plugin_common.filter_fields
@db_base_plugin_common.convert_result_to_dict
def get_policy_bandwidth_limit_rule(self, context, rule_id,
policy_id, fields=None):
# validate that we have access to the policy
self._get_policy_obj(context, policy_id)
return rule_object.QosBandwidthLimitRule.get_by_id(context,
rule_id).to_dict()
rule_id)
@db_base_plugin_common.filter_fields
@db_base_plugin_common.convert_result_to_dict
def get_policy_bandwidth_limit_rules(self, context, policy_id,
filters=None, fields=None,
sorts=None, limit=None,
@ -133,12 +140,11 @@ class QoSPlugin(qos.QoSPluginBase):
#TODO(QoS): Support all the optional parameters
# validate that we have access to the policy
self._get_policy_obj(context, policy_id)
return [rule_obj.to_dict() for rule_obj in
rule_object.QosBandwidthLimitRule.get_objects(context)]
return rule_object.QosBandwidthLimitRule.get_objects(context)
@db_base_plugin_common.filter_fields
@db_base_plugin_common.convert_result_to_dict
def get_rule_types(self, context, filters=None, fields=None,
sorts=None, limit=None,
marker=None, page_reverse=False):
return [rule_type_obj.to_dict() for rule_type_obj in
rule_type_object.QosRuleType.get_objects()]
return rule_type_object.QosRuleType.get_objects()

View File

@ -17,6 +17,35 @@ from neutron.db import db_base_plugin_common
from neutron.tests import base
class DummyObject(object):
def __init__(self, **kwargs):
self.kwargs = kwargs
def to_dict(self):
return self.kwargs
class ConvertToDictTestCase(base.BaseTestCase):
@db_base_plugin_common.convert_result_to_dict
def method_dict(self, fields=None):
return DummyObject(one=1, two=2, three=3)
@db_base_plugin_common.convert_result_to_dict
def method_list(self):
return [DummyObject(one=1, two=2, three=3)] * 3
def test_simple_object(self):
expected = {'one': 1, 'two': 2, 'three': 3}
observed = self.method_dict()
self.assertEqual(expected, observed)
def test_list_of_objects(self):
expected = [{'one': 1, 'two': 2, 'three': 3}] * 3
observed = self.method_list()
self.assertEqual(expected, observed)
class FilterFieldsTestCase(base.BaseTestCase):
@db_base_plugin_common.filter_fields