@ -438,6 +438,9 @@ class TestOVSFirewallDriver(base.BaseTestCase):
ovs_lib , ' OVSBridge ' , autospec = True ) . start ( )
securitygroups_rpc . register_securitygroups_opts ( )
self . firewall = ovsfw . OVSFirewallDriver ( mock_bridge )
self . delete_invalid_conntrack_entries_mock = mock . patch . object (
self . firewall . ipconntrack ,
" delete_conntrack_state_by_remote_ips " ) . start ( )
self . mock_bridge = self . firewall . int_br
self . mock_bridge . reset_mock ( )
self . fake_ovs_port = FakeOVSPort ( ' port ' , 1 , ' 00:00:00:00:00:00 ' )
@ -462,6 +465,16 @@ class TestOVSFirewallDriver(base.BaseTestCase):
' direction ' : constants . EGRESS_DIRECTION } ]
self . firewall . update_security_group_rules ( 2 , security_group_rules )
def _assert_invalid_conntrack_entries_deleted ( self , port_dict ) :
port_dict [ ' of_port ' ] = mock . Mock ( vlan_tag = 10 )
self . delete_invalid_conntrack_entries_mock . assert_has_calls ( [
mock . call (
[ port_dict ] , constants . IPv4 , set ( ) ,
mark = ovsfw_consts . CT_MARK_INVALID ) ,
mock . call (
[ port_dict ] , constants . IPv6 , set ( ) ,
mark = ovsfw_consts . CT_MARK_INVALID ) ] )
@property
def port_ofport ( self ) :
return self . mock_bridge . br . get_vif_port_by_id . return_value . ofport
@ -619,6 +632,7 @@ class TestOVSFirewallDriver(base.BaseTestCase):
calls = self . mock_bridge . br . add_flow . call_args_list
for call in exp_ingress_classifier , exp_egress_classifier , filter_rule :
self . assertIn ( call , calls )
self . _assert_invalid_conntrack_entries_deleted ( port_dict )
def test_prepare_port_filter_port_security_disabled ( self ) :
port_dict = { ' device ' : ' port-id ' ,
@ -629,6 +643,7 @@ class TestOVSFirewallDriver(base.BaseTestCase):
self . firewall , ' initialize_port_flows ' ) as m_init_flows :
self . firewall . prepare_port_filter ( port_dict )
self . assertFalse ( m_init_flows . called )
self . delete_invalid_conntrack_entries_mock . assert_not_called ( )
def _test_initialize_port_flows_dvr_conntrack_direct ( self , network_type ) :
port_dict = {
@ -800,6 +815,7 @@ class TestOVSFirewallDriver(base.BaseTestCase):
self . assertFalse ( self . mock_bridge . br . delete_flows . called )
self . firewall . prepare_port_filter ( port_dict )
self . assertTrue ( self . mock_bridge . br . delete_flows . called )
self . _assert_invalid_conntrack_entries_deleted ( port_dict )
def test_update_port_filter ( self ) :
port_dict = { ' device ' : ' port-id ' ,
@ -831,6 +847,7 @@ class TestOVSFirewallDriver(base.BaseTestCase):
table = ovs_consts . RULES_EGRESS_TABLE ) ]
self . mock_bridge . br . add_flow . assert_has_calls (
filter_rules , any_order = True )
self . _assert_invalid_conntrack_entries_deleted ( port_dict )
def test_update_port_filter_create_new_port_if_not_present ( self ) :
port_dict = { ' device ' : ' port-id ' ,
@ -850,15 +867,18 @@ class TestOVSFirewallDriver(base.BaseTestCase):
self . assertFalse ( self . mock_bridge . br . delete_flows . called )
self . assertTrue ( initialize_port_flows_mock . called )
self . assertTrue ( add_flows_from_rules_mock . called )
self . _assert_invalid_conntrack_entries_deleted ( port_dict )
def test_update_port_filter_port_security_disabled ( self ) :
port_dict = { ' device ' : ' port-id ' ,
' security_groups ' : [ 1 ] }
self . _prepare_security_group ( )
self . firewall . prepare_port_filter ( port_dict )
self . delete_invalid_conntrack_entries_mock . reset_mock ( )
port_dict [ ' port_security_enabled ' ] = False
self . firewall . update_port_filter ( port_dict )
self . assertTrue ( self . mock_bridge . br . delete_flows . called )
self . delete_invalid_conntrack_entries_mock . assert_not_called ( )
def test_update_port_filter_applies_added_flows ( self ) :
""" Check flows are applied right after _set_flows is called. """
@ -879,6 +899,7 @@ class TestOVSFirewallDriver(base.BaseTestCase):
self . mock_bridge . br . get_vif_port_by_id . return_value = None
self . firewall . update_port_filter ( port_dict )
self . assertTrue ( self . mock_bridge . br . delete_flows . called )
self . _assert_invalid_conntrack_entries_deleted ( port_dict )
def test_remove_port_filter ( self ) :
port_dict = { ' device ' : ' port-id ' ,