diff --git a/quantum/db/db_base_plugin_v2.py b/quantum/db/db_base_plugin_v2.py index cce4e07aac..fe779f26b7 100644 --- a/quantum/db/db_base_plugin_v2.py +++ b/quantum/db/db_base_plugin_v2.py @@ -206,11 +206,18 @@ class QuantumDbPluginV2(quantum_plugin_base_v2.QuantumPluginBaseV2): query = query.filter(column.in_(value)) return query - def _get_collection(self, context, model, dict_func, filters=None, - fields=None): + def _get_collection_query(self, context, model, filters=None): collection = self._model_query(context, model) collection = self._apply_filters_to_query(collection, model, filters) - return [dict_func(c, fields) for c in collection.all()] + return collection + + def _get_collection(self, context, model, dict_func, filters=None, + fields=None): + query = self._get_collection_query(context, model, filters) + return [dict_func(c, fields) for c in query.all()] + + def _get_collection_count(self, context, model, filters=None): + return self._get_collection_query(context, model, filters).count() @staticmethod def _generate_mac(context, network_id): @@ -952,6 +959,10 @@ class QuantumDbPluginV2(quantum_plugin_base_v2.QuantumPluginBaseV2): self._make_network_dict, filters=filters, fields=fields) + def get_networks_count(self, context, filters=None): + return self._get_collection_count(context, models_v2.Network, + filters=filters) + def create_subnet_bulk(self, context, subnets): return self._create_bulk('subnet', context, subnets) @@ -1143,6 +1154,10 @@ class QuantumDbPluginV2(quantum_plugin_base_v2.QuantumPluginBaseV2): self._make_subnet_dict, filters=filters, fields=fields) + def get_subnets_count(self, context, filters=None): + return self._get_collection_count(context, models_v2.Subnet, + filters=filters) + def create_port_bulk(self, context, ports): return self._create_bulk('port', context, ports) @@ -1274,7 +1289,7 @@ class QuantumDbPluginV2(quantum_plugin_base_v2.QuantumPluginBaseV2): port = self._get_port(context, id) return self._make_port_dict(port, fields) - def get_ports(self, context, filters=None, fields=None): + def _get_ports_query(self, context, filters=None): Port = models_v2.Port IPAllocation = models_v2.IPAllocation @@ -1294,4 +1309,11 @@ class QuantumDbPluginV2(quantum_plugin_base_v2.QuantumPluginBaseV2): query = query.filter(IPAllocation.subnet_id.in_(subnet_ids)) query = self._apply_filters_to_query(query, Port, filters) + return query + + def get_ports(self, context, filters=None, fields=None): + query = self._get_ports_query(context, filters) return [self._make_port_dict(c, fields) for c in query.all()] + + def get_ports_count(self, context, filters=None): + return self._get_ports_query(context, filters).count() diff --git a/quantum/db/l3_db.py b/quantum/db/l3_db.py index e192a59e32..1886ead0b7 100644 --- a/quantum/db/l3_db.py +++ b/quantum/db/l3_db.py @@ -251,6 +251,10 @@ class L3_NAT_db_mixin(l3.RouterPluginBase): self._make_router_dict, filters=filters, fields=fields) + def get_routers_count(self, context, filters=None): + return self._get_collection_count(context, Router, + filters=filters) + def _check_for_dup_router_subnet(self, context, router_id, network_id, subnet_id): try: @@ -615,6 +619,10 @@ class L3_NAT_db_mixin(l3.RouterPluginBase): self._make_floatingip_dict, filters=filters, fields=fields) + def get_floatingips_count(self, context, filters=None): + return self._get_collection_count(context, FloatingIP, + filters=filters) + def prevent_l3_port_deletion(self, context, port_id): """ Checks to make sure a port is allowed to be deleted, raising an exception if this is not the case. This should be called by diff --git a/quantum/db/securitygroups_db.py b/quantum/db/securitygroups_db.py index 2093b334bc..b61f1dcb33 100644 --- a/quantum/db/securitygroups_db.py +++ b/quantum/db/securitygroups_db.py @@ -150,6 +150,10 @@ class SecurityGroupDbMixin(ext_sg.SecurityGroupPluginBase): self._make_security_group_dict, filters=filters, fields=fields) + def get_security_groups_count(self, context, filters=None): + return self._get_collection_count(context, SecurityGroup, + filters=filters) + def get_security_group(self, context, id, fields=None, tenant_id=None): """Tenant id is given to handle the case when we are creating a security group or security group rule on behalf of @@ -384,6 +388,10 @@ class SecurityGroupDbMixin(ext_sg.SecurityGroupPluginBase): self._make_security_group_rule_dict, filters=filters, fields=fields) + def get_security_group_rules_count(self, context, filters=None): + return self._get_collection_count(context, SecurityGroupRule, + filters=filters) + def get_security_group_rule(self, context, id, fields=None): security_group_rule = self._get_security_group_rule(context, id) return self._make_security_group_rule_dict(security_group_rule, fields) diff --git a/quantum/plugins/cisco/models/virt_phy_sw_v2.py b/quantum/plugins/cisco/models/virt_phy_sw_v2.py index 1d6622e5e0..1586c0016f 100644 --- a/quantum/plugins/cisco/models/virt_phy_sw_v2.py +++ b/quantum/plugins/cisco/models/virt_phy_sw_v2.py @@ -289,6 +289,10 @@ class VirtualPhysicalSwitchModelV2(quantum_plugin_base_v2.QuantumPluginBaseV2): """For this model this method will be delegated to vswitch plugin""" pass + def get_networks_count(self, context, filters=None): + """For this model this method will be delegated to vswitch plugin""" + pass + def create_port(self, context, port): """For this model this method will be delegated to vswitch plugin""" pass @@ -301,6 +305,10 @@ class VirtualPhysicalSwitchModelV2(quantum_plugin_base_v2.QuantumPluginBaseV2): """For this model this method will be delegated to vswitch plugin""" pass + def get_ports_count(self, context, filters=None): + """For this model this method will be delegated to vswitch plugin""" + pass + def update_port(self, context, id, port): """For this model this method will be delegated to vswitch plugin""" pass @@ -328,3 +336,7 @@ class VirtualPhysicalSwitchModelV2(quantum_plugin_base_v2.QuantumPluginBaseV2): def get_subnets(self, context, filters=None, fields=None): """For this model this method will be delegated to vswitch plugin""" pass + + def get_subnets_count(self, context, filters=None): + """For this model this method will be delegated to vswitch plugin""" + pass diff --git a/quantum/plugins/cisco/network_plugin.py b/quantum/plugins/cisco/network_plugin.py index 64ce6796b2..05edc4c537 100644 --- a/quantum/plugins/cisco/network_plugin.py +++ b/quantum/plugins/cisco/network_plugin.py @@ -92,6 +92,11 @@ class PluginV2(db_base_plugin_v2.QuantumDbPluginV2): """ if hasattr(self._model, name): return getattr(self._model, name) + else: + # Must make sure we re-raise the error that led us here, since + # otherwise getattr() and even hasattr() doesn't work corretly. + raise AttributeError("'%s' object has no attribute '%s'" % + (self._model, name)) """ Core API implementation diff --git a/quantum/quantum_plugin_base_v2.py b/quantum/quantum_plugin_base_v2.py index 2dae8911d2..61b8fa437a 100644 --- a/quantum/quantum_plugin_base_v2.py +++ b/quantum/quantum_plugin_base_v2.py @@ -70,7 +70,7 @@ class QuantumPluginBaseV2(object): def get_subnets(self, context, filters=None, fields=None): """ Retrieve a list of subnets. The contents of the list depends on - the identify of the user making the request (as indicated by the + the identity of the user making the request (as indicated by the context) as well as any filters. : param context: quantum api request context : param filters: a dictionary with keys that are valid keys for @@ -87,6 +87,23 @@ class QuantumPluginBaseV2(object): """ pass + @abstractmethod + def get_subnets_count(self, context, filters=None): + """ + Return the number of subnets. The result depends on the identity of + the user making the request (as indicated by the context) as well as + any filters. + : param context: quantum api request context + : param filters: a dictionary with keys that are valid keys for + a network as listed in the RESOURCE_ATTRIBUTE_MAP object + in quantum/api/v2/attributes.py. Values in this dictiontary + are an iterable containing values that will be used for an exact + match comparison for that value. Each result returned by this + function will have matched one of the values for each key in + filters. + """ + pass + @abstractmethod def delete_subnet(self, context, id): """ @@ -138,7 +155,7 @@ class QuantumPluginBaseV2(object): def get_networks(self, context, filters=None, fields=None): """ Retrieve a list of networks. The contents of the list depends on - the identify of the user making the request (as indicated by the + the identity of the user making the request (as indicated by the context) as well as any filters. : param context: quantum api request context : param filters: a dictionary with keys that are valid keys for @@ -155,6 +172,23 @@ class QuantumPluginBaseV2(object): """ pass + @abstractmethod + def get_networks_count(self, context, filters=None): + """ + Return the number of networks. The result depends on the identity + of the user making the request (as indicated by the context) as well + as any filters. + : param context: quantum api request context + : param filters: a dictionary with keys that are valid keys for + a network as listed in the RESOURCE_ATTRIBUTE_MAP object + in quantum/api/v2/attributes.py. Values in this dictiontary + are an iterable containing values that will be used for an exact + match comparison for that value. Each result returned by this + function will have matched one of the values for each key in + filters. + """ + pass + @abstractmethod def delete_network(self, context, id): """ @@ -206,7 +240,7 @@ class QuantumPluginBaseV2(object): def get_ports(self, context, filters=None, fields=None): """ Retrieve a list of ports. The contents of the list depends on - the identify of the user making the request (as indicated by the + the identity of the user making the request (as indicated by the context) as well as any filters. : param context: quantum api request context : param filters: a dictionary with keys that are valid keys for @@ -223,6 +257,23 @@ class QuantumPluginBaseV2(object): """ pass + @abstractmethod + def get_ports_count(self, context, filters=None): + """ + Return the number of ports. The result depends on the identity of + the user making the request (as indicated by the context) as well as + any filters. + : param context: quantum api request context + : param filters: a dictionary with keys that are valid keys for + a network as listed in the RESOURCE_ATTRIBUTE_MAP object + in quantum/api/v2/attributes.py. Values in this dictiontary + are an iterable containing values that will be used for an exact + match comparison for that value. Each result returned by this + function will have matched one of the values for each key in + filters. + """ + pass + @abstractmethod def delete_port(self, context, id): """ diff --git a/quantum/quota.py b/quantum/quota.py index bcb498fc16..f0b3f8fc65 100644 --- a/quantum/quota.py +++ b/quantum/quota.py @@ -266,9 +266,19 @@ QUOTAS = QuotaEngine() def _count_resource(context, plugin, resources, tenant_id): - obj_getter = getattr(plugin, "get_%s" % resources) - obj_list = obj_getter(context, filters={'tenant_id': [tenant_id]}) - return len(obj_list) if obj_list else 0 + count_getter_name = "get_%s_count" % resources + + # Some plugins support a count method for particular resources, + # using a DB's optimized counting features. We try to use that one + # if present. Otherwise just use regular getter to retrieve all objects + # and count in python, allowing older plugins to still be supported + if hasattr(plugin, count_getter_name): + obj_count_getter = getattr(plugin, count_getter_name) + return obj_count_getter(context, filters={'tenant_id': [tenant_id]}) + else: + obj_getter = getattr(plugin, "get_%s" % resources) + obj_list = obj_getter(context, filters={'tenant_id': [tenant_id]}) + return len(obj_list) if obj_list else 0 resources = [] diff --git a/quantum/tests/unit/test_api_v2.py b/quantum/tests/unit/test_api_v2.py index efa46e5e27..619af9f45d 100644 --- a/quantum/tests/unit/test_api_v2.py +++ b/quantum/tests/unit/test_api_v2.py @@ -369,6 +369,7 @@ class JSONV2TestCase(APIv2TestBase): instance = self.plugin.return_value instance.create_network.return_value = return_value + instance.get_networks_count.return_value = 0 res = self.api.post_json(_get_path('networks'), data) @@ -390,6 +391,7 @@ class JSONV2TestCase(APIv2TestBase): instance = self.plugin.return_value instance.create_network.return_value = return_value + instance.get_networks_count.return_value = 0 res = self.api.post_json(_get_path('networks'), initial_input) @@ -423,6 +425,7 @@ class JSONV2TestCase(APIv2TestBase): instance = self.plugin.return_value instance.create_network.return_value = return_value + instance.get_networks_count.return_value = 0 res = self.api.post_json(_get_path('networks'), initial_input, extra_environ=env) @@ -479,6 +482,7 @@ class JSONV2TestCase(APIv2TestBase): instance = self.plugin.return_value instance.create_network.side_effect = side_effect + instance.get_networks_count.return_value = 0 res = self.api.post_json(_get_path('networks'), data) self.assertEqual(res.status_int, exc.HTTPCreated.code) @@ -525,6 +529,7 @@ class JSONV2TestCase(APIv2TestBase): instance = self.plugin.return_value instance.get_network.return_value = {'tenant_id': unicode(tenant_id)} + instance.get_ports_count.return_value = 1 instance.create_port.return_value = return_value res = self.api.post_json(_get_path('ports'), initial_input) @@ -545,6 +550,7 @@ class JSONV2TestCase(APIv2TestBase): instance = self.plugin.return_value instance.create_network.return_value = return_value + instance.get_networks_count.return_value = 0 res = self.api.post_json(_get_path('networks'), data) @@ -699,6 +705,7 @@ class NotificationTest(APIv2TestBase): initial_input = {resource: {'name': 'myname'}} instance = self.plugin.return_value instance.get_networks.return_value = initial_input + instance.get_networks_count.return_value = 0 expected_code = exc.HTTPCreated.code with mock.patch.object(notifer_api, 'notify') as mynotifier: if opname == 'create': @@ -742,37 +749,24 @@ class NotificationTest(APIv2TestBase): class QuotaTest(APIv2TestBase): def test_create_network_quota(self): cfg.CONF.set_override('quota_network', 1, group='QUOTAS') - net_id = _uuid() initial_input = {'network': {'name': 'net1', 'tenant_id': _uuid()}} full_input = {'network': {'admin_state_up': True, 'subnets': []}} full_input['network'].update(initial_input['network']) - return_value = {'id': net_id, 'status': "ACTIVE"} - return_value.update(full_input['network']) - return_networks = {'networks': [return_value]} instance = self.plugin.return_value - instance.get_networks.return_value = return_networks + instance.get_networks_count.return_value = 1 res = self.api.post_json( _get_path('networks'), initial_input, expect_errors=True) - instance.get_networks.assert_called_with(mock.ANY, - filters=mock.ANY) + instance.get_networks_count.assert_called_with(mock.ANY, + filters=mock.ANY) self.assertTrue("Quota exceeded for resources" in res.json['QuantumError']) def test_create_network_quota_without_limit(self): cfg.CONF.set_override('quota_network', -1, group='QUOTAS') - net_id = _uuid() initial_input = {'network': {'name': 'net1', 'tenant_id': _uuid()}} - full_input = {'network': {'admin_state_up': True, 'subnets': []}} - full_input['network'].update(initial_input['network']) - return_networks = [] - for i in xrange(0, 3): - return_value = {'id': net_id + str(i), 'status': "ACTIVE"} - return_value.update(full_input['network']) - return_networks.append(return_value) - self.assertEquals(3, len(return_networks)) instance = self.plugin.return_value - instance.get_networks.return_value = return_networks + instance.get_networks_count.return_value = 3 res = self.api.post_json( _get_path('networks'), initial_input) self.assertEqual(res.status_int, exc.HTTPCreated.code) @@ -836,6 +830,7 @@ class ExtensionTestCase(unittest.TestCase): instance = self.plugin.return_value instance.create_network.return_value = return_value + instance.get_networks_count.return_value = 0 res = self.api.post_json(_get_path('networks'), initial_input) diff --git a/quantum/tests/unit/test_db_plugin.py b/quantum/tests/unit/test_db_plugin.py index 281fee836b..4c73a2760a 100644 --- a/quantum/tests/unit/test_db_plugin.py +++ b/quantum/tests/unit/test_db_plugin.py @@ -376,6 +376,7 @@ class QuantumDbPluginV2TestCase(unittest2.TestCase): def network(self, name='net1', admin_status_up=True, fmt='json', + do_delete=True, **kwargs): res = self._create_network(fmt, name, @@ -389,7 +390,13 @@ class QuantumDbPluginV2TestCase(unittest2.TestCase): if res.status_int >= 400: raise webob.exc.HTTPClientError(code=res.status_int) yield network - self._delete('networks', network['network']['id']) + + if do_delete: + # The do_delete parameter allows you to control whether the + # created network is immediately deleted again. Therefore, this + # function is also usable in tests, which require the creation + # of many networks. + self._delete('networks', network['network']['id']) @contextlib.contextmanager def subnet(self, network=None,