Merge "Fix KeyError in OVS firewall" into stable/pike

This commit is contained in:
Zuul 2019-04-05 01:19:18 +00:00 committed by Gerrit Code Review
commit 368b53cc9a
2 changed files with 22 additions and 7 deletions

View File

@ -124,26 +124,32 @@ def merge_port_ranges(rule_conj_list):
# item means a removal. # item means a removal.
result = [] result = []
rule_tmpl = rule_conj_list[0][0] rule_tmpl = rule_conj_list[0][0]
cur_conj = set() cur_conj = {}
cur_range_min = None cur_range_min = None
for port, m, conj_id in port_ranges: for port, m, conj_id in port_ranges:
if m == 'min': if m == 'min':
if conj_id in cur_conj:
cur_conj[conj_id] += 1
continue
if cur_conj and cur_range_min != port: if cur_conj and cur_range_min != port:
rule = rule_tmpl.copy() rule = rule_tmpl.copy()
rule['port_range_min'] = cur_range_min rule['port_range_min'] = cur_range_min
rule['port_range_max'] = port - 1 rule['port_range_max'] = port - 1
result.append((rule, list(cur_conj))) result.append((rule, list(cur_conj.keys())))
cur_range_min = port cur_range_min = port
cur_conj.add(conj_id) cur_conj[conj_id] = 1
else: else:
if cur_conj[conj_id] > 1:
cur_conj[conj_id] -= 1
continue
if cur_range_min <= port: if cur_range_min <= port:
rule = rule_tmpl.copy() rule = rule_tmpl.copy()
rule['port_range_min'] = cur_range_min rule['port_range_min'] = cur_range_min
rule['port_range_max'] = port rule['port_range_max'] = port
result.append((rule, list(cur_conj))) result.append((rule, list(cur_conj.keys())))
# The next port range without 'port' starts from (port + 1) # The next port range without 'port' starts from (port + 1)
cur_range_min = port + 1 cur_range_min = port + 1
cur_conj.remove(conj_id) del cur_conj[conj_id]
if (len(result) == 1 and result[0][0]['port_range_min'] == 1 and if (len(result) == 1 and result[0][0]['port_range_min'] == 1 and
result[0][0]['port_range_max'] == 65535): result[0][0]['port_range_max'] == 65535):

View File

@ -419,8 +419,8 @@ class TestMergeRules(base.BaseTestCase):
self.assertEqual(len(expected), len(result)) self.assertEqual(len(expected), len(result))
for (range_min, range_max, conj_ids), result1 in zip( for (range_min, range_max, conj_ids), result1 in zip(
expected, result): expected, result):
self.assertEqual(range_min, result1[0]['port_range_min']) self.assertEqual(range_min, result1[0].get('port_range_min'))
self.assertEqual(range_max, result1[0]['port_range_max']) self.assertEqual(range_max, result1[0].get('port_range_max'))
self.assertEqual(conj_ids, set(result1[1])) self.assertEqual(conj_ids, set(result1[1]))
def test__assert_mergeable_rules(self): def test__assert_mergeable_rules(self):
@ -488,6 +488,15 @@ class TestMergeRules(base.BaseTestCase):
(30, 40, {10, 12, 4}), (30, 40, {10, 12, 4}),
(41, 65535, {10, 12})], result) (41, 65535, {10, 12})], result)
def test_merge_port_ranges_no_port_ranges_same_conj_id(self):
result = rules.merge_port_ranges(
[(dict(self.rule_tmpl), 10),
(dict(self.rule_tmpl), 12),
(dict([('port_range_min', 30), ('port_range_max', 30)] +
self.rule_tmpl), 10)])
self._test_merge_port_ranges_helper([
(None, None, {10, 12})], result)
def test_merge_port_ranges_nonoverlapping(self): def test_merge_port_ranges_nonoverlapping(self):
result = rules.merge_port_ranges( result = rules.merge_port_ranges(
[(dict([('port_range_min', 30), ('port_range_max', 40)] + [(dict([('port_range_min', 30), ('port_range_max', 40)] +