From 1b5d0867ab320ea24c2f3954fb8927d0b5df2b17 Mon Sep 17 00:00:00 2001 From: Michal Kelner Mishali Date: Sun, 18 Mar 2018 11:11:37 +0200 Subject: [PATCH] Filter port-list based on security-group This patch will allow users to filter ports according to security_group supplied as a filter. Code is for V and V3. Change-Id: I20b4655cb188aae9d031fee20aea917268ebdf48 Signed-off-by: Michal Kelner Mishali --- vmware_nsx/plugins/common/plugin.py | 15 +++++++++ vmware_nsx/plugins/nsx_v/plugin.py | 16 +++++++++- vmware_nsx/plugins/nsx_v3/plugin.py | 4 ++- vmware_nsx/tests/unit/nsx_v/test_plugin.py | 34 +++++++++++++++++++++ vmware_nsx/tests/unit/nsx_v3/test_plugin.py | 34 +++++++++++++++++++++ 5 files changed, 101 insertions(+), 2 deletions(-) diff --git a/vmware_nsx/plugins/common/plugin.py b/vmware_nsx/plugins/common/plugin.py index 3eed3da5cc..a39cfe66dc 100644 --- a/vmware_nsx/plugins/common/plugin.py +++ b/vmware_nsx/plugins/common/plugin.py @@ -178,6 +178,21 @@ class NsxPluginBase(db_base_plugin_v2.NeutronDbPluginV2, device_id=device_id, device_owner=device_owner,).all() + def _update_filters_with_sec_group(self, context, filters=None): + if filters is not None: + security_groups = filters.pop("security_groups", None) + if security_groups: + bindings = ( + super(NsxPluginBase, self) + ._get_port_security_group_bindings(context, + filters={'security_group_id': security_groups})) + if 'id' in filters: + filters['id'] = [entry['port_id'] for + entry in bindings + if entry['port_id'] in filters['id']] + else: + filters['id'] = [entry['port_id'] for entry in bindings] + def _find_router_subnets(self, context, router_id): """Retrieve subnets attached to the specified router.""" ports = self._get_port_by_device_id(context, router_id, diff --git a/vmware_nsx/plugins/nsx_v/plugin.py b/vmware_nsx/plugins/nsx_v/plugin.py index 266eca65cb..3dbe025ac5 100644 --- a/vmware_nsx/plugins/nsx_v/plugin.py +++ b/vmware_nsx/plugins/nsx_v/plugin.py @@ -208,7 +208,8 @@ class NsxVPluginV2(addr_pair_db.AllowedAddressPairsMixin, "flavors", "dhcp-mtu", "mac-learning", - "housekeeper"] + "housekeeper", + "port-security-groups-filtering"] __native_bulk_support = True __native_pagination_support = True @@ -2449,6 +2450,19 @@ class NsxVPluginV2(addr_pair_db.AllowedAddressPairsMixin, self._extend_get_port_dict_qos(context, port) return db_utils.resource_fields(port, fields) + def get_ports(self, context, filters=None, fields=None, + sorts=None, limit=None, marker=None, + page_reverse=False): + filters = filters or {} + self._update_filters_with_sec_group(context, filters) + with db_api.context_manager.reader.using(context): + ports = ( + super(NsxVPluginV2, self).get_ports( + context, filters, fields, sorts, + limit, marker, page_reverse)) + return (ports if not fields else + [db_utils.resource_fields(port, fields) for port in ports]) + def delete_port(self, context, id, l3_port_check=True, nw_gw_port_check=True, force_delete_dhcp=False, allow_delete_internal=False): diff --git a/vmware_nsx/plugins/nsx_v3/plugin.py b/vmware_nsx/plugins/nsx_v3/plugin.py index dd717275d2..c16ff1fad3 100644 --- a/vmware_nsx/plugins/nsx_v3/plugin.py +++ b/vmware_nsx/plugins/nsx_v3/plugin.py @@ -207,7 +207,8 @@ class NsxV3Plugin(agentschedulers_db.AZDhcpAgentSchedulerDbMixin, "router_availability_zone", "subnet_allocation", "security-group-logging", - "provider-security-group"] + "provider-security-group", + "port-security-groups-filtering"] @resource_registry.tracked_resources( network=models_v2.Network, @@ -3309,6 +3310,7 @@ class NsxV3Plugin(agentschedulers_db.AZDhcpAgentSchedulerDbMixin, sorts=None, limit=None, marker=None, page_reverse=False): filters = filters or {} + self._update_filters_with_sec_group(context, filters) with db_api.context_manager.reader.using(context): ports = ( super(NsxV3Plugin, self).get_ports( diff --git a/vmware_nsx/tests/unit/nsx_v/test_plugin.py b/vmware_nsx/tests/unit/nsx_v/test_plugin.py index fc85c9fe86..d17dc4505b 100644 --- a/vmware_nsx/tests/unit/nsx_v/test_plugin.py +++ b/vmware_nsx/tests/unit/nsx_v/test_plugin.py @@ -19,6 +19,7 @@ import copy from eventlet import greenthread import mock import netaddr +from neutron.db import securitygroups_db as sg_db from neutron.extensions import address_scope from neutron.extensions import l3 from neutron.extensions import securitygroup as secgrp @@ -1069,6 +1070,39 @@ class TestPortsV2(NsxVPluginV2TestCase, [('admin_state_up', 'asc'), ('mac_address', 'desc')]) + def test_list_ports_filtered_by_security_groups(self): + ctx = context.get_admin_context() + with self.port() as port1, self.port() as port2: + query_params = "security_groups=%s" % ( + port1['port']['security_groups'][0]) + ports_data = self._list('ports', query_params=query_params) + self.assertEqual(set([port1['port']['id'], port2['port']['id']]), + set([port['id'] for port in ports_data['ports']])) + query_params = "security_groups=%s&id=%s" % ( + port1['port']['security_groups'][0], + port1['port']['id']) + ports_data = self._list('ports', query_params=query_params) + self.assertEqual(port1['port']['id'], ports_data['ports'][0]['id']) + self.assertEqual(1, len(ports_data['ports'])) + temp_sg = {'security_group': {'tenant_id': 'some_tenant', + 'name': '', 'description': 's'}} + sg_dbMixin = sg_db.SecurityGroupDbMixin() + sg = sg_dbMixin.create_security_group(ctx, temp_sg) + sg_dbMixin._delete_port_security_group_bindings( + ctx, port2['port']['id']) + sg_dbMixin._create_port_security_group_binding( + ctx, port2['port']['id'], sg['id']) + port2['port']['security_groups'][0] = sg['id'] + query_params = "security_groups=%s" % ( + port1['port']['security_groups'][0]) + ports_data = self._list('ports', query_params=query_params) + self.assertEqual(port1['port']['id'], ports_data['ports'][0]['id']) + self.assertEqual(1, len(ports_data['ports'])) + query_params = "security_groups=%s" % ( + (port2['port']['security_groups'][0])) + ports_data = self._list('ports', query_params=query_params) + self.assertEqual(port2['port']['id'], ports_data['ports'][0]['id']) + def test_update_port_delete_ip(self): # This test case overrides the default because the nsx plugin # implements port_security/security groups and it is not allowed diff --git a/vmware_nsx/tests/unit/nsx_v3/test_plugin.py b/vmware_nsx/tests/unit/nsx_v3/test_plugin.py index e534229e7b..474f7b0712 100644 --- a/vmware_nsx/tests/unit/nsx_v3/test_plugin.py +++ b/vmware_nsx/tests/unit/nsx_v3/test_plugin.py @@ -16,6 +16,7 @@ import mock import netaddr from neutron.db import models_v2 +from neutron.db import securitygroups_db as sg_db from neutron.extensions import address_scope from neutron.extensions import l3 from neutron.extensions import securitygroup as secgrp @@ -888,6 +889,39 @@ class TestPortsV2(test_plugin.TestPortsV2, NsxV3PluginTestCaseMixin, self._get_ports_with_fields(tenid, 'mac_address', 4) self._get_ports_with_fields(tenid, 'network_id', 4) + def test_list_ports_filtered_by_security_groups(self): + ctx = context.get_admin_context() + with self.port() as port1, self.port() as port2: + query_params = "security_groups=%s" % ( + port1['port']['security_groups'][0]) + ports_data = self._list('ports', query_params=query_params) + self.assertEqual(set([port1['port']['id'], port2['port']['id']]), + set([port['id'] for port in ports_data['ports']])) + query_params = "security_groups=%s&id=%s" % ( + port1['port']['security_groups'][0], + port1['port']['id']) + ports_data = self._list('ports', query_params=query_params) + self.assertEqual(port1['port']['id'], ports_data['ports'][0]['id']) + self.assertEqual(1, len(ports_data['ports'])) + temp_sg = {'security_group': {'tenant_id': 'some_tenant', + 'name': '', 'description': 's'}} + sg_dbMixin = sg_db.SecurityGroupDbMixin() + sg = sg_dbMixin.create_security_group(ctx, temp_sg) + sg_dbMixin._delete_port_security_group_bindings( + ctx, port2['port']['id']) + sg_dbMixin._create_port_security_group_binding( + ctx, port2['port']['id'], sg['id']) + port2['port']['security_groups'][0] = sg['id'] + query_params = "security_groups=%s" % ( + port1['port']['security_groups'][0]) + ports_data = self._list('ports', query_params=query_params) + self.assertEqual(port1['port']['id'], ports_data['ports'][0]['id']) + self.assertEqual(1, len(ports_data['ports'])) + query_params = "security_groups=%s" % ( + (port2['port']['security_groups'][0])) + ports_data = self._list('ports', query_params=query_params) + self.assertEqual(port2['port']['id'], ports_data['ports'][0]['id']) + def test_port_failure_rollback_dhcp_exception(self): cfg.CONF.set_override('native_dhcp_metadata', True, 'nsx_v3') self.plugin = directory.get_plugin()