Refactor _remove_unused_security_group_info

_remove_unused_security_group_info is refactored into smaller
functions, to make this block easier to understand.

Implements blueprint refactor-iptables-firewall-driver

Change-Id: I4107f1a702d059337e7b2d701a5d0372ee2cfe11
This commit is contained in:
Miguel Angel Ajo 2015-02-03 13:35:40 +00:00
parent f1fe1fe912
commit 6bc82841c5
2 changed files with 189 additions and 74 deletions

View File

@ -13,6 +13,7 @@
# License for the specific language governing permissions and limitations
# under the License.
import collections
import netaddr
from oslo.config import cfg
@ -64,7 +65,8 @@ class IptablesFirewallDriver(firewall.FirewallDriver):
self.sg_rules = {}
self.pre_sg_rules = None
# List of security group member ips for ports residing on this host
self.sg_members = {}
self.sg_members = collections.defaultdict(
lambda: collections.defaultdict(list))
self.pre_sg_members = None
self.enable_ipset = cfg.CONF.SECURITYGROUP.enable_ipset
@ -78,7 +80,7 @@ class IptablesFirewallDriver(firewall.FirewallDriver):
def update_security_group_members(self, sg_id, sg_members):
LOG.debug("Update members of security group (%s)", sg_id)
self.sg_members[sg_id] = sg_members
self.sg_members[sg_id] = collections.defaultdict(list, sg_members)
def prepare_port_filter(self, port):
LOG.debug("Preparing device (%s) filter", port['device'])
@ -323,12 +325,12 @@ class IptablesFirewallDriver(firewall.FirewallDriver):
else:
yield rule
def _get_remote_sg_ids(self, port, direction):
def _get_remote_sg_ids(self, port, direction=None):
sg_ids = port.get('security_groups', [])
remote_sg_ids = {constants.IPv4: [], constants.IPv6: []}
for sg_id in sg_ids:
for rule in self.sg_rules.get(sg_id, []):
if rule['direction'] == direction:
if not direction or rule['direction'] == direction:
remote_sg_id = rule.get('remote_group_id')
ether_type = rule.get('ethertype')
if remote_sg_id and ether_type:
@ -374,15 +376,12 @@ class IptablesFirewallDriver(firewall.FirewallDriver):
ipv6_iptables_rules)
self._drop_dhcp_rule(ipv4_iptables_rules, ipv6_iptables_rules)
def _get_current_sg_member_ips(self, sg_id, ethertype):
return self.sg_members.get(sg_id, {}).get(ethertype, [])
def _update_ipset_members(self, security_group_ids):
for ethertype, sg_ids in security_group_ids.items():
for ip_version, sg_ids in security_group_ids.items():
for sg_id in sg_ids:
current_ips = self._get_current_sg_member_ips(sg_id, ethertype)
current_ips = self.sg_members[sg_id][ip_version]
if current_ips:
self.ipset.set_members(sg_id, ethertype, current_ips)
self.ipset.set_members(sg_id, ip_version, current_ips)
def _generate_ipset_rule_args(self, sg_rule, remote_gid):
ethertype = sg_rule.get('ethertype')
@ -505,47 +504,88 @@ class IptablesFirewallDriver(firewall.FirewallDriver):
self._defer_apply = True
def _remove_unused_security_group_info(self):
need_removed_ipsets = {constants.IPv4: set(),
constants.IPv6: set()}
need_removed_security_groups = set()
remote_group_ids = {constants.IPv4: set(),
constants.IPv6: set()}
current_group_ids = set()
for port in self.filtered_ports.values():
for direction in INGRESS_DIRECTION, EGRESS_DIRECTION:
for ethertype, sg_ids in self._get_remote_sg_ids(
port, direction).items():
remote_group_ids[ethertype].update(sg_ids)
groups = port.get('security_groups', [])
current_group_ids.update(groups)
"""Remove any unnecesary local security group info or unused ipsets.
for ethertype in [constants.IPv4, constants.IPv6]:
need_removed_ipsets[ethertype].update(
[x for x in self.pre_sg_members if x not in remote_group_ids[
ethertype]])
need_removed_security_groups.update(
[x for x in self.pre_sg_rules if x not in current_group_ids])
This function has to be called after applying the last iptables
rules, so we're in a point where no iptable rule depends
on an ipset we're going to delete.
"""
filtered_ports = self.filtered_ports.values()
# Remove unused ip sets (sg_members and kernel ipset if we
# are using ipset)
for ethertype, remove_set_ids in need_removed_ipsets.items():
for remove_set_id in remove_set_ids:
if self.sg_members.get(remove_set_id, {}).get(ethertype, []):
self.sg_members[remove_set_id][ethertype] = []
if self.enable_ipset:
self.ipset.destroy(remove_set_id, ethertype)
remote_sgs_to_remove = self._determine_remote_sgs_to_remove(
filtered_ports)
# Remove unused remote security group member ips
sg_ids = self.sg_members.keys()
for sg_id in sg_ids:
if not (self.sg_members[sg_id].get(constants.IPv4, [])
or self.sg_members[sg_id].get(constants.IPv6, [])):
self.sg_members.pop(sg_id, None)
for ip_version, remote_sg_ids in remote_sgs_to_remove.iteritems():
self._clear_sg_members(ip_version, remote_sg_ids)
if self.enable_ipset:
self._remove_ipsets_for_remote_sgs(ip_version, remote_sg_ids)
self._remove_unused_sg_members()
# Remove unused security group rules
for remove_group_id in need_removed_security_groups:
if remove_group_id in self.sg_rules:
self.sg_rules.pop(remove_group_id, None)
for remove_group_id in self._determine_sg_rules_to_remove(
filtered_ports):
self.sg_rules.pop(remove_group_id, None)
def _determine_remote_sgs_to_remove(self, filtered_ports):
"""Calculate which remote security groups we don't need anymore.
We do the calculation for each ip_version.
"""
sgs_to_remove_per_ipversion = {constants.IPv4: set(),
constants.IPv6: set()}
remote_group_id_sets = self._get_remote_sg_ids_sets_by_ipversion(
filtered_ports)
for ip_version, remote_group_id_set in (
remote_group_id_sets.iteritems()):
sgs_to_remove_per_ipversion[ip_version].update(
set(self.pre_sg_members) - remote_group_id_set)
return sgs_to_remove_per_ipversion
def _get_remote_sg_ids_sets_by_ipversion(self, filtered_ports):
"""Given a port, calculates the remote sg references by ip_version."""
remote_group_id_sets = {constants.IPv4: set(),
constants.IPv6: set()}
for port in filtered_ports:
for ip_version, sg_ids in self._get_remote_sg_ids(
port).iteritems():
remote_group_id_sets[ip_version].update(sg_ids)
return remote_group_id_sets
def _determine_sg_rules_to_remove(self, filtered_ports):
"""Calculate which security groups need to be removed.
We find out by substracting our previous sg group ids,
with the security groups associated to a set of ports.
"""
port_group_ids = self._get_sg_ids_set_for_ports(filtered_ports)
return set(self.pre_sg_rules) - port_group_ids
def _get_sg_ids_set_for_ports(self, filtered_ports):
"""Get the port security group ids as a set."""
port_group_ids = set()
for port in filtered_ports:
port_group_ids.update(port.get('security_groups', []))
return port_group_ids
def _clear_sg_members(self, ip_version, remote_sg_ids):
"""Clear our internal cache of sg members matching the parameters."""
for remote_sg_id in remote_sg_ids:
if self.sg_members[remote_sg_id][ip_version]:
self.sg_members[remote_sg_id][ip_version] = []
def _remove_ipsets_for_remote_sgs(self, ip_version, remote_sg_ids):
"""Remove system ipsets matching the provided parameters."""
for remote_sg_id in remote_sg_ids:
self.ipset.destroy(remote_sg_id, ip_version)
def _remove_unused_sg_members(self):
"""Remove sg_member entries where no IPv4 or IPv6 is associated."""
for sg_id in self.sg_members.keys():
sg_has_members = (self.sg_members[sg_id][constants.IPv4] or
self.sg_members[sg_id][constants.IPv6])
if not sg_has_members:
del self.sg_members[sg_id]
def filter_defer_apply_off(self):
if self._defer_apply:

View File

@ -34,8 +34,11 @@ FAKE_PREFIX = {'IPv4': '10.0.0.0/24',
'IPv6': 'fe80::/48'}
FAKE_IP = {'IPv4': '10.0.0.1',
'IPv6': 'fe80::1'}
#TODO(mangelajo): replace all 'fake_sgid' strings for the constant
#TODO(mangelajo): replace all '*_sgid' strings for the constants
FAKE_SGID = 'fake_sgid'
OTHER_SGID = 'other_sgid'
_IPv6 = constants.IPv6
_IPv4 = constants.IPv4
class BaseIptablesFirewallTestCase(base.BaseTestCase):
@ -1420,16 +1423,25 @@ class IptablesFirewallEnhancedIpsetTestCase(BaseIptablesFirewallTestCase):
'security_groups': [sg_id],
'security_group_source_groups': [sg_id]}
def _fake_sg_rule_for_ethertype(self, ethertype):
return {'direction': 'ingress', 'remote_group_id': 'fake_sgid',
def _fake_sg_rule_for_ethertype(self, ethertype, remote_group):
return {'direction': 'ingress', 'remote_group_id': remote_group,
'ethertype': ethertype}
def _fake_sg_rule(self):
return {'fake_sgid': [self._fake_sg_rule_for_ethertype('IPv4'),
self._fake_sg_rule_for_ethertype('IPv6')]}
def _fake_sg_rules(self, sg_id=FAKE_SGID, remote_groups=None):
remote_groups = remote_groups or {_IPv4: [FAKE_SGID],
_IPv6: [FAKE_SGID]}
rules = []
for ip_version, remote_group_list in remote_groups.iteritems():
for remote_group in remote_group_list:
rules.append(self._fake_sg_rule_for_ethertype(ip_version,
remote_group))
return {sg_id: rules}
def _fake_sg_members(self, sg_ids=None):
return {sg_id: copy.copy(FAKE_IP) for sg_id in (sg_ids or [FAKE_SGID])}
def test_prepare_port_filter_with_new_members(self):
self.firewall.sg_rules = self._fake_sg_rule()
self.firewall.sg_rules = self._fake_sg_rules()
self.firewall.sg_members = {'fake_sgid': {
'IPv4': ['10.0.0.1', '10.0.0.2'], 'IPv6': ['fe80::1']}}
self.firewall.pre_sg_members = {}
@ -1444,34 +1456,97 @@ class IptablesFirewallEnhancedIpsetTestCase(BaseIptablesFirewallTestCase):
self.firewall.ipset.assert_has_calls(calls)
def _setup_fake_firewall_members_and_rules(self, firewall):
firewall.sg_rules = self._fake_sg_rule()
firewall.pre_sg_rules = self._fake_sg_rule()
firewall.sg_members = {'fake_sgid': {
'IPv4': ['10.0.0.1'],
'IPv6': ['fe80::1']}}
firewall.sg_rules = self._fake_sg_rules()
firewall.pre_sg_rules = self._fake_sg_rules()
firewall.sg_members = self._fake_sg_members()
firewall.pre_sg_members = firewall.sg_members
def _prepare_rules_and_members_for_removal(self):
self._setup_fake_firewall_members_and_rules(self.firewall)
self.firewall.pre_sg_members[OTHER_SGID] = (
self.firewall.pre_sg_members[FAKE_SGID])
def test_determine_remote_sgs_to_remove(self):
self._prepare_rules_and_members_for_removal()
ports = [self._fake_port()]
self.assertEqual(
{_IPv4: set([OTHER_SGID]), _IPv6: set([OTHER_SGID])},
self.firewall._determine_remote_sgs_to_remove(ports))
def test_determine_remote_sgs_to_remove_ipv6_unreferenced(self):
self._prepare_rules_and_members_for_removal()
ports = [self._fake_port()]
self.firewall.sg_rules = self._fake_sg_rules(
remote_groups={_IPv4: [OTHER_SGID, FAKE_SGID],
_IPv6: [FAKE_SGID]})
self.assertEqual(
{_IPv4: set(), _IPv6: set([OTHER_SGID])},
self.firewall._determine_remote_sgs_to_remove(ports))
def test_get_remote_sg_ids_by_ipversion(self):
self.firewall.sg_rules = self._fake_sg_rules(
remote_groups={_IPv4: [FAKE_SGID], _IPv6: [OTHER_SGID]})
ports = [self._fake_port()]
self.assertEqual(
{_IPv4: set([FAKE_SGID]), _IPv6: set([OTHER_SGID])},
self.firewall._get_remote_sg_ids_sets_by_ipversion(ports))
def test_determine_sg_rules_to_remove(self):
self.firewall.pre_sg_rules = self._fake_sg_rules(sg_id=OTHER_SGID)
ports = [self._fake_port()]
self.assertEqual(set([OTHER_SGID]),
self.firewall._determine_sg_rules_to_remove(ports))
def test_get_sg_ids_set_for_ports(self):
sg_ids = set([FAKE_SGID, OTHER_SGID])
ports = [self._fake_port(sg_id) for sg_id in sg_ids]
self.assertEqual(sg_ids,
self.firewall._get_sg_ids_set_for_ports(ports))
def test_clear_sg_members(self):
self.firewall.sg_members = self._fake_sg_members(
sg_ids=[FAKE_SGID, OTHER_SGID])
self.firewall._clear_sg_members(_IPv4, [OTHER_SGID])
self.assertEqual(0, len(self.firewall.sg_members[OTHER_SGID][_IPv4]))
def test_remove_unused_sg_members(self):
self.firewall.sg_members = self._fake_sg_members([FAKE_SGID,
OTHER_SGID])
self.firewall.sg_members[FAKE_SGID][_IPv4] = []
self.firewall.sg_members[FAKE_SGID][_IPv6] = []
self.firewall.sg_members[OTHER_SGID][_IPv6] = []
self.firewall._remove_unused_sg_members()
self.assertIn(OTHER_SGID, self.firewall.sg_members)
self.assertNotIn(FAKE_SGID, self.firewall.sg_members)
def test_remove_unused_security_group_info_clears_unused_rules(self):
self._setup_fake_firewall_members_and_rules(self.firewall)
self.firewall.prepare_port_filter(self._fake_port())
# create another SG which won't be referenced by any filtered port
fake_sg_rules = self.firewall.sg_rules['fake_sgid']
self.firewall.pre_sg_rules['other_sgid'] = fake_sg_rules
self.firewall.sg_rules['other_sgid'] = fake_sg_rules
self.firewall.pre_sg_rules[OTHER_SGID] = fake_sg_rules
self.firewall.sg_rules[OTHER_SGID] = fake_sg_rules
# call the cleanup function, and check the unused sg_rules are out
self.firewall._remove_unused_security_group_info()
self.assertNotIn('other_sgid', self.firewall.sg_rules)
self.assertNotIn(OTHER_SGID, self.firewall.sg_rules)
def test_remove_unused_sg_members(self):
def test_remove_unused_security_group_info(self):
self._setup_fake_firewall_members_and_rules(self.firewall)
# no filtered ports in 'fake_sgid', so all rules and members
# are not needed and we expect them to be cleaned up
self.firewall.prepare_port_filter(self._fake_port('other_sgid'))
self.firewall.prepare_port_filter(self._fake_port(OTHER_SGID))
self.firewall._remove_unused_security_group_info()
self.assertNotIn('fake_sgid', self.firewall.sg_members)
self.assertNotIn(FAKE_SGID, self.firewall.sg_members)
def test_remove_all_unused_info(self):
self._setup_fake_firewall_members_and_rules(self.firewall)
@ -1481,8 +1556,8 @@ class IptablesFirewallEnhancedIpsetTestCase(BaseIptablesFirewallTestCase):
self.assertFalse(self.firewall.sg_rules)
def test_prepare_port_filter_with_deleted_member(self):
self.firewall.sg_rules = self._fake_sg_rule()
self.firewall.pre_sg_rules = self._fake_sg_rule()
self.firewall.sg_rules = self._fake_sg_rules()
self.firewall.pre_sg_rules = self._fake_sg_rules()
self.firewall.sg_members = {'fake_sgid': {
'IPv4': [
'10.0.0.1', '10.0.0.3', '10.0.0.4', '10.0.0.5'],
@ -1500,7 +1575,7 @@ class IptablesFirewallEnhancedIpsetTestCase(BaseIptablesFirewallTestCase):
self.firewall.ipset.assert_has_calls(calls, True)
def test_remove_port_filter_with_destroy_ipset_chain(self):
self.firewall.sg_rules = self._fake_sg_rule()
self.firewall.sg_rules = self._fake_sg_rules()
port = self._fake_port()
self.firewall.sg_members = {'fake_sgid': {
'IPv4': ['10.0.0.1'],
@ -1531,13 +1606,13 @@ class IptablesFirewallEnhancedIpsetTestCase(BaseIptablesFirewallTestCase):
self.firewall.ipset.assert_has_calls(calls)
def test_prepare_port_filter_with_sg_no_member(self):
self.firewall.sg_rules = self._fake_sg_rule()
self.firewall.sg_rules['fake_sgid'].append(
self.firewall.sg_rules = self._fake_sg_rules()
self.firewall.sg_rules[FAKE_SGID].append(
{'direction': 'ingress', 'remote_group_id': 'fake_sgid2',
'ethertype': 'IPv4'})
self.firewall.sg_rules.update()
self.firewall.sg_members = {'fake_sgid': {
'IPv4': ['10.0.0.1', '10.0.0.2'], 'IPv6': ['fe80::1']}}
self.firewall.sg_members['fake_sgid'] = {
'IPv4': ['10.0.0.1', '10.0.0.2'], 'IPv6': ['fe80::1']}
self.firewall.pre_sg_members = {}
port = self._fake_port()
port['security_group_source_groups'].append('fake_sgid2')
@ -1549,8 +1624,8 @@ class IptablesFirewallEnhancedIpsetTestCase(BaseIptablesFirewallTestCase):
self.firewall.ipset.assert_has_calls(calls)
def test_filter_defer_apply_off_with_sg_only_ipv6_rule(self):
self.firewall.sg_rules = self._fake_sg_rule()
self.firewall.pre_sg_rules = self._fake_sg_rule()
self.firewall.sg_rules = self._fake_sg_rules()
self.firewall.pre_sg_rules = self._fake_sg_rules()
self.firewall.ipset_chains = {'IPv4fake_sgid': ['10.0.0.2'],
'IPv6fake_sgid': ['fe80::1']}
self.firewall.sg_members = {'fake_sgid': {
@ -1579,7 +1654,7 @@ class IptablesFirewallEnhancedIpsetTestCase(BaseIptablesFirewallTestCase):
'IPv6': [FAKE_IP['IPv6']]}}
port = self._fake_port()
rule = self._fake_sg_rule_for_ethertype('IPv4')
rule = self._fake_sg_rule_for_ethertype(_IPv4, FAKE_SGID)
rules = self.firewall._expand_sg_rule_with_remote_ips(
rule, port, 'ingress')
self.assertEqual(list(rules),