Refactor _remove_unused_security_group_info
_remove_unused_security_group_info is refactored into smaller functions, to make this block easier to understand. Implements blueprint refactor-iptables-firewall-driver Change-Id: I4107f1a702d059337e7b2d701a5d0372ee2cfe11
This commit is contained in:
parent
f1fe1fe912
commit
6bc82841c5
|
@ -13,6 +13,7 @@
|
|||
# License for the specific language governing permissions and limitations
|
||||
# under the License.
|
||||
|
||||
import collections
|
||||
import netaddr
|
||||
from oslo.config import cfg
|
||||
|
||||
|
@ -64,7 +65,8 @@ class IptablesFirewallDriver(firewall.FirewallDriver):
|
|||
self.sg_rules = {}
|
||||
self.pre_sg_rules = None
|
||||
# List of security group member ips for ports residing on this host
|
||||
self.sg_members = {}
|
||||
self.sg_members = collections.defaultdict(
|
||||
lambda: collections.defaultdict(list))
|
||||
self.pre_sg_members = None
|
||||
self.enable_ipset = cfg.CONF.SECURITYGROUP.enable_ipset
|
||||
|
||||
|
@ -78,7 +80,7 @@ class IptablesFirewallDriver(firewall.FirewallDriver):
|
|||
|
||||
def update_security_group_members(self, sg_id, sg_members):
|
||||
LOG.debug("Update members of security group (%s)", sg_id)
|
||||
self.sg_members[sg_id] = sg_members
|
||||
self.sg_members[sg_id] = collections.defaultdict(list, sg_members)
|
||||
|
||||
def prepare_port_filter(self, port):
|
||||
LOG.debug("Preparing device (%s) filter", port['device'])
|
||||
|
@ -323,12 +325,12 @@ class IptablesFirewallDriver(firewall.FirewallDriver):
|
|||
else:
|
||||
yield rule
|
||||
|
||||
def _get_remote_sg_ids(self, port, direction):
|
||||
def _get_remote_sg_ids(self, port, direction=None):
|
||||
sg_ids = port.get('security_groups', [])
|
||||
remote_sg_ids = {constants.IPv4: [], constants.IPv6: []}
|
||||
for sg_id in sg_ids:
|
||||
for rule in self.sg_rules.get(sg_id, []):
|
||||
if rule['direction'] == direction:
|
||||
if not direction or rule['direction'] == direction:
|
||||
remote_sg_id = rule.get('remote_group_id')
|
||||
ether_type = rule.get('ethertype')
|
||||
if remote_sg_id and ether_type:
|
||||
|
@ -374,15 +376,12 @@ class IptablesFirewallDriver(firewall.FirewallDriver):
|
|||
ipv6_iptables_rules)
|
||||
self._drop_dhcp_rule(ipv4_iptables_rules, ipv6_iptables_rules)
|
||||
|
||||
def _get_current_sg_member_ips(self, sg_id, ethertype):
|
||||
return self.sg_members.get(sg_id, {}).get(ethertype, [])
|
||||
|
||||
def _update_ipset_members(self, security_group_ids):
|
||||
for ethertype, sg_ids in security_group_ids.items():
|
||||
for ip_version, sg_ids in security_group_ids.items():
|
||||
for sg_id in sg_ids:
|
||||
current_ips = self._get_current_sg_member_ips(sg_id, ethertype)
|
||||
current_ips = self.sg_members[sg_id][ip_version]
|
||||
if current_ips:
|
||||
self.ipset.set_members(sg_id, ethertype, current_ips)
|
||||
self.ipset.set_members(sg_id, ip_version, current_ips)
|
||||
|
||||
def _generate_ipset_rule_args(self, sg_rule, remote_gid):
|
||||
ethertype = sg_rule.get('ethertype')
|
||||
|
@ -505,47 +504,88 @@ class IptablesFirewallDriver(firewall.FirewallDriver):
|
|||
self._defer_apply = True
|
||||
|
||||
def _remove_unused_security_group_info(self):
|
||||
need_removed_ipsets = {constants.IPv4: set(),
|
||||
constants.IPv6: set()}
|
||||
need_removed_security_groups = set()
|
||||
remote_group_ids = {constants.IPv4: set(),
|
||||
constants.IPv6: set()}
|
||||
current_group_ids = set()
|
||||
for port in self.filtered_ports.values():
|
||||
for direction in INGRESS_DIRECTION, EGRESS_DIRECTION:
|
||||
for ethertype, sg_ids in self._get_remote_sg_ids(
|
||||
port, direction).items():
|
||||
remote_group_ids[ethertype].update(sg_ids)
|
||||
groups = port.get('security_groups', [])
|
||||
current_group_ids.update(groups)
|
||||
"""Remove any unnecesary local security group info or unused ipsets.
|
||||
|
||||
for ethertype in [constants.IPv4, constants.IPv6]:
|
||||
need_removed_ipsets[ethertype].update(
|
||||
[x for x in self.pre_sg_members if x not in remote_group_ids[
|
||||
ethertype]])
|
||||
need_removed_security_groups.update(
|
||||
[x for x in self.pre_sg_rules if x not in current_group_ids])
|
||||
This function has to be called after applying the last iptables
|
||||
rules, so we're in a point where no iptable rule depends
|
||||
on an ipset we're going to delete.
|
||||
"""
|
||||
filtered_ports = self.filtered_ports.values()
|
||||
|
||||
# Remove unused ip sets (sg_members and kernel ipset if we
|
||||
# are using ipset)
|
||||
for ethertype, remove_set_ids in need_removed_ipsets.items():
|
||||
for remove_set_id in remove_set_ids:
|
||||
if self.sg_members.get(remove_set_id, {}).get(ethertype, []):
|
||||
self.sg_members[remove_set_id][ethertype] = []
|
||||
if self.enable_ipset:
|
||||
self.ipset.destroy(remove_set_id, ethertype)
|
||||
remote_sgs_to_remove = self._determine_remote_sgs_to_remove(
|
||||
filtered_ports)
|
||||
|
||||
# Remove unused remote security group member ips
|
||||
sg_ids = self.sg_members.keys()
|
||||
for sg_id in sg_ids:
|
||||
if not (self.sg_members[sg_id].get(constants.IPv4, [])
|
||||
or self.sg_members[sg_id].get(constants.IPv6, [])):
|
||||
self.sg_members.pop(sg_id, None)
|
||||
for ip_version, remote_sg_ids in remote_sgs_to_remove.iteritems():
|
||||
self._clear_sg_members(ip_version, remote_sg_ids)
|
||||
if self.enable_ipset:
|
||||
self._remove_ipsets_for_remote_sgs(ip_version, remote_sg_ids)
|
||||
|
||||
self._remove_unused_sg_members()
|
||||
|
||||
# Remove unused security group rules
|
||||
for remove_group_id in need_removed_security_groups:
|
||||
if remove_group_id in self.sg_rules:
|
||||
self.sg_rules.pop(remove_group_id, None)
|
||||
for remove_group_id in self._determine_sg_rules_to_remove(
|
||||
filtered_ports):
|
||||
self.sg_rules.pop(remove_group_id, None)
|
||||
|
||||
def _determine_remote_sgs_to_remove(self, filtered_ports):
|
||||
"""Calculate which remote security groups we don't need anymore.
|
||||
|
||||
We do the calculation for each ip_version.
|
||||
"""
|
||||
sgs_to_remove_per_ipversion = {constants.IPv4: set(),
|
||||
constants.IPv6: set()}
|
||||
remote_group_id_sets = self._get_remote_sg_ids_sets_by_ipversion(
|
||||
filtered_ports)
|
||||
for ip_version, remote_group_id_set in (
|
||||
remote_group_id_sets.iteritems()):
|
||||
sgs_to_remove_per_ipversion[ip_version].update(
|
||||
set(self.pre_sg_members) - remote_group_id_set)
|
||||
return sgs_to_remove_per_ipversion
|
||||
|
||||
def _get_remote_sg_ids_sets_by_ipversion(self, filtered_ports):
|
||||
"""Given a port, calculates the remote sg references by ip_version."""
|
||||
remote_group_id_sets = {constants.IPv4: set(),
|
||||
constants.IPv6: set()}
|
||||
for port in filtered_ports:
|
||||
for ip_version, sg_ids in self._get_remote_sg_ids(
|
||||
port).iteritems():
|
||||
remote_group_id_sets[ip_version].update(sg_ids)
|
||||
return remote_group_id_sets
|
||||
|
||||
def _determine_sg_rules_to_remove(self, filtered_ports):
|
||||
"""Calculate which security groups need to be removed.
|
||||
|
||||
We find out by substracting our previous sg group ids,
|
||||
with the security groups associated to a set of ports.
|
||||
"""
|
||||
port_group_ids = self._get_sg_ids_set_for_ports(filtered_ports)
|
||||
return set(self.pre_sg_rules) - port_group_ids
|
||||
|
||||
def _get_sg_ids_set_for_ports(self, filtered_ports):
|
||||
"""Get the port security group ids as a set."""
|
||||
port_group_ids = set()
|
||||
for port in filtered_ports:
|
||||
port_group_ids.update(port.get('security_groups', []))
|
||||
return port_group_ids
|
||||
|
||||
def _clear_sg_members(self, ip_version, remote_sg_ids):
|
||||
"""Clear our internal cache of sg members matching the parameters."""
|
||||
for remote_sg_id in remote_sg_ids:
|
||||
if self.sg_members[remote_sg_id][ip_version]:
|
||||
self.sg_members[remote_sg_id][ip_version] = []
|
||||
|
||||
def _remove_ipsets_for_remote_sgs(self, ip_version, remote_sg_ids):
|
||||
"""Remove system ipsets matching the provided parameters."""
|
||||
for remote_sg_id in remote_sg_ids:
|
||||
self.ipset.destroy(remote_sg_id, ip_version)
|
||||
|
||||
def _remove_unused_sg_members(self):
|
||||
"""Remove sg_member entries where no IPv4 or IPv6 is associated."""
|
||||
for sg_id in self.sg_members.keys():
|
||||
sg_has_members = (self.sg_members[sg_id][constants.IPv4] or
|
||||
self.sg_members[sg_id][constants.IPv6])
|
||||
if not sg_has_members:
|
||||
del self.sg_members[sg_id]
|
||||
|
||||
def filter_defer_apply_off(self):
|
||||
if self._defer_apply:
|
||||
|
|
|
@ -34,8 +34,11 @@ FAKE_PREFIX = {'IPv4': '10.0.0.0/24',
|
|||
'IPv6': 'fe80::/48'}
|
||||
FAKE_IP = {'IPv4': '10.0.0.1',
|
||||
'IPv6': 'fe80::1'}
|
||||
#TODO(mangelajo): replace all 'fake_sgid' strings for the constant
|
||||
#TODO(mangelajo): replace all '*_sgid' strings for the constants
|
||||
FAKE_SGID = 'fake_sgid'
|
||||
OTHER_SGID = 'other_sgid'
|
||||
_IPv6 = constants.IPv6
|
||||
_IPv4 = constants.IPv4
|
||||
|
||||
|
||||
class BaseIptablesFirewallTestCase(base.BaseTestCase):
|
||||
|
@ -1420,16 +1423,25 @@ class IptablesFirewallEnhancedIpsetTestCase(BaseIptablesFirewallTestCase):
|
|||
'security_groups': [sg_id],
|
||||
'security_group_source_groups': [sg_id]}
|
||||
|
||||
def _fake_sg_rule_for_ethertype(self, ethertype):
|
||||
return {'direction': 'ingress', 'remote_group_id': 'fake_sgid',
|
||||
def _fake_sg_rule_for_ethertype(self, ethertype, remote_group):
|
||||
return {'direction': 'ingress', 'remote_group_id': remote_group,
|
||||
'ethertype': ethertype}
|
||||
|
||||
def _fake_sg_rule(self):
|
||||
return {'fake_sgid': [self._fake_sg_rule_for_ethertype('IPv4'),
|
||||
self._fake_sg_rule_for_ethertype('IPv6')]}
|
||||
def _fake_sg_rules(self, sg_id=FAKE_SGID, remote_groups=None):
|
||||
remote_groups = remote_groups or {_IPv4: [FAKE_SGID],
|
||||
_IPv6: [FAKE_SGID]}
|
||||
rules = []
|
||||
for ip_version, remote_group_list in remote_groups.iteritems():
|
||||
for remote_group in remote_group_list:
|
||||
rules.append(self._fake_sg_rule_for_ethertype(ip_version,
|
||||
remote_group))
|
||||
return {sg_id: rules}
|
||||
|
||||
def _fake_sg_members(self, sg_ids=None):
|
||||
return {sg_id: copy.copy(FAKE_IP) for sg_id in (sg_ids or [FAKE_SGID])}
|
||||
|
||||
def test_prepare_port_filter_with_new_members(self):
|
||||
self.firewall.sg_rules = self._fake_sg_rule()
|
||||
self.firewall.sg_rules = self._fake_sg_rules()
|
||||
self.firewall.sg_members = {'fake_sgid': {
|
||||
'IPv4': ['10.0.0.1', '10.0.0.2'], 'IPv6': ['fe80::1']}}
|
||||
self.firewall.pre_sg_members = {}
|
||||
|
@ -1444,34 +1456,97 @@ class IptablesFirewallEnhancedIpsetTestCase(BaseIptablesFirewallTestCase):
|
|||
self.firewall.ipset.assert_has_calls(calls)
|
||||
|
||||
def _setup_fake_firewall_members_and_rules(self, firewall):
|
||||
firewall.sg_rules = self._fake_sg_rule()
|
||||
firewall.pre_sg_rules = self._fake_sg_rule()
|
||||
firewall.sg_members = {'fake_sgid': {
|
||||
'IPv4': ['10.0.0.1'],
|
||||
'IPv6': ['fe80::1']}}
|
||||
firewall.sg_rules = self._fake_sg_rules()
|
||||
firewall.pre_sg_rules = self._fake_sg_rules()
|
||||
firewall.sg_members = self._fake_sg_members()
|
||||
firewall.pre_sg_members = firewall.sg_members
|
||||
|
||||
def _prepare_rules_and_members_for_removal(self):
|
||||
self._setup_fake_firewall_members_and_rules(self.firewall)
|
||||
self.firewall.pre_sg_members[OTHER_SGID] = (
|
||||
self.firewall.pre_sg_members[FAKE_SGID])
|
||||
|
||||
def test_determine_remote_sgs_to_remove(self):
|
||||
self._prepare_rules_and_members_for_removal()
|
||||
ports = [self._fake_port()]
|
||||
|
||||
self.assertEqual(
|
||||
{_IPv4: set([OTHER_SGID]), _IPv6: set([OTHER_SGID])},
|
||||
self.firewall._determine_remote_sgs_to_remove(ports))
|
||||
|
||||
def test_determine_remote_sgs_to_remove_ipv6_unreferenced(self):
|
||||
self._prepare_rules_and_members_for_removal()
|
||||
ports = [self._fake_port()]
|
||||
self.firewall.sg_rules = self._fake_sg_rules(
|
||||
remote_groups={_IPv4: [OTHER_SGID, FAKE_SGID],
|
||||
_IPv6: [FAKE_SGID]})
|
||||
self.assertEqual(
|
||||
{_IPv4: set(), _IPv6: set([OTHER_SGID])},
|
||||
self.firewall._determine_remote_sgs_to_remove(ports))
|
||||
|
||||
def test_get_remote_sg_ids_by_ipversion(self):
|
||||
self.firewall.sg_rules = self._fake_sg_rules(
|
||||
remote_groups={_IPv4: [FAKE_SGID], _IPv6: [OTHER_SGID]})
|
||||
|
||||
ports = [self._fake_port()]
|
||||
|
||||
self.assertEqual(
|
||||
{_IPv4: set([FAKE_SGID]), _IPv6: set([OTHER_SGID])},
|
||||
self.firewall._get_remote_sg_ids_sets_by_ipversion(ports))
|
||||
|
||||
def test_determine_sg_rules_to_remove(self):
|
||||
self.firewall.pre_sg_rules = self._fake_sg_rules(sg_id=OTHER_SGID)
|
||||
ports = [self._fake_port()]
|
||||
|
||||
self.assertEqual(set([OTHER_SGID]),
|
||||
self.firewall._determine_sg_rules_to_remove(ports))
|
||||
|
||||
def test_get_sg_ids_set_for_ports(self):
|
||||
sg_ids = set([FAKE_SGID, OTHER_SGID])
|
||||
ports = [self._fake_port(sg_id) for sg_id in sg_ids]
|
||||
|
||||
self.assertEqual(sg_ids,
|
||||
self.firewall._get_sg_ids_set_for_ports(ports))
|
||||
|
||||
def test_clear_sg_members(self):
|
||||
self.firewall.sg_members = self._fake_sg_members(
|
||||
sg_ids=[FAKE_SGID, OTHER_SGID])
|
||||
self.firewall._clear_sg_members(_IPv4, [OTHER_SGID])
|
||||
|
||||
self.assertEqual(0, len(self.firewall.sg_members[OTHER_SGID][_IPv4]))
|
||||
|
||||
def test_remove_unused_sg_members(self):
|
||||
self.firewall.sg_members = self._fake_sg_members([FAKE_SGID,
|
||||
OTHER_SGID])
|
||||
self.firewall.sg_members[FAKE_SGID][_IPv4] = []
|
||||
self.firewall.sg_members[FAKE_SGID][_IPv6] = []
|
||||
self.firewall.sg_members[OTHER_SGID][_IPv6] = []
|
||||
self.firewall._remove_unused_sg_members()
|
||||
|
||||
self.assertIn(OTHER_SGID, self.firewall.sg_members)
|
||||
self.assertNotIn(FAKE_SGID, self.firewall.sg_members)
|
||||
|
||||
def test_remove_unused_security_group_info_clears_unused_rules(self):
|
||||
self._setup_fake_firewall_members_and_rules(self.firewall)
|
||||
self.firewall.prepare_port_filter(self._fake_port())
|
||||
|
||||
# create another SG which won't be referenced by any filtered port
|
||||
fake_sg_rules = self.firewall.sg_rules['fake_sgid']
|
||||
self.firewall.pre_sg_rules['other_sgid'] = fake_sg_rules
|
||||
self.firewall.sg_rules['other_sgid'] = fake_sg_rules
|
||||
self.firewall.pre_sg_rules[OTHER_SGID] = fake_sg_rules
|
||||
self.firewall.sg_rules[OTHER_SGID] = fake_sg_rules
|
||||
|
||||
# call the cleanup function, and check the unused sg_rules are out
|
||||
self.firewall._remove_unused_security_group_info()
|
||||
self.assertNotIn('other_sgid', self.firewall.sg_rules)
|
||||
self.assertNotIn(OTHER_SGID, self.firewall.sg_rules)
|
||||
|
||||
def test_remove_unused_sg_members(self):
|
||||
def test_remove_unused_security_group_info(self):
|
||||
self._setup_fake_firewall_members_and_rules(self.firewall)
|
||||
# no filtered ports in 'fake_sgid', so all rules and members
|
||||
# are not needed and we expect them to be cleaned up
|
||||
self.firewall.prepare_port_filter(self._fake_port('other_sgid'))
|
||||
self.firewall.prepare_port_filter(self._fake_port(OTHER_SGID))
|
||||
self.firewall._remove_unused_security_group_info()
|
||||
|
||||
self.assertNotIn('fake_sgid', self.firewall.sg_members)
|
||||
self.assertNotIn(FAKE_SGID, self.firewall.sg_members)
|
||||
|
||||
def test_remove_all_unused_info(self):
|
||||
self._setup_fake_firewall_members_and_rules(self.firewall)
|
||||
|
@ -1481,8 +1556,8 @@ class IptablesFirewallEnhancedIpsetTestCase(BaseIptablesFirewallTestCase):
|
|||
self.assertFalse(self.firewall.sg_rules)
|
||||
|
||||
def test_prepare_port_filter_with_deleted_member(self):
|
||||
self.firewall.sg_rules = self._fake_sg_rule()
|
||||
self.firewall.pre_sg_rules = self._fake_sg_rule()
|
||||
self.firewall.sg_rules = self._fake_sg_rules()
|
||||
self.firewall.pre_sg_rules = self._fake_sg_rules()
|
||||
self.firewall.sg_members = {'fake_sgid': {
|
||||
'IPv4': [
|
||||
'10.0.0.1', '10.0.0.3', '10.0.0.4', '10.0.0.5'],
|
||||
|
@ -1500,7 +1575,7 @@ class IptablesFirewallEnhancedIpsetTestCase(BaseIptablesFirewallTestCase):
|
|||
self.firewall.ipset.assert_has_calls(calls, True)
|
||||
|
||||
def test_remove_port_filter_with_destroy_ipset_chain(self):
|
||||
self.firewall.sg_rules = self._fake_sg_rule()
|
||||
self.firewall.sg_rules = self._fake_sg_rules()
|
||||
port = self._fake_port()
|
||||
self.firewall.sg_members = {'fake_sgid': {
|
||||
'IPv4': ['10.0.0.1'],
|
||||
|
@ -1531,13 +1606,13 @@ class IptablesFirewallEnhancedIpsetTestCase(BaseIptablesFirewallTestCase):
|
|||
self.firewall.ipset.assert_has_calls(calls)
|
||||
|
||||
def test_prepare_port_filter_with_sg_no_member(self):
|
||||
self.firewall.sg_rules = self._fake_sg_rule()
|
||||
self.firewall.sg_rules['fake_sgid'].append(
|
||||
self.firewall.sg_rules = self._fake_sg_rules()
|
||||
self.firewall.sg_rules[FAKE_SGID].append(
|
||||
{'direction': 'ingress', 'remote_group_id': 'fake_sgid2',
|
||||
'ethertype': 'IPv4'})
|
||||
self.firewall.sg_rules.update()
|
||||
self.firewall.sg_members = {'fake_sgid': {
|
||||
'IPv4': ['10.0.0.1', '10.0.0.2'], 'IPv6': ['fe80::1']}}
|
||||
self.firewall.sg_members['fake_sgid'] = {
|
||||
'IPv4': ['10.0.0.1', '10.0.0.2'], 'IPv6': ['fe80::1']}
|
||||
self.firewall.pre_sg_members = {}
|
||||
port = self._fake_port()
|
||||
port['security_group_source_groups'].append('fake_sgid2')
|
||||
|
@ -1549,8 +1624,8 @@ class IptablesFirewallEnhancedIpsetTestCase(BaseIptablesFirewallTestCase):
|
|||
self.firewall.ipset.assert_has_calls(calls)
|
||||
|
||||
def test_filter_defer_apply_off_with_sg_only_ipv6_rule(self):
|
||||
self.firewall.sg_rules = self._fake_sg_rule()
|
||||
self.firewall.pre_sg_rules = self._fake_sg_rule()
|
||||
self.firewall.sg_rules = self._fake_sg_rules()
|
||||
self.firewall.pre_sg_rules = self._fake_sg_rules()
|
||||
self.firewall.ipset_chains = {'IPv4fake_sgid': ['10.0.0.2'],
|
||||
'IPv6fake_sgid': ['fe80::1']}
|
||||
self.firewall.sg_members = {'fake_sgid': {
|
||||
|
@ -1579,7 +1654,7 @@ class IptablesFirewallEnhancedIpsetTestCase(BaseIptablesFirewallTestCase):
|
|||
'IPv6': [FAKE_IP['IPv6']]}}
|
||||
|
||||
port = self._fake_port()
|
||||
rule = self._fake_sg_rule_for_ethertype('IPv4')
|
||||
rule = self._fake_sg_rule_for_ethertype(_IPv4, FAKE_SGID)
|
||||
rules = self.firewall._expand_sg_rule_with_remote_ips(
|
||||
rule, port, 'ingress')
|
||||
self.assertEqual(list(rules),
|
||||
|
|
Loading…
Reference in New Issue