ovsfw: fix troublesome port_rule_masking

In several cases port masking algorithm borrowed
from networking_ovs_dpdk didn't behave correctly.
This caused non-restricted ports to be open due to
wrong tp_src field value in resulting ovs rules.

This was fixed by alternative port masking
implementation.

Functional and unit tests to cover the bug added as well.

Co-Authored-By: Jakub Libosvar <libosvar@redhat.com>
Co-Authored-By: IWAMOTO Toshihiro <iwamoto@valinux.co.jp>

Closes-Bug: #1611991
Change-Id: Idfc0e9c52e0dd08852c91c17e12edb034606a361
This commit is contained in:
Inessa Vasilevskaya 2016-08-11 02:21:29 +03:00
parent a915f2b690
commit 0494f212aa
3 changed files with 303 additions and 119 deletions

View File

@ -23,7 +23,6 @@ import decimal
import errno
import functools
import importlib
import math
import multiprocessing
import os
import os.path
@ -55,8 +54,6 @@ from neutron.db import api as db_api
TIME_FORMAT = "%Y-%m-%dT%H:%M:%SZ"
LOG = logging.getLogger(__name__)
SYNCHRONIZED_PREFIX = 'neutron-'
# Unsigned 16 bit MAX.
MAX_UINT16 = 0xffff
synchronized = lockutils.synchronized_with_prefix(SYNCHRONIZED_PREFIX)
@ -422,101 +419,242 @@ def safe_decode_utf8(s):
return s
#TODO(jlibosva): Move this to neutron-lib and reuse in networking-ovs-dpdk
def _create_mask(lsb_mask):
return (MAX_UINT16 << int(math.floor(math.log(lsb_mask, 2)))) \
& MAX_UINT16
def _hex_format(port, mask=0):
def hex_str(num):
return format(num, '#06x')
if mask > 0:
return "%s/%s" % (hex_str(port), hex_str(0xffff & ~mask))
return hex_str(port)
def _reduce_mask(mask, step=1):
mask <<= step
return mask & MAX_UINT16
def _gen_rules_port_min(port_min, top_bit):
"""
Encode a port range range(port_min, (port_min | (top_bit - 1)) + 1) into
a set of bit value/masks.
"""
# Processing starts with setting up mask and top_bit variables to their
# maximum. Top_bit has the form (1000000) with '1' pointing to the register
# being processed, while mask has the form (0111111) with '1' showing
# possible range to be covered.
# With each rule generation cycle, mask and top_bit are bit shifted to the
# right. When top_bit reaches 0 it means that last register was processed.
def _increase_mask(mask, step=1):
for index in range(step):
# Let port_min be n bits long, top_bit = 1 << k, 0<=k<=n-1.
# Each cycle step checks the following conditions:
# 1). port & mask == 0
# This means that remaining bits k..1 are equal to '0' and can be
# covered by a single port/mask rule.
# If condition 1 doesn't fit, then both top_bit and mask are bit
# shifted to the right and condition 2 is checked:
# 2). port & top_bit == 0
# This means that kth port bit is equal to '0'. By setting it to '1'
# and masking other (k-1) bits all ports in range
# [P, P + 2^(k-1)-1] are guaranteed to be covered.
# Let p_k be equal to port first (n-k) bits with rest set to 0.
# Then P = p_k | top_bit.
# Correctness proof:
# The remaining range to be encoded in a cycle is calculated as follows:
# R = [port_min, port_min | mask].
# If condition 1 holds, then a rule that covers R is generated and the job
# is done.
# If condition 2 holds, then the rule emitted will cover 2^(k-1) values
# from the range. Remaining range R will shrink by 2^(k-1).
# If condition 2 doesn't hold, then even after top_bit/mask shift in next
# iteration the value of R won't change.
# Full cycle example for range [40, 64):
# port=0101000, top_bit=1000000, k=6
# * step 1, k=6, R=[40, 63]
# top_bit=1000000, mask=0111111 -> condition 1 doesn't hold, shifting
# mask/top_bit
# top_bit=0100000, mask=0011111 -> condition 2 doesn't hold
# * step 2, k=5, R=[40, 63]
# top_bit=0100000, mask=0011111 -> condition 1 doesn't hold, shifting
# mask/top_bit
# top_bit=0010000, mask=0001111 -> condition 2 holds -> 011xxxx or
# 0x0030/fff0
# * step 3, k=4, R=[40, 47]
# top_bit=0010000, mask=0001111 -> condition 1 doesn't hold, shifting
# mask/top_bit
# top_bit=0001000, mask=0000111 -> condition 2 doesn't hold
# * step 4, k=3, R=[40, 47]
# top_bit=0001000, mask=0000111 -> condition 1 holds -> 0101xxx or
# 0x0028/fff8
# rules=[0x0030/fff0, 0x0028/fff8]
rules = []
mask = top_bit - 1
while True:
if (port_min & mask) == 0:
# greedy matched a streak of '0' in port_min
rules.append(_hex_format(port_min, mask))
break
top_bit >>= 1
mask >>= 1
mask |= 0x8000
return mask
if (port_min & top_bit) == 0:
# matched next '0' in port_min to substitute for '1' in resulting
# rule
rules.append(_hex_format(port_min & ~mask | top_bit, mask))
return rules
def _hex_format(number):
return format(number, '#06x')
def _gen_rules_port_max(port_max, top_bit):
"""
Encode a port range range(port_max & ~(top_bit - 1), port_max + 1) into
a set of bit value/masks.
"""
# Processing starts with setting up mask and top_bit variables to their
# maximum. Top_bit has the form (1000000) with '1' pointing to the register
# being processed, while mask has the form (0111111) with '1' showing
# possible range to be covered.
# With each rule generation cycle, mask and top_bit are bit shifted to the
# right. When top_bit reaches 0 it means that last register was processed.
# Let port_max be n bits long, top_bit = 1 << k, 0<=k<=n-1.
# Each cycle step checks the following conditions:
# 1). port & mask == mask
# This means that remaining bits k..1 are equal to '1' and can be
# covered by a single port/mask rule.
# If condition 1 doesn't fit, then both top_bit and mask are bit
# shifted to the right and condition 2 is checked:
# 2). port & top_bit == top_bit
# This means that kth port bit is equal to '1'. By setting it to '0'
# and masking other (k-1) bits all ports in range
# [P, P + 2^(k-1)-1] are guaranteed to be covered.
# Let p_k be equal to port first (n-k) bits with rest set to 0.
# Then P = p_k | ~top_bit.
# Correctness proof:
# The remaining range to be encoded in a cycle is calculated as follows:
# R = [port_max & ~mask, port_max].
# If condition 1 holds, then a rule that covers R is generated and the job
# is done.
# If condition 2 holds, then the rule emitted will cover 2^(k-1) values
# from the range. Remaining range R will shrink by 2^(k-1).
# If condition 2 doesn't hold, then even after top_bit/mask shift in next
# iteration the value of R won't change.
# Full cycle example for range [64, 105]:
# port=1101001, top_bit=1000000, k=6
# * step 1, k=6, R=[64, 105]
# top_bit=1000000, mask=0111111 -> condition 1 doesn't hold, shifting
# mask/top_bit
# top_bit=0100000, mask=0011111 -> condition 2 holds -> 10xxxxx or
# 0x0040/ffe0
# * step 2, k=5, R=[96, 105]
# top_bit=0100000, mask=0011111 -> condition 1 doesn't hold, shifting
# mask/top_bit
# top_bit=0010000, mask=0001111 -> condition 2 doesn't hold
# * step 3, k=4, R=[96, 105]
# top_bit=0010000, mask=0001111 -> condition 1 doesn't hold, shifting
# mask/top_bit
# top_bit=0001000, mask=0000111 -> condition 2 holds -> 1100xxx or
# 0x0060/fff8
# * step 4, k=3, R=[104, 105]
# top_bit=0001000, mask=0000111 -> condition 1 doesn't hold, shifting
# mask/top_bit
# top_bit=0000100, mask=0000011 -> condition 2 doesn't hold
# * step 5, k=2, R=[104, 105]
# top_bit=0000100, mask=0000011 -> condition 1 doesn't hold, shifting
# mask/top_bit
# top_bit=0000010, mask=0000001 -> condition 2 doesn't hold
# * step 6, k=1, R=[104, 105]
# top_bit=0000010, mask=0000001 -> condition 1 holds -> 1101001 or
# 0x0068
# rules=[0x0040/ffe0, 0x0060/fff8, 0x0068]
rules = []
mask = top_bit - 1
while True:
if (port_max & mask) == mask:
# greedy matched a streak of '1' in port_max
rules.append(_hex_format(port_max & ~mask, mask))
break
top_bit >>= 1
mask >>= 1
if (port_max & top_bit) == top_bit:
# matched next '1' in port_max to substitute for '0' in resulting
# rule
rules.append(_hex_format(port_max & ~mask & ~top_bit, mask))
return rules
def port_rule_masking(port_min, port_max):
"""Translate a range [port_min, port_max] into a set of bitwise matches.
Each match has the form 'port/mask'. The port and mask are 16-bit numbers
written in hexadecimal prefixed by 0x. Each 1-bit in mask requires that
the corresponding bit in port must match. Each 0-bit in mask causes the
corresponding bit to be ignored.
"""
# Let binary representation of port_min and port_max be n bits long and
# have first m bits in common, 0 <= m <= n.
# If remaining (n - m) bits of given ports define 2^(n-m) values, then
# [port_min, port_max] range is covered by a single rule.
# For example:
# n = 6
# port_min = 16 (binary 010000)
# port_max = 23 (binary 010111)
# Ports have m=3 bits in common with the remaining (n-m)=3 bits
# covering range [0, 2^3), which equals to a single 010xxx rule. The algo
# will return [0x0010/fff8].
# Else [port_min, port_max] range will be split into 2: range [port_min, T)
# and [T, port_max]. Let p_m be the common part of port_min and port_max
# with other (n-m) bits set to 0. Then T = p_m | 1 << (n-m-1).
# For example:
# n = 7
# port_min = 40 (binary 0101000)
# port_max = 105 (binary 1101001)
# Ports have m=0 bits in common, p_m=000000. Then T=1000000 and the
# initial range [40, 105] is divided into [40, 64) and [64, 105].
# Each of the ranges will be processed separately, then the generated rules
# will be merged.
# Check port_max >= port_min.
if port_max < port_min:
raise ValueError(_("'port_max' is smaller than 'port_min'"))
# Rules to be added to OVS.
bitdiff = port_min ^ port_max
if bitdiff == 0:
# port_min == port_max
return [_hex_format(port_min)]
# for python3.x, bit_length could be used here
top_bit = 1
while top_bit <= bitdiff:
top_bit <<= 1
if (port_min & (top_bit - 1) == 0 and
port_max & (top_bit - 1) == top_bit - 1):
# special case, range of 2^k ports is covered
return [_hex_format(port_min, top_bit - 1)]
top_bit >>= 1
rules = []
# Loop from the lower part. Increment port_min.
bit_right = 1
mask = MAX_UINT16
t_port_min = port_min
while True:
# Obtain last significative bit.
bit_min = port_min & bit_right
# Take care of first bit.
if bit_right == 1:
if bit_min > 0:
rules.append("%s" % (_hex_format(t_port_min)))
else:
mask = _create_mask(2)
rules.append("%s/%s" % (_hex_format(t_port_min & mask),
_hex_format(mask)))
elif bit_min == 0:
mask = _create_mask(bit_right)
t_port_min += bit_right
# If the temporal variable we are using exceeds the
# port_max value, exit the loop.
if t_port_min > port_max:
break
rules.append("%s/%s" % (_hex_format(t_port_min & mask),
_hex_format(mask)))
# If the temporal variable we are using exceeds the
# port_max value, exit the loop.
if t_port_min > port_max:
break
bit_right <<= 1
# Loop from the higher part.
bit_position = int(round(math.log(port_max, 2)))
bit_left = 1 << bit_position
mask = MAX_UINT16
mask = _reduce_mask(mask, bit_position)
# Find the most significative bit of port_max, higher
# than the most significative bit of port_min.
while mask < MAX_UINT16:
bit_max = port_max & bit_left
bit_min = port_min & bit_left
if bit_max > bit_min:
# Difference found.
break
# Rotate bit_left to the right and increase mask.
bit_left >>= 1
mask = _increase_mask(mask)
while bit_left > 1:
# Obtain next most significative bit.
bit_left >>= 1
bit_max = port_max & bit_left
if bit_left == 1:
if bit_max == 0:
rules.append("%s" % (_hex_format(port_max)))
else:
mask = _create_mask(2)
rules.append("%s/%s" % (_hex_format(port_max & mask),
_hex_format(mask)))
elif bit_max > 0:
t_port_max = port_max - bit_max
mask = _create_mask(bit_left)
rules.append("%s/%s" % (_hex_format(t_port_max),
_hex_format(mask)))
rules.extend(_gen_rules_port_min(port_min, top_bit))
rules.extend(_gen_rules_port_max(port_max, top_bit))
return rules

View File

@ -432,6 +432,26 @@ class FirewallTestCase(BaseFirewallTestCase):
def test_ingress_tcp_rule(self):
self._test_rule(self.tester.INGRESS, self.tester.TCP)
def test_next_port_closed(self):
# https://bugs.launchpad.net/neutron/+bug/1611991 was caused by wrong
# masking in rules which allow traffic to a port with even port number
port = 42
for direction in (self.tester.EGRESS, self.tester.INGRESS):
sg_rules = [{'ethertype': constants.IPv4,
'direction': direction,
'protocol': constants.PROTO_NAME_TCP,
'source_port_range_min': port,
'source_port_range_max': port}]
self._apply_security_group_rules(self.FAKE_SECURITY_GROUP_ID,
sg_rules)
self.tester.assert_connection(protocol=self.tester.TCP,
direction=direction,
src_port=port)
self.tester.assert_no_connection(protocol=self.tester.TCP,
direction=direction,
src_port=port + 1)
def test_ingress_udp_rule(self):
self._test_rule(self.tester.INGRESS, self.tester.UDP)

View File

@ -13,6 +13,7 @@
# under the License.
import errno
import inspect
import os.path
import re
import sys
@ -24,6 +25,7 @@ from neutron_lib import constants
from neutron_lib import exceptions as exc
from oslo_log import log as logging
import six
import testscenarios
import testtools
from neutron.common import exceptions as n_exc
@ -34,6 +36,8 @@ from neutron.tests import base
from neutron.tests.common import helpers
from neutron.tests.unit import tests
load_tests = testscenarios.load_tests_apply_scenarios
class TestParseMappings(base.BaseTestCase):
def parse(self, mapping_list, unique_values=True, unique_keys=True):
@ -685,43 +689,65 @@ class TestSafeDecodeUtf8(base.BaseTestCase):
class TestPortRuleMasking(base.BaseTestCase):
scenarios = [
('Test 1 (networking-ovs-dpdk)',
{'port_min': 5,
'port_max': 12,
'expected': ['0x0005', '0x0006/0xfffe', '0x0008/0xfffc', '0x000c']}
),
('Test 2 (networking-ovs-dpdk)',
{'port_min': 20,
'port_max': 130,
'expected': ['0x0014/0xfffc', '0x0018/0xfff8',
'0x0020/0xffe0', '0x0040/0xffc0', '0x0080/0xfffe',
'0x0082']}),
('Test 3 (networking-ovs-dpdk)',
{'port_min': 4501,
'port_max': 33057,
'expected': ['0x1195', '0x1196/0xfffe', '0x1198/0xfff8',
'0x11a0/0xffe0', '0x11c0/0xffc0', '0x1200/0xfe00',
'0x1400/0xfc00', '0x1800/0xf800', '0x2000/0xe000',
'0x4000/0xc000', '0x8000/0xff00', '0x8100/0xffe0',
'0x8120/0xfffe']}),
('Test port_max == 2^k-1',
{'port_min': 101,
'port_max': 127,
'expected': ['0x0065', '0x0066/0xfffe', '0x0068/0xfff8',
'0x0070/0xfff0']}),
('Test single even port',
{'port_min': 22,
'port_max': 22,
'expected': ['0x0016']}),
('Test single odd port',
{'port_min': 5001,
'port_max': 5001,
'expected': ['0x1389']}),
('Test full interval',
{'port_min': 0,
'port_max': 7,
'expected': ['0x0000/0xfff8']}),
('Test 2^k interval',
{'port_min': 8,
'port_max': 15,
'expected': ['0x0008/0xfff8']}),
('Test full port range',
{'port_min': 0,
'port_max': 65535,
'expected': ['0x0000/0x0000']}),
('Test bad values',
{'port_min': 12,
'port_max': 5,
'expected': ValueError}),
]
def test_port_rule_masking(self):
compare_rules = lambda x, y: set(x) == set(y) and len(x) == len(y)
# Test 1.
port_min = 5
port_max = 12
expected_rules = ['0x0005', '0x000c', '0x0006/0xfffe',
'0x0008/0xfffc']
rules = utils.port_rule_masking(port_min, port_max)
self.assertTrue(compare_rules(rules, expected_rules))
# Test 2.
port_min = 20
port_max = 130
expected_rules = ['0x0014/0xfffe', '0x0016/0xfffe', '0x0018/0xfff8',
'0x0020/0xffe0', '0x0040/0xffc0', '0x0080/0xfffe',
'0x0082']
rules = utils.port_rule_masking(port_min, port_max)
self.assertEqual(expected_rules, rules)
# Test 3.
port_min = 4501
port_max = 33057
expected_rules = ['0x1195', '0x1196/0xfffe', '0x1198/0xfff8',
'0x11a0/0xffe0', '0x11c0/0xffc0', '0x1200/0xfe00',
'0x1400/0xfc00', '0x1800/0xf800', '0x2000/0xe000',
'0x4000/0xc000', '0x8021/0xff00', '0x8101/0xffe0',
'0x8120/0xfffe']
rules = utils.port_rule_masking(port_min, port_max)
self.assertEqual(expected_rules, rules)
def test_port_rule_masking_min_higher_than_max(self):
port_min = 10
port_max = 5
with testtools.ExpectedException(ValueError):
utils.port_rule_masking(port_min, port_max)
if (inspect.isclass(self.expected)
and issubclass(self.expected, Exception)):
with testtools.ExpectedException(self.expected):
utils.port_rule_masking(self.port_min, self.port_max)
else:
rules = utils.port_rule_masking(self.port_min, self.port_max)
self.assertItemsEqual(self.expected, rules)
class TestAuthenticEUI(base.BaseTestCase):