From 12ff4d6b5890c2fd1e0a3e58f974be3e1f1465ca Mon Sep 17 00:00:00 2001 From: John Schwarz Date: Mon, 27 Jul 2015 12:09:10 +0300 Subject: [PATCH] Replace to_dict() calls with a function decorator Up until now, API server functions would need to return simple iterable objects, such as dicts and lists of dicts. This patch introduces a decorator which allows such functions to return non-simple objects (as long as the returned object implements the 'to_dict()' method, or is a list of such objects) and converts them on its own, simplifying the user's code and removing code duplication. Change-Id: Ib30a9213b86b33826291197cf01f00bc1dd3db52 --- neutron/db/db_base_plugin_common.py | 18 +++++++++-- neutron/services/qos/qos_plugin.py | 30 +++++++++++-------- .../unit/db/test_db_base_plugin_common.py | 29 ++++++++++++++++++ 3 files changed, 63 insertions(+), 14 deletions(-) diff --git a/neutron/db/db_base_plugin_common.py b/neutron/db/db_base_plugin_common.py index 4ce5daab7b6..c2fbff20107 100644 --- a/neutron/db/db_base_plugin_common.py +++ b/neutron/db/db_base_plugin_common.py @@ -29,16 +29,30 @@ from neutron.db import models_v2 LOG = logging.getLogger(__name__) +def convert_result_to_dict(f): + @functools.wraps(f) + def inner(*args, **kwargs): + result = f(*args, **kwargs) + + if result is None: + return None + elif isinstance(result, list): + return [r.to_dict() for r in result] + else: + return result.to_dict() + return inner + + def filter_fields(f): @functools.wraps(f) def inner_filter(*args, **kwargs): result = f(*args, **kwargs) fields = kwargs.get('fields') if not fields: - pos = f.func_code.co_varnames.index('fields') try: + pos = f.func_code.co_varnames.index('fields') fields = args[pos] - except IndexError: + except (IndexError, ValueError): return result do_filter = lambda d: {k: v for k, v in d.items() if k in fields} diff --git a/neutron/services/qos/qos_plugin.py b/neutron/services/qos/qos_plugin.py index 23135bf82be..d66acc2685c 100644 --- a/neutron/services/qos/qos_plugin.py +++ b/neutron/services/qos/qos_plugin.py @@ -42,18 +42,20 @@ class QoSPlugin(qos.QoSPluginBase): self.notification_driver_manager = ( driver_mgr.QosServiceNotificationDriverManager()) + @db_base_plugin_common.convert_result_to_dict def create_policy(self, context, policy): policy = policy_object.QosPolicy(context, **policy['policy']) policy.create() self.notification_driver_manager.create_policy(policy) - return policy.to_dict() + return policy + @db_base_plugin_common.convert_result_to_dict def update_policy(self, context, policy_id, policy): policy = policy_object.QosPolicy(context, **policy['policy']) policy.id = policy_id policy.update() self.notification_driver_manager.update_policy(policy) - return policy.to_dict() + return policy def delete_policy(self, context, policy_id): policy = policy_object.QosPolicy(context) @@ -68,21 +70,23 @@ class QoSPlugin(qos.QoSPluginBase): return obj @db_base_plugin_common.filter_fields + @db_base_plugin_common.convert_result_to_dict def get_policy(self, context, policy_id, fields=None): - return self._get_policy_obj(context, policy_id).to_dict() + return self._get_policy_obj(context, policy_id) @db_base_plugin_common.filter_fields + @db_base_plugin_common.convert_result_to_dict def get_policies(self, context, filters=None, fields=None, sorts=None, limit=None, marker=None, page_reverse=False): #TODO(QoS): Support all the optional parameters - return [policy_obj.to_dict() for policy_obj in - policy_object.QosPolicy.get_objects(context)] + return policy_object.QosPolicy.get_objects(context) #TODO(QoS): Consider adding a proxy catch-all for rules, so # we capture the API function call, and just pass # the rule type as a parameter removing lots of # future code duplication when we have more rules. + @db_base_plugin_common.convert_result_to_dict def create_policy_bandwidth_limit_rule(self, context, policy_id, bandwidth_limit_rule): #TODO(QoS): avoid creation of severan bandwidth limit rules @@ -96,8 +100,9 @@ class QoSPlugin(qos.QoSPluginBase): **bandwidth_limit_rule['bandwidth_limit_rule']) rule.create() self.notification_driver_manager.update_policy(policy) - return rule.to_dict() + return rule + @db_base_plugin_common.convert_result_to_dict def update_policy_bandwidth_limit_rule(self, context, rule_id, policy_id, bandwidth_limit_rule): # validate that we have access to the policy @@ -107,7 +112,7 @@ class QoSPlugin(qos.QoSPluginBase): rule.id = rule_id rule.update() self.notification_driver_manager.update_policy(policy) - return rule.to_dict() + return rule def delete_policy_bandwidth_limit_rule(self, context, rule_id, policy_id): # validate that we have access to the policy @@ -118,14 +123,16 @@ class QoSPlugin(qos.QoSPluginBase): self.notification_driver_manager.update_policy(policy) @db_base_plugin_common.filter_fields + @db_base_plugin_common.convert_result_to_dict def get_policy_bandwidth_limit_rule(self, context, rule_id, policy_id, fields=None): # validate that we have access to the policy self._get_policy_obj(context, policy_id) return rule_object.QosBandwidthLimitRule.get_by_id(context, - rule_id).to_dict() + rule_id) @db_base_plugin_common.filter_fields + @db_base_plugin_common.convert_result_to_dict def get_policy_bandwidth_limit_rules(self, context, policy_id, filters=None, fields=None, sorts=None, limit=None, @@ -133,12 +140,11 @@ class QoSPlugin(qos.QoSPluginBase): #TODO(QoS): Support all the optional parameters # validate that we have access to the policy self._get_policy_obj(context, policy_id) - return [rule_obj.to_dict() for rule_obj in - rule_object.QosBandwidthLimitRule.get_objects(context)] + return rule_object.QosBandwidthLimitRule.get_objects(context) @db_base_plugin_common.filter_fields + @db_base_plugin_common.convert_result_to_dict def get_rule_types(self, context, filters=None, fields=None, sorts=None, limit=None, marker=None, page_reverse=False): - return [rule_type_obj.to_dict() for rule_type_obj in - rule_type_object.QosRuleType.get_objects()] + return rule_type_object.QosRuleType.get_objects() diff --git a/neutron/tests/unit/db/test_db_base_plugin_common.py b/neutron/tests/unit/db/test_db_base_plugin_common.py index 9074bf6183c..21866522ad7 100644 --- a/neutron/tests/unit/db/test_db_base_plugin_common.py +++ b/neutron/tests/unit/db/test_db_base_plugin_common.py @@ -17,6 +17,35 @@ from neutron.db import db_base_plugin_common from neutron.tests import base +class DummyObject(object): + def __init__(self, **kwargs): + self.kwargs = kwargs + + def to_dict(self): + return self.kwargs + + +class ConvertToDictTestCase(base.BaseTestCase): + + @db_base_plugin_common.convert_result_to_dict + def method_dict(self, fields=None): + return DummyObject(one=1, two=2, three=3) + + @db_base_plugin_common.convert_result_to_dict + def method_list(self): + return [DummyObject(one=1, two=2, three=3)] * 3 + + def test_simple_object(self): + expected = {'one': 1, 'two': 2, 'three': 3} + observed = self.method_dict() + self.assertEqual(expected, observed) + + def test_list_of_objects(self): + expected = [{'one': 1, 'two': 2, 'three': 3}] * 3 + observed = self.method_list() + self.assertEqual(expected, observed) + + class FilterFieldsTestCase(base.BaseTestCase): @db_base_plugin_common.filter_fields