Merge "[OVS FW] Clean conntrack entries with mark == CT_MARK_INVALID" into stable/stein

This commit is contained in:
Zuul 2021-03-12 14:28:04 +00:00 committed by Gerrit Code Review
commit eee271e45a
4 changed files with 79 additions and 1 deletions

View File

@ -116,10 +116,13 @@ class IpConntrackManager(object):
ethertype = rule.get('ethertype') ethertype = rule.get('ethertype')
protocol = rule.get('protocol') protocol = rule.get('protocol')
direction = rule.get('direction') direction = rule.get('direction')
mark = rule.get('mark')
cmd = ['conntrack', '-D'] cmd = ['conntrack', '-D']
if protocol: if protocol:
cmd.extend(['-p', str(protocol)]) cmd.extend(['-p', str(protocol)])
cmd.extend(['-f', str(ethertype).lower()]) cmd.extend(['-f', str(ethertype).lower()])
if mark is not None:
cmd.extend(['-m', str(mark)])
cmd.append('-d' if direction == 'ingress' else '-s') cmd.append('-d' if direction == 'ingress' else '-s')
cmd_ns = [] cmd_ns = []
if namespace: if namespace:
@ -170,10 +173,12 @@ class IpConntrackManager(object):
self._process(device_info_list, rule) self._process(device_info_list, rule)
def delete_conntrack_state_by_remote_ips(self, device_info_list, def delete_conntrack_state_by_remote_ips(self, device_info_list,
ethertype, remote_ips): ethertype, remote_ips, mark=None):
for direction in ['ingress', 'egress']: for direction in ['ingress', 'egress']:
rule = {'ethertype': str(ethertype).lower(), rule = {'ethertype': str(ethertype).lower(),
'direction': direction} 'direction': direction}
if mark:
rule['mark'] = mark
self._process(device_info_list, rule, remote_ips) self._process(device_info_list, rule, remote_ips)
def _populate_initial_zone_map(self): def _populate_initial_zone_map(self):
@ -251,3 +256,21 @@ class IpConntrackManager(object):
return index + ZONE_START return index + ZONE_START
# conntrack zones exhausted :( :( # conntrack zones exhausted :( :(
raise exceptions.CTZoneExhaustedError() raise exceptions.CTZoneExhaustedError()
class OvsIpConntrackManager(IpConntrackManager):
def __init__(self, execute=None):
super(OvsIpConntrackManager, self).__init__(
get_rules_for_table_func=None,
filtered_ports={}, unfiltered_ports={},
execute=execute, namespace=None, zone_per_port=False)
def _populate_initial_zone_map(self):
self._device_zone_map = {}
def get_device_zone(self, port, create=False):
of_port = port.get('of_port')
if of_port is None:
return
return of_port.vlan_tag

View File

@ -32,6 +32,7 @@ from oslo_utils import netutils
from neutron._i18n import _ from neutron._i18n import _
from neutron.agent.common import ovs_lib from neutron.agent.common import ovs_lib
from neutron.agent import firewall from neutron.agent import firewall
from neutron.agent.linux import ip_conntrack
from neutron.agent.linux.openvswitch_firewall import constants as ovsfw_consts from neutron.agent.linux.openvswitch_firewall import constants as ovsfw_consts
from neutron.agent.linux.openvswitch_firewall import exceptions from neutron.agent.linux.openvswitch_firewall import exceptions
from neutron.agent.linux.openvswitch_firewall import iptables from neutron.agent.linux.openvswitch_firewall import iptables
@ -482,6 +483,7 @@ class OVSFirewallDriver(firewall.FirewallDriver):
self._deferred = False self._deferred = False
self.iptables_helper = iptables.Helper(self.int_br.br) self.iptables_helper = iptables.Helper(self.int_br.br)
self.iptables_helper.load_driver_if_needed() self.iptables_helper.load_driver_if_needed()
self.ipconntrack = ip_conntrack.OvsIpConntrackManager()
self._initialize_firewall() self._initialize_firewall()
callbacks_registry.subscribe( callbacks_registry.subscribe(
@ -612,6 +614,12 @@ class OVSFirewallDriver(firewall.FirewallDriver):
return get_physical_network_from_other_config( return get_physical_network_from_other_config(
self.int_br.br, port_name) self.int_br.br, port_name)
def _delete_invalid_conntrack_entries_for_port(self, port, of_port):
port['of_port'] = of_port
for ethertype in [lib_const.IPv4, lib_const.IPv6]:
self.ipconntrack.delete_conntrack_state_by_remote_ips(
[port], ethertype, set(), mark=ovsfw_consts.CT_MARK_INVALID)
def get_ofport(self, port): def get_ofport(self, port):
port_id = port['device'] port_id = port['device']
return self.sg_port_map.ports.get(port_id) return self.sg_port_map.ports.get(port_id)
@ -666,6 +674,7 @@ class OVSFirewallDriver(firewall.FirewallDriver):
self._update_flows_for_port(of_port, old_of_port) self._update_flows_for_port(of_port, old_of_port)
else: else:
self._set_port_filters(of_port) self._set_port_filters(of_port)
self._delete_invalid_conntrack_entries_for_port(port, of_port)
except exceptions.OVSFWPortNotFound as not_found_error: except exceptions.OVSFWPortNotFound as not_found_error:
LOG.info("port %(port_id)s does not exist in ovsdb: %(err)s.", LOG.info("port %(port_id)s does not exist in ovsdb: %(err)s.",
{'port_id': port['device'], {'port_id': port['device'],
@ -705,6 +714,8 @@ class OVSFirewallDriver(firewall.FirewallDriver):
else: else:
self._set_port_filters(of_port) self._set_port_filters(of_port)
self._delete_invalid_conntrack_entries_for_port(port, of_port)
except exceptions.OVSFWPortNotFound as not_found_error: except exceptions.OVSFWPortNotFound as not_found_error:
LOG.info("port %(port_id)s does not exist in ovsdb: %(err)s.", LOG.info("port %(port_id)s does not exist in ovsdb: %(err)s.",
{'port_id': port['device'], {'port_id': port['device'],

View File

@ -436,6 +436,9 @@ class TestOVSFirewallDriver(base.BaseTestCase):
mock_bridge = mock.patch.object( mock_bridge = mock.patch.object(
ovs_lib, 'OVSBridge', autospec=True).start() ovs_lib, 'OVSBridge', autospec=True).start()
self.firewall = ovsfw.OVSFirewallDriver(mock_bridge) self.firewall = ovsfw.OVSFirewallDriver(mock_bridge)
self.delete_invalid_conntrack_entries_mock = mock.patch.object(
self.firewall.ipconntrack,
"delete_conntrack_state_by_remote_ips").start()
self.mock_bridge = self.firewall.int_br self.mock_bridge = self.firewall.int_br
self.mock_bridge.reset_mock() self.mock_bridge.reset_mock()
self.fake_ovs_port = FakeOVSPort('port', 1, '00:00:00:00:00:00') self.fake_ovs_port = FakeOVSPort('port', 1, '00:00:00:00:00:00')
@ -460,6 +463,16 @@ class TestOVSFirewallDriver(base.BaseTestCase):
'direction': constants.EGRESS_DIRECTION}] 'direction': constants.EGRESS_DIRECTION}]
self.firewall.update_security_group_rules(2, security_group_rules) self.firewall.update_security_group_rules(2, security_group_rules)
def _assert_invalid_conntrack_entries_deleted(self, port_dict):
port_dict['of_port'] = mock.Mock(vlan_tag=10)
self.delete_invalid_conntrack_entries_mock.assert_has_calls([
mock.call(
[port_dict], constants.IPv4, set(),
mark=ovsfw_consts.CT_MARK_INVALID),
mock.call(
[port_dict], constants.IPv6, set(),
mark=ovsfw_consts.CT_MARK_INVALID)])
@property @property
def port_ofport(self): def port_ofport(self):
return self.mock_bridge.br.get_vif_port_by_id.return_value.ofport return self.mock_bridge.br.get_vif_port_by_id.return_value.ofport
@ -617,6 +630,7 @@ class TestOVSFirewallDriver(base.BaseTestCase):
calls = self.mock_bridge.br.add_flow.call_args_list calls = self.mock_bridge.br.add_flow.call_args_list
for call in exp_ingress_classifier, exp_egress_classifier, filter_rule: for call in exp_ingress_classifier, exp_egress_classifier, filter_rule:
self.assertIn(call, calls) self.assertIn(call, calls)
self._assert_invalid_conntrack_entries_deleted(port_dict)
def test_prepare_port_filter_port_security_disabled(self): def test_prepare_port_filter_port_security_disabled(self):
port_dict = {'device': 'port-id', port_dict = {'device': 'port-id',
@ -627,6 +641,7 @@ class TestOVSFirewallDriver(base.BaseTestCase):
self.firewall, 'initialize_port_flows') as m_init_flows: self.firewall, 'initialize_port_flows') as m_init_flows:
self.firewall.prepare_port_filter(port_dict) self.firewall.prepare_port_filter(port_dict)
self.assertFalse(m_init_flows.called) self.assertFalse(m_init_flows.called)
self.delete_invalid_conntrack_entries_mock.assert_not_called()
def test_initialize_port_flows_vlan_dvr_conntrack_direct(self): def test_initialize_port_flows_vlan_dvr_conntrack_direct(self):
port_dict = { port_dict = {
@ -765,6 +780,7 @@ class TestOVSFirewallDriver(base.BaseTestCase):
self.assertFalse(self.mock_bridge.br.delete_flows.called) self.assertFalse(self.mock_bridge.br.delete_flows.called)
self.firewall.prepare_port_filter(port_dict) self.firewall.prepare_port_filter(port_dict)
self.assertTrue(self.mock_bridge.br.delete_flows.called) self.assertTrue(self.mock_bridge.br.delete_flows.called)
self._assert_invalid_conntrack_entries_deleted(port_dict)
def test_update_port_filter(self): def test_update_port_filter(self):
port_dict = {'device': 'port-id', port_dict = {'device': 'port-id',
@ -796,6 +812,7 @@ class TestOVSFirewallDriver(base.BaseTestCase):
table=ovs_consts.RULES_EGRESS_TABLE)] table=ovs_consts.RULES_EGRESS_TABLE)]
self.mock_bridge.br.add_flow.assert_has_calls( self.mock_bridge.br.add_flow.assert_has_calls(
filter_rules, any_order=True) filter_rules, any_order=True)
self._assert_invalid_conntrack_entries_deleted(port_dict)
def test_update_port_filter_create_new_port_if_not_present(self): def test_update_port_filter_create_new_port_if_not_present(self):
port_dict = {'device': 'port-id', port_dict = {'device': 'port-id',
@ -815,15 +832,18 @@ class TestOVSFirewallDriver(base.BaseTestCase):
self.assertFalse(self.mock_bridge.br.delete_flows.called) self.assertFalse(self.mock_bridge.br.delete_flows.called)
self.assertTrue(initialize_port_flows_mock.called) self.assertTrue(initialize_port_flows_mock.called)
self.assertTrue(add_flows_from_rules_mock.called) self.assertTrue(add_flows_from_rules_mock.called)
self._assert_invalid_conntrack_entries_deleted(port_dict)
def test_update_port_filter_port_security_disabled(self): def test_update_port_filter_port_security_disabled(self):
port_dict = {'device': 'port-id', port_dict = {'device': 'port-id',
'security_groups': [1]} 'security_groups': [1]}
self._prepare_security_group() self._prepare_security_group()
self.firewall.prepare_port_filter(port_dict) self.firewall.prepare_port_filter(port_dict)
self.delete_invalid_conntrack_entries_mock.reset_mock()
port_dict['port_security_enabled'] = False port_dict['port_security_enabled'] = False
self.firewall.update_port_filter(port_dict) self.firewall.update_port_filter(port_dict)
self.assertTrue(self.mock_bridge.br.delete_flows.called) self.assertTrue(self.mock_bridge.br.delete_flows.called)
self.delete_invalid_conntrack_entries_mock.assert_not_called()
def test_update_port_filter_applies_added_flows(self): def test_update_port_filter_applies_added_flows(self):
"""Check flows are applied right after _set_flows is called.""" """Check flows are applied right after _set_flows is called."""
@ -844,6 +864,7 @@ class TestOVSFirewallDriver(base.BaseTestCase):
self.mock_bridge.br.get_vif_port_by_id.return_value = None self.mock_bridge.br.get_vif_port_by_id.return_value = None
self.firewall.update_port_filter(port_dict) self.firewall.update_port_filter(port_dict)
self.assertTrue(self.mock_bridge.br.delete_flows.called) self.assertTrue(self.mock_bridge.br.delete_flows.called)
self._assert_invalid_conntrack_entries_deleted(port_dict)
def test_remove_port_filter(self): def test_remove_port_filter(self):
port_dict = {'device': 'port-id', port_dict = {'device': 'port-id',

View File

@ -39,3 +39,26 @@ class IPConntrackTestCase(base.BaseTestCase):
dev_info_list = [dev_info for _ in range(10)] dev_info_list = [dev_info for _ in range(10)]
self.mgr._delete_conntrack_state(dev_info_list, rule) self.mgr._delete_conntrack_state(dev_info_list, rule)
self.assertEqual(1, len(self.execute.mock_calls)) self.assertEqual(1, len(self.execute.mock_calls))
class OvsIPConntrackTestCase(IPConntrackTestCase):
def setUp(self):
super(IPConntrackTestCase, self).setUp()
self.execute = mock.Mock()
self.mgr = ip_conntrack.OvsIpConntrackManager(self.execute)
def test_delete_conntrack_state_dedupes(self):
rule = {'ethertype': 'IPv4', 'direction': 'ingress'}
dev_info = {
'device': 'tapdevice',
'fixed_ips': ['1.2.3.4'],
'of_port': mock.Mock(of_port=10)}
dev_info_list = [dev_info for _ in range(10)]
self.mgr._delete_conntrack_state(dev_info_list, rule)
self.assertEqual(1, len(self.execute.mock_calls))
def test_get_device_zone(self):
of_port = mock.Mock(vlan_tag=10)
port = {'id': 'port-id', 'of_port': of_port}
self.assertEqual(10, self.mgr.get_device_zone(port))