diff --git a/neutron/common/utils.py b/neutron/common/utils.py index 021ebbb1dcf..d1e75eacbf0 100644 --- a/neutron/common/utils.py +++ b/neutron/common/utils.py @@ -23,7 +23,6 @@ import decimal import errno import functools import importlib -import math import multiprocessing import os import os.path @@ -55,8 +54,6 @@ from neutron.db import api as db_api TIME_FORMAT = "%Y-%m-%dT%H:%M:%SZ" LOG = logging.getLogger(__name__) SYNCHRONIZED_PREFIX = 'neutron-' -# Unsigned 16 bit MAX. -MAX_UINT16 = 0xffff synchronized = lockutils.synchronized_with_prefix(SYNCHRONIZED_PREFIX) @@ -422,101 +419,242 @@ def safe_decode_utf8(s): return s -#TODO(jlibosva): Move this to neutron-lib and reuse in networking-ovs-dpdk -def _create_mask(lsb_mask): - return (MAX_UINT16 << int(math.floor(math.log(lsb_mask, 2)))) \ - & MAX_UINT16 +def _hex_format(port, mask=0): + + def hex_str(num): + return format(num, '#06x') + if mask > 0: + return "%s/%s" % (hex_str(port), hex_str(0xffff & ~mask)) + return hex_str(port) -def _reduce_mask(mask, step=1): - mask <<= step - return mask & MAX_UINT16 +def _gen_rules_port_min(port_min, top_bit): + """ + Encode a port range range(port_min, (port_min | (top_bit - 1)) + 1) into + a set of bit value/masks. + """ + # Processing starts with setting up mask and top_bit variables to their + # maximum. Top_bit has the form (1000000) with '1' pointing to the register + # being processed, while mask has the form (0111111) with '1' showing + # possible range to be covered. + # With each rule generation cycle, mask and top_bit are bit shifted to the + # right. When top_bit reaches 0 it means that last register was processed. -def _increase_mask(mask, step=1): - for index in range(step): + # Let port_min be n bits long, top_bit = 1 << k, 0<=k<=n-1. + + # Each cycle step checks the following conditions: + + # 1). port & mask == 0 + # This means that remaining bits k..1 are equal to '0' and can be + # covered by a single port/mask rule. + + # If condition 1 doesn't fit, then both top_bit and mask are bit + # shifted to the right and condition 2 is checked: + + # 2). port & top_bit == 0 + # This means that kth port bit is equal to '0'. By setting it to '1' + # and masking other (k-1) bits all ports in range + # [P, P + 2^(k-1)-1] are guaranteed to be covered. + # Let p_k be equal to port first (n-k) bits with rest set to 0. + # Then P = p_k | top_bit. + + # Correctness proof: + # The remaining range to be encoded in a cycle is calculated as follows: + # R = [port_min, port_min | mask]. + # If condition 1 holds, then a rule that covers R is generated and the job + # is done. + # If condition 2 holds, then the rule emitted will cover 2^(k-1) values + # from the range. Remaining range R will shrink by 2^(k-1). + # If condition 2 doesn't hold, then even after top_bit/mask shift in next + # iteration the value of R won't change. + + # Full cycle example for range [40, 64): + # port=0101000, top_bit=1000000, k=6 + # * step 1, k=6, R=[40, 63] + # top_bit=1000000, mask=0111111 -> condition 1 doesn't hold, shifting + # mask/top_bit + # top_bit=0100000, mask=0011111 -> condition 2 doesn't hold + + # * step 2, k=5, R=[40, 63] + # top_bit=0100000, mask=0011111 -> condition 1 doesn't hold, shifting + # mask/top_bit + # top_bit=0010000, mask=0001111 -> condition 2 holds -> 011xxxx or + # 0x0030/fff0 + # * step 3, k=4, R=[40, 47] + # top_bit=0010000, mask=0001111 -> condition 1 doesn't hold, shifting + # mask/top_bit + # top_bit=0001000, mask=0000111 -> condition 2 doesn't hold + + # * step 4, k=3, R=[40, 47] + # top_bit=0001000, mask=0000111 -> condition 1 holds -> 0101xxx or + # 0x0028/fff8 + + # rules=[0x0030/fff0, 0x0028/fff8] + + rules = [] + mask = top_bit - 1 + + while True: + if (port_min & mask) == 0: + # greedy matched a streak of '0' in port_min + rules.append(_hex_format(port_min, mask)) + break + top_bit >>= 1 mask >>= 1 - mask |= 0x8000 - return mask + if (port_min & top_bit) == 0: + # matched next '0' in port_min to substitute for '1' in resulting + # rule + rules.append(_hex_format(port_min & ~mask | top_bit, mask)) + return rules -def _hex_format(number): - return format(number, '#06x') +def _gen_rules_port_max(port_max, top_bit): + """ + Encode a port range range(port_max & ~(top_bit - 1), port_max + 1) into + a set of bit value/masks. + """ + # Processing starts with setting up mask and top_bit variables to their + # maximum. Top_bit has the form (1000000) with '1' pointing to the register + # being processed, while mask has the form (0111111) with '1' showing + # possible range to be covered. + + # With each rule generation cycle, mask and top_bit are bit shifted to the + # right. When top_bit reaches 0 it means that last register was processed. + + # Let port_max be n bits long, top_bit = 1 << k, 0<=k<=n-1. + + # Each cycle step checks the following conditions: + + # 1). port & mask == mask + # This means that remaining bits k..1 are equal to '1' and can be + # covered by a single port/mask rule. + + # If condition 1 doesn't fit, then both top_bit and mask are bit + # shifted to the right and condition 2 is checked: + + # 2). port & top_bit == top_bit + # This means that kth port bit is equal to '1'. By setting it to '0' + # and masking other (k-1) bits all ports in range + # [P, P + 2^(k-1)-1] are guaranteed to be covered. + # Let p_k be equal to port first (n-k) bits with rest set to 0. + # Then P = p_k | ~top_bit. + + # Correctness proof: + # The remaining range to be encoded in a cycle is calculated as follows: + # R = [port_max & ~mask, port_max]. + # If condition 1 holds, then a rule that covers R is generated and the job + # is done. + # If condition 2 holds, then the rule emitted will cover 2^(k-1) values + # from the range. Remaining range R will shrink by 2^(k-1). + # If condition 2 doesn't hold, then even after top_bit/mask shift in next + # iteration the value of R won't change. + + # Full cycle example for range [64, 105]: + # port=1101001, top_bit=1000000, k=6 + # * step 1, k=6, R=[64, 105] + # top_bit=1000000, mask=0111111 -> condition 1 doesn't hold, shifting + # mask/top_bit + # top_bit=0100000, mask=0011111 -> condition 2 holds -> 10xxxxx or + # 0x0040/ffe0 + # * step 2, k=5, R=[96, 105] + # top_bit=0100000, mask=0011111 -> condition 1 doesn't hold, shifting + # mask/top_bit + # top_bit=0010000, mask=0001111 -> condition 2 doesn't hold + + # * step 3, k=4, R=[96, 105] + # top_bit=0010000, mask=0001111 -> condition 1 doesn't hold, shifting + # mask/top_bit + # top_bit=0001000, mask=0000111 -> condition 2 holds -> 1100xxx or + # 0x0060/fff8 + # * step 4, k=3, R=[104, 105] + # top_bit=0001000, mask=0000111 -> condition 1 doesn't hold, shifting + # mask/top_bit + # top_bit=0000100, mask=0000011 -> condition 2 doesn't hold + + # * step 5, k=2, R=[104, 105] + # top_bit=0000100, mask=0000011 -> condition 1 doesn't hold, shifting + # mask/top_bit + # top_bit=0000010, mask=0000001 -> condition 2 doesn't hold + + # * step 6, k=1, R=[104, 105] + # top_bit=0000010, mask=0000001 -> condition 1 holds -> 1101001 or + # 0x0068 + + # rules=[0x0040/ffe0, 0x0060/fff8, 0x0068] + + rules = [] + mask = top_bit - 1 + + while True: + if (port_max & mask) == mask: + # greedy matched a streak of '1' in port_max + rules.append(_hex_format(port_max & ~mask, mask)) + break + top_bit >>= 1 + mask >>= 1 + if (port_max & top_bit) == top_bit: + # matched next '1' in port_max to substitute for '0' in resulting + # rule + rules.append(_hex_format(port_max & ~mask & ~top_bit, mask)) + return rules def port_rule_masking(port_min, port_max): + """Translate a range [port_min, port_max] into a set of bitwise matches. + + Each match has the form 'port/mask'. The port and mask are 16-bit numbers + written in hexadecimal prefixed by 0x. Each 1-bit in mask requires that + the corresponding bit in port must match. Each 0-bit in mask causes the + corresponding bit to be ignored. + """ + + # Let binary representation of port_min and port_max be n bits long and + # have first m bits in common, 0 <= m <= n. + + # If remaining (n - m) bits of given ports define 2^(n-m) values, then + # [port_min, port_max] range is covered by a single rule. + # For example: + # n = 6 + # port_min = 16 (binary 010000) + # port_max = 23 (binary 010111) + # Ports have m=3 bits in common with the remaining (n-m)=3 bits + # covering range [0, 2^3), which equals to a single 010xxx rule. The algo + # will return [0x0010/fff8]. + + # Else [port_min, port_max] range will be split into 2: range [port_min, T) + # and [T, port_max]. Let p_m be the common part of port_min and port_max + # with other (n-m) bits set to 0. Then T = p_m | 1 << (n-m-1). + # For example: + # n = 7 + # port_min = 40 (binary 0101000) + # port_max = 105 (binary 1101001) + # Ports have m=0 bits in common, p_m=000000. Then T=1000000 and the + # initial range [40, 105] is divided into [40, 64) and [64, 105]. + # Each of the ranges will be processed separately, then the generated rules + # will be merged. + # Check port_max >= port_min. if port_max < port_min: raise ValueError(_("'port_max' is smaller than 'port_min'")) - # Rules to be added to OVS. + bitdiff = port_min ^ port_max + if bitdiff == 0: + # port_min == port_max + return [_hex_format(port_min)] + # for python3.x, bit_length could be used here + top_bit = 1 + while top_bit <= bitdiff: + top_bit <<= 1 + if (port_min & (top_bit - 1) == 0 and + port_max & (top_bit - 1) == top_bit - 1): + # special case, range of 2^k ports is covered + return [_hex_format(port_min, top_bit - 1)] + + top_bit >>= 1 rules = [] - - # Loop from the lower part. Increment port_min. - bit_right = 1 - mask = MAX_UINT16 - t_port_min = port_min - while True: - # Obtain last significative bit. - bit_min = port_min & bit_right - # Take care of first bit. - if bit_right == 1: - if bit_min > 0: - rules.append("%s" % (_hex_format(t_port_min))) - else: - mask = _create_mask(2) - rules.append("%s/%s" % (_hex_format(t_port_min & mask), - _hex_format(mask))) - elif bit_min == 0: - mask = _create_mask(bit_right) - t_port_min += bit_right - # If the temporal variable we are using exceeds the - # port_max value, exit the loop. - if t_port_min > port_max: - break - rules.append("%s/%s" % (_hex_format(t_port_min & mask), - _hex_format(mask))) - - # If the temporal variable we are using exceeds the - # port_max value, exit the loop. - if t_port_min > port_max: - break - bit_right <<= 1 - - # Loop from the higher part. - bit_position = int(round(math.log(port_max, 2))) - bit_left = 1 << bit_position - mask = MAX_UINT16 - mask = _reduce_mask(mask, bit_position) - # Find the most significative bit of port_max, higher - # than the most significative bit of port_min. - while mask < MAX_UINT16: - bit_max = port_max & bit_left - bit_min = port_min & bit_left - if bit_max > bit_min: - # Difference found. - break - # Rotate bit_left to the right and increase mask. - bit_left >>= 1 - mask = _increase_mask(mask) - - while bit_left > 1: - # Obtain next most significative bit. - bit_left >>= 1 - bit_max = port_max & bit_left - if bit_left == 1: - if bit_max == 0: - rules.append("%s" % (_hex_format(port_max))) - else: - mask = _create_mask(2) - rules.append("%s/%s" % (_hex_format(port_max & mask), - _hex_format(mask))) - elif bit_max > 0: - t_port_max = port_max - bit_max - mask = _create_mask(bit_left) - rules.append("%s/%s" % (_hex_format(t_port_max), - _hex_format(mask))) - + rules.extend(_gen_rules_port_min(port_min, top_bit)) + rules.extend(_gen_rules_port_max(port_max, top_bit)) return rules diff --git a/neutron/tests/functional/agent/test_firewall.py b/neutron/tests/functional/agent/test_firewall.py index d982ae4e57d..4570b00b61a 100644 --- a/neutron/tests/functional/agent/test_firewall.py +++ b/neutron/tests/functional/agent/test_firewall.py @@ -432,6 +432,26 @@ class FirewallTestCase(BaseFirewallTestCase): def test_ingress_tcp_rule(self): self._test_rule(self.tester.INGRESS, self.tester.TCP) + def test_next_port_closed(self): + # https://bugs.launchpad.net/neutron/+bug/1611991 was caused by wrong + # masking in rules which allow traffic to a port with even port number + port = 42 + for direction in (self.tester.EGRESS, self.tester.INGRESS): + sg_rules = [{'ethertype': constants.IPv4, + 'direction': direction, + 'protocol': constants.PROTO_NAME_TCP, + 'source_port_range_min': port, + 'source_port_range_max': port}] + self._apply_security_group_rules(self.FAKE_SECURITY_GROUP_ID, + sg_rules) + + self.tester.assert_connection(protocol=self.tester.TCP, + direction=direction, + src_port=port) + self.tester.assert_no_connection(protocol=self.tester.TCP, + direction=direction, + src_port=port + 1) + def test_ingress_udp_rule(self): self._test_rule(self.tester.INGRESS, self.tester.UDP) diff --git a/neutron/tests/unit/common/test_utils.py b/neutron/tests/unit/common/test_utils.py index c8326c381ef..5326c003c1e 100644 --- a/neutron/tests/unit/common/test_utils.py +++ b/neutron/tests/unit/common/test_utils.py @@ -13,6 +13,7 @@ # under the License. import errno +import inspect import os.path import re import sys @@ -24,6 +25,7 @@ from neutron_lib import constants from neutron_lib import exceptions as exc from oslo_log import log as logging import six +import testscenarios import testtools from neutron.common import exceptions as n_exc @@ -34,6 +36,8 @@ from neutron.tests import base from neutron.tests.common import helpers from neutron.tests.unit import tests +load_tests = testscenarios.load_tests_apply_scenarios + class TestParseMappings(base.BaseTestCase): def parse(self, mapping_list, unique_values=True, unique_keys=True): @@ -685,43 +689,65 @@ class TestSafeDecodeUtf8(base.BaseTestCase): class TestPortRuleMasking(base.BaseTestCase): + scenarios = [ + ('Test 1 (networking-ovs-dpdk)', + {'port_min': 5, + 'port_max': 12, + 'expected': ['0x0005', '0x0006/0xfffe', '0x0008/0xfffc', '0x000c']} + ), + ('Test 2 (networking-ovs-dpdk)', + {'port_min': 20, + 'port_max': 130, + 'expected': ['0x0014/0xfffc', '0x0018/0xfff8', + '0x0020/0xffe0', '0x0040/0xffc0', '0x0080/0xfffe', + '0x0082']}), + ('Test 3 (networking-ovs-dpdk)', + {'port_min': 4501, + 'port_max': 33057, + 'expected': ['0x1195', '0x1196/0xfffe', '0x1198/0xfff8', + '0x11a0/0xffe0', '0x11c0/0xffc0', '0x1200/0xfe00', + '0x1400/0xfc00', '0x1800/0xf800', '0x2000/0xe000', + '0x4000/0xc000', '0x8000/0xff00', '0x8100/0xffe0', + '0x8120/0xfffe']}), + ('Test port_max == 2^k-1', + {'port_min': 101, + 'port_max': 127, + 'expected': ['0x0065', '0x0066/0xfffe', '0x0068/0xfff8', + '0x0070/0xfff0']}), + ('Test single even port', + {'port_min': 22, + 'port_max': 22, + 'expected': ['0x0016']}), + ('Test single odd port', + {'port_min': 5001, + 'port_max': 5001, + 'expected': ['0x1389']}), + ('Test full interval', + {'port_min': 0, + 'port_max': 7, + 'expected': ['0x0000/0xfff8']}), + ('Test 2^k interval', + {'port_min': 8, + 'port_max': 15, + 'expected': ['0x0008/0xfff8']}), + ('Test full port range', + {'port_min': 0, + 'port_max': 65535, + 'expected': ['0x0000/0x0000']}), + ('Test bad values', + {'port_min': 12, + 'port_max': 5, + 'expected': ValueError}), + ] + def test_port_rule_masking(self): - compare_rules = lambda x, y: set(x) == set(y) and len(x) == len(y) - - # Test 1. - port_min = 5 - port_max = 12 - expected_rules = ['0x0005', '0x000c', '0x0006/0xfffe', - '0x0008/0xfffc'] - rules = utils.port_rule_masking(port_min, port_max) - self.assertTrue(compare_rules(rules, expected_rules)) - - # Test 2. - port_min = 20 - port_max = 130 - expected_rules = ['0x0014/0xfffe', '0x0016/0xfffe', '0x0018/0xfff8', - '0x0020/0xffe0', '0x0040/0xffc0', '0x0080/0xfffe', - '0x0082'] - rules = utils.port_rule_masking(port_min, port_max) - self.assertEqual(expected_rules, rules) - - # Test 3. - port_min = 4501 - port_max = 33057 - expected_rules = ['0x1195', '0x1196/0xfffe', '0x1198/0xfff8', - '0x11a0/0xffe0', '0x11c0/0xffc0', '0x1200/0xfe00', - '0x1400/0xfc00', '0x1800/0xf800', '0x2000/0xe000', - '0x4000/0xc000', '0x8021/0xff00', '0x8101/0xffe0', - '0x8120/0xfffe'] - - rules = utils.port_rule_masking(port_min, port_max) - self.assertEqual(expected_rules, rules) - - def test_port_rule_masking_min_higher_than_max(self): - port_min = 10 - port_max = 5 - with testtools.ExpectedException(ValueError): - utils.port_rule_masking(port_min, port_max) + if (inspect.isclass(self.expected) + and issubclass(self.expected, Exception)): + with testtools.ExpectedException(self.expected): + utils.port_rule_masking(self.port_min, self.port_max) + else: + rules = utils.port_rule_masking(self.port_min, self.port_max) + self.assertItemsEqual(self.expected, rules) class TestAuthenticEUI(base.BaseTestCase):