Merge "[OVS FW] Clean conntrack entries with mark == CT_MARK_INVALID" into stable/rocky

This commit is contained in:
Zuul 2021-03-10 18:20:03 +00:00 committed by Gerrit Code Review
commit 96c75fffb1
4 changed files with 79 additions and 1 deletions

View File

@ -116,10 +116,13 @@ class IpConntrackManager(object):
ethertype = rule.get('ethertype')
protocol = rule.get('protocol')
direction = rule.get('direction')
mark = rule.get('mark')
cmd = ['conntrack', '-D']
if protocol:
cmd.extend(['-p', str(protocol)])
cmd.extend(['-f', str(ethertype).lower()])
if mark is not None:
cmd.extend(['-m', str(mark)])
cmd.append('-d' if direction == 'ingress' else '-s')
cmd_ns = []
if namespace:
@ -170,10 +173,12 @@ class IpConntrackManager(object):
self._process(device_info_list, rule)
def delete_conntrack_state_by_remote_ips(self, device_info_list,
ethertype, remote_ips):
ethertype, remote_ips, mark=None):
for direction in ['ingress', 'egress']:
rule = {'ethertype': str(ethertype).lower(),
'direction': direction}
if mark:
rule['mark'] = mark
self._process(device_info_list, rule, remote_ips)
def _populate_initial_zone_map(self):
@ -251,3 +256,21 @@ class IpConntrackManager(object):
return index + ZONE_START
# conntrack zones exhausted :( :(
raise n_exc.CTZoneExhaustedError()
class OvsIpConntrackManager(IpConntrackManager):
def __init__(self, execute=None):
super(OvsIpConntrackManager, self).__init__(
get_rules_for_table_func=None,
filtered_ports={}, unfiltered_ports={},
execute=execute, namespace=None, zone_per_port=False)
def _populate_initial_zone_map(self):
self._device_zone_map = {}
def get_device_zone(self, port, create=False):
of_port = port.get('of_port')
if of_port is None:
return
return of_port.vlan_tag

View File

@ -31,6 +31,7 @@ from oslo_utils import netutils
from neutron.agent.common import ovs_lib
from neutron.agent import firewall
from neutron.agent.linux import ip_conntrack
from neutron.agent.linux.openvswitch_firewall import constants as ovsfw_consts
from neutron.agent.linux.openvswitch_firewall import exceptions
from neutron.agent.linux.openvswitch_firewall import iptables
@ -480,6 +481,7 @@ class OVSFirewallDriver(firewall.FirewallDriver):
self._deferred = False
self.iptables_helper = iptables.Helper(self.int_br.br)
self.iptables_helper.load_driver_if_needed()
self.ipconntrack = ip_conntrack.OvsIpConntrackManager()
self._initialize_firewall()
callbacks_registry.subscribe(
@ -610,6 +612,12 @@ class OVSFirewallDriver(firewall.FirewallDriver):
return get_physical_network_from_other_config(
self.int_br.br, port_name)
def _delete_invalid_conntrack_entries_for_port(self, port, of_port):
port['of_port'] = of_port
for ethertype in [lib_const.IPv4, lib_const.IPv6]:
self.ipconntrack.delete_conntrack_state_by_remote_ips(
[port], ethertype, set(), mark=ovsfw_consts.CT_MARK_INVALID)
def get_ofport(self, port):
port_id = port['device']
return self.sg_port_map.ports.get(port_id)
@ -664,6 +672,7 @@ class OVSFirewallDriver(firewall.FirewallDriver):
self._update_flows_for_port(of_port, old_of_port)
else:
self._set_port_filters(of_port)
self._delete_invalid_conntrack_entries_for_port(port, of_port)
except exceptions.OVSFWPortNotFound as not_found_error:
LOG.info("port %(port_id)s does not exist in ovsdb: %(err)s.",
{'port_id': port['device'],
@ -703,6 +712,8 @@ class OVSFirewallDriver(firewall.FirewallDriver):
else:
self._set_port_filters(of_port)
self._delete_invalid_conntrack_entries_for_port(port, of_port)
except exceptions.OVSFWPortNotFound as not_found_error:
LOG.info("port %(port_id)s does not exist in ovsdb: %(err)s.",
{'port_id': port['device'],

View File

@ -443,6 +443,9 @@ class TestOVSFirewallDriver(base.BaseTestCase):
mock_bridge = mock.patch.object(
ovs_lib, 'OVSBridge', autospec=True).start()
self.firewall = ovsfw.OVSFirewallDriver(mock_bridge)
self.delete_invalid_conntrack_entries_mock = mock.patch.object(
self.firewall.ipconntrack,
"delete_conntrack_state_by_remote_ips").start()
self.mock_bridge = self.firewall.int_br
self.mock_bridge.reset_mock()
self.fake_ovs_port = FakeOVSPort('port', 1, '00:00:00:00:00:00')
@ -467,6 +470,16 @@ class TestOVSFirewallDriver(base.BaseTestCase):
'direction': constants.EGRESS_DIRECTION}]
self.firewall.update_security_group_rules(2, security_group_rules)
def _assert_invalid_conntrack_entries_deleted(self, port_dict):
port_dict['of_port'] = mock.Mock(vlan_tag=10)
self.delete_invalid_conntrack_entries_mock.assert_has_calls([
mock.call(
[port_dict], constants.IPv4, set(),
mark=ovsfw_consts.CT_MARK_INVALID),
mock.call(
[port_dict], constants.IPv6, set(),
mark=ovsfw_consts.CT_MARK_INVALID)])
@property
def port_ofport(self):
return self.mock_bridge.br.get_vif_port_by_id.return_value.ofport
@ -624,6 +637,7 @@ class TestOVSFirewallDriver(base.BaseTestCase):
calls = self.mock_bridge.br.add_flow.call_args_list
for call in exp_ingress_classifier, exp_egress_classifier, filter_rule:
self.assertIn(call, calls)
self._assert_invalid_conntrack_entries_deleted(port_dict)
def test_prepare_port_filter_port_security_disabled(self):
port_dict = {'device': 'port-id',
@ -634,6 +648,7 @@ class TestOVSFirewallDriver(base.BaseTestCase):
self.firewall, 'initialize_port_flows') as m_init_flows:
self.firewall.prepare_port_filter(port_dict)
self.assertFalse(m_init_flows.called)
self.delete_invalid_conntrack_entries_mock.assert_not_called()
def test_initialize_port_flows_vlan_dvr_conntrack_direct(self):
port_dict = {
@ -772,6 +787,7 @@ class TestOVSFirewallDriver(base.BaseTestCase):
self.assertFalse(self.mock_bridge.br.delete_flows.called)
self.firewall.prepare_port_filter(port_dict)
self.assertTrue(self.mock_bridge.br.delete_flows.called)
self._assert_invalid_conntrack_entries_deleted(port_dict)
def test_update_port_filter(self):
port_dict = {'device': 'port-id',
@ -803,6 +819,7 @@ class TestOVSFirewallDriver(base.BaseTestCase):
table=ovs_consts.RULES_EGRESS_TABLE)]
self.mock_bridge.br.add_flow.assert_has_calls(
filter_rules, any_order=True)
self._assert_invalid_conntrack_entries_deleted(port_dict)
def test_update_port_filter_create_new_port_if_not_present(self):
port_dict = {'device': 'port-id',
@ -822,15 +839,18 @@ class TestOVSFirewallDriver(base.BaseTestCase):
self.assertFalse(self.mock_bridge.br.delete_flows.called)
self.assertTrue(initialize_port_flows_mock.called)
self.assertTrue(add_flows_from_rules_mock.called)
self._assert_invalid_conntrack_entries_deleted(port_dict)
def test_update_port_filter_port_security_disabled(self):
port_dict = {'device': 'port-id',
'security_groups': [1]}
self._prepare_security_group()
self.firewall.prepare_port_filter(port_dict)
self.delete_invalid_conntrack_entries_mock.reset_mock()
port_dict['port_security_enabled'] = False
self.firewall.update_port_filter(port_dict)
self.assertTrue(self.mock_bridge.br.delete_flows.called)
self.delete_invalid_conntrack_entries_mock.assert_not_called()
def test_update_port_filter_applies_added_flows(self):
"""Check flows are applied right after _set_flows is called."""
@ -851,6 +871,7 @@ class TestOVSFirewallDriver(base.BaseTestCase):
self.mock_bridge.br.get_vif_port_by_id.return_value = None
self.firewall.update_port_filter(port_dict)
self.assertTrue(self.mock_bridge.br.delete_flows.called)
self._assert_invalid_conntrack_entries_deleted(port_dict)
def test_remove_port_filter(self):
port_dict = {'device': 'port-id',

View File

@ -39,3 +39,26 @@ class IPConntrackTestCase(base.BaseTestCase):
dev_info_list = [dev_info for _ in range(10)]
self.mgr._delete_conntrack_state(dev_info_list, rule)
self.assertEqual(1, len(self.execute.mock_calls))
class OvsIPConntrackTestCase(IPConntrackTestCase):
def setUp(self):
super(IPConntrackTestCase, self).setUp()
self.execute = mock.Mock()
self.mgr = ip_conntrack.OvsIpConntrackManager(self.execute)
def test_delete_conntrack_state_dedupes(self):
rule = {'ethertype': 'IPv4', 'direction': 'ingress'}
dev_info = {
'device': 'tapdevice',
'fixed_ips': ['1.2.3.4'],
'of_port': mock.Mock(of_port=10)}
dev_info_list = [dev_info for _ in range(10)]
self.mgr._delete_conntrack_state(dev_info_list, rule)
self.assertEqual(1, len(self.execute.mock_calls))
def test_get_device_zone(self):
of_port = mock.Mock(vlan_tag=10)
port = {'id': 'port-id', 'of_port': of_port}
self.assertEqual(10, self.mgr.get_device_zone(port))