Enable nftables rules for SR-IOV VIPs

This patch enables setting the nftables rules in Amphora using SR-IOV VIPs.

Change-Id: I554aac422371abafb4bb04e2d0df3fce3fa169d4
This commit is contained in:
Michael Johnson 2024-02-24 23:43:59 +00:00
parent d83999f4ed
commit fc37d8303d
32 changed files with 1121 additions and 111 deletions

View File

@ -0,0 +1,52 @@
# Copyright 2024 Red Hat, Inc. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License. You may obtain
# a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations
# under the License.
from octavia_lib.common import constants as lib_consts
from octavia.common import constants as consts
# This is a JSON schema validation dictionary
# https://json-schema.org/latest/json-schema-validation.html
SUPPORTED_RULES_SCHEMA = {
'$schema': 'http://json-schema.org/draft-07/schema#',
'title': 'Octavia Amphora NFTables Rules Schema',
'description': 'This schema is used to validate an nftables rules JSON '
'document sent from a controller.',
'type': 'array',
'items': {
'additionalProperties': False,
'properties': {
consts.PROTOCOL: {
'type': 'string',
'description': 'The protocol for the rule. One of: '
'TCP, UDP, VRRP, SCTP',
'enum': list((lib_consts.PROTOCOL_SCTP,
lib_consts.PROTOCOL_TCP,
lib_consts.PROTOCOL_UDP,
consts.VRRP))
},
consts.CIDR: {
'type': ['string', 'null'],
'description': 'The allowed source CIDR.'
},
consts.PORT: {
'type': 'number',
'description': 'The protocol port number.',
'minimum': 1,
'maximum': 65535
}
},
'required': [consts.PROTOCOL, consts.CIDR, consts.PORT]
}
}

View File

@ -16,6 +16,7 @@ import os
import stat import stat
import flask import flask
from jsonschema import validate
from oslo_config import cfg from oslo_config import cfg
from oslo_log import log as logging from oslo_log import log as logging
import webob import webob
@ -29,7 +30,9 @@ from octavia.amphorae.backends.agent.api_server import keepalivedlvs
from octavia.amphorae.backends.agent.api_server import loadbalancer from octavia.amphorae.backends.agent.api_server import loadbalancer
from octavia.amphorae.backends.agent.api_server import osutils from octavia.amphorae.backends.agent.api_server import osutils
from octavia.amphorae.backends.agent.api_server import plug from octavia.amphorae.backends.agent.api_server import plug
from octavia.amphorae.backends.agent.api_server import rules_schema
from octavia.amphorae.backends.agent.api_server import util from octavia.amphorae.backends.agent.api_server import util
from octavia.amphorae.backends.utils import nftable_utils
from octavia.common import constants as consts from octavia.common import constants as consts
@ -137,6 +140,9 @@ class Server(object):
self.app.add_url_rule(rule=PATH_PREFIX + '/interface/<ip_addr>', self.app.add_url_rule(rule=PATH_PREFIX + '/interface/<ip_addr>',
view_func=self.get_interface, view_func=self.get_interface,
methods=['GET']) methods=['GET'])
self.app.add_url_rule(rule=PATH_PREFIX + '/interface/<ip_addr>/rules',
view_func=self.set_interface_rules,
methods=['PUT'])
def upload_haproxy_config(self, amphora_id, lb_id): def upload_haproxy_config(self, amphora_id, lb_id):
return self._loadbalancer.upload_haproxy_config(amphora_id, lb_id) return self._loadbalancer.upload_haproxy_config(amphora_id, lb_id)
@ -257,3 +263,23 @@ class Server(object):
def version_discovery(self): def version_discovery(self):
return webob.Response(json={'api_version': api_server.VERSION}) return webob.Response(json={'api_version': api_server.VERSION})
def set_interface_rules(self, ip_addr):
interface_webob = self._amphora_info.get_interface(ip_addr)
if interface_webob.status_code != 200:
return interface_webob
interface = interface_webob.json['interface']
try:
rules_info = flask.request.get_json()
validate(rules_info, rules_schema.SUPPORTED_RULES_SCHEMA)
except Exception as e:
raise exceptions.BadRequest(
description='Invalid rules information') from e
nftable_utils.write_nftable_vip_rules_file(interface, rules_info)
nftable_utils.load_nftables_file()
return webob.Response(json={'message': 'OK'}, status=200)

View File

@ -210,15 +210,7 @@ class InterfaceController(object):
nftable_utils.write_nftable_vip_rules_file(interface.name, []) nftable_utils.write_nftable_vip_rules_file(interface.name, [])
cmd = [consts.NFT_CMD, '-o', '-f', consts.NFT_VIP_RULES_FILE] nftable_utils.load_nftables_file()
try:
subprocess.check_output(cmd, stderr=subprocess.STDOUT)
except Exception as e:
if hasattr(e, 'output'):
LOG.error(e.output)
else:
LOG.error(e)
raise
def up(self, interface): def up(self, interface):
LOG.info("Setting interface %s up", interface.name) LOG.info("Setting interface %s up", interface.name)

View File

@ -13,8 +13,17 @@
# under the License. # under the License.
import os import os
import stat import stat
import subprocess
from octavia_lib.common import constants as lib_consts
from oslo_log import log as logging
from webob import exc
from octavia.amphorae.backends.utils import network_namespace
from octavia.common import constants as consts from octavia.common import constants as consts
from octavia.common import utils
LOG = logging.getLogger(__name__)
def write_nftable_vip_rules_file(interface_name, rules): def write_nftable_vip_rules_file(interface_name, rules):
@ -28,7 +37,17 @@ def write_nftable_vip_rules_file(interface_name, rules):
hook_string = (f' type filter hook ingress device {interface_name} ' hook_string = (f' type filter hook ingress device {interface_name} '
f'priority {consts.NFT_SRIOV_PRIORITY}; policy drop;\n') f'priority {consts.NFT_SRIOV_PRIORITY}; policy drop;\n')
# Check if an existing rules file exists or we if need to create an # Allow ICMP destination unreachable for PMTUD
icmp_string = ' icmp type destination-unreachable accept\n'
# Allow the required neighbor solicitation/discovery PMTUD ICMPV6
icmpv6_string = (' icmpv6 type { nd-neighbor-solicit, '
'nd-router-advert, nd-neighbor-advert, packet-too-big, '
'destination-unreachable } accept\n')
# Allow DHCP responses
dhcp_string = ' udp sport 67 udp dport 68 accept\n'
dhcpv6_string = ' udp sport 547 udp dport 546 accept\n'
# Check if an existing rules file exists or we be need to create an
# "drop all" file with no rules except for VRRP. If it exists, we should # "drop all" file with no rules except for VRRP. If it exists, we should
# not overwrite it here as it could be a reboot unless we were passed new # not overwrite it here as it could be a reboot unless we were passed new
# rules. # rules.
@ -40,15 +59,21 @@ def write_nftable_vip_rules_file(interface_name, rules):
# Clear the existing rules in the kernel # Clear the existing rules in the kernel
# Note: The "nft -f" method is atomic, so clearing the rules will # Note: The "nft -f" method is atomic, so clearing the rules will
# not leave the amphora exposed. # not leave the amphora exposed.
file.write(f'flush chain {consts.NFT_FAMILY} ' # Create and delete the table to not get errors if the table does
f'{consts.NFT_VIP_TABLE} {consts.NFT_VIP_CHAIN}\n') # not exist yet.
file.write(f'table {consts.NFT_FAMILY} {consts.NFT_VIP_TABLE} '
'{}\n')
file.write(f'delete table {consts.NFT_FAMILY} '
f'{consts.NFT_VIP_TABLE}\n')
file.write(table_string) file.write(table_string)
file.write(chain_string) file.write(chain_string)
file.write(hook_string) file.write(hook_string)
# TODO(johnsom) Add peer ports here consts.HAPROXY_BASE_PEER_PORT file.write(icmp_string)
# and ip protocol 112 for VRRP. Need the peer address file.write(icmpv6_string)
file.write(dhcp_string)
file.write(dhcpv6_string)
for rule in rules: for rule in rules:
file.write(f' {rule}\n') file.write(f' {_build_rule_cmd(rule)}\n')
file.write(' }\n') # close the chain file.write(' }\n') # close the chain
file.write('}\n') # close the table file.write('}\n') # close the table
else: # No existing rules, create the "drop all" base rules else: # No existing rules, create the "drop all" base rules
@ -57,7 +82,44 @@ def write_nftable_vip_rules_file(interface_name, rules):
file.write(table_string) file.write(table_string)
file.write(chain_string) file.write(chain_string)
file.write(hook_string) file.write(hook_string)
# TODO(johnsom) Add peer ports here consts.HAPROXY_BASE_PEER_PORT file.write(icmp_string)
# and ip protocol 112 for VRRP. Need the peer address file.write(icmpv6_string)
file.write(dhcp_string)
file.write(dhcpv6_string)
file.write(' }\n') # close the chain file.write(' }\n') # close the chain
file.write('}\n') # close the table file.write('}\n') # close the table
def _build_rule_cmd(rule):
prefix_saddr = ''
if rule[consts.CIDR] and rule[consts.CIDR] != '0.0.0.0/0':
cidr_ip_version = utils.ip_version(rule[consts.CIDR].split('/')[0])
if cidr_ip_version == 4:
prefix_saddr = f'ip saddr {rule[consts.CIDR]} '
elif cidr_ip_version == 6:
prefix_saddr = f'ip6 saddr {rule[consts.CIDR]} '
else:
raise exc.HTTPBadRequest(explanation='Unknown ip version')
if rule[consts.PROTOCOL] == lib_consts.PROTOCOL_SCTP:
return f'{prefix_saddr}sctp dport {rule[consts.PORT]} accept'
if rule[consts.PROTOCOL] == lib_consts.PROTOCOL_TCP:
return f'{prefix_saddr}tcp dport {rule[consts.PORT]} accept'
if rule[consts.PROTOCOL] == lib_consts.PROTOCOL_UDP:
return f'{prefix_saddr}udp dport {rule[consts.PORT]} accept'
if rule[consts.PROTOCOL] == consts.VRRP:
return f'{prefix_saddr}ip protocol 112 accept'
raise exc.HTTPBadRequest(explanation='Unknown protocol used in rules')
def load_nftables_file():
cmd = [consts.NFT_CMD, '-o', '-f', consts.NFT_VIP_RULES_FILE]
try:
with network_namespace.NetworkNamespace(consts.AMPHORA_NAMESPACE):
subprocess.check_output(cmd, stderr=subprocess.STDOUT)
except Exception as e:
if hasattr(e, 'output'):
LOG.error(e.output)
else:
LOG.error(e)
raise

View File

@ -252,6 +252,17 @@ class AmphoraLoadBalancerDriver(object, metaclass=abc.ABCMeta):
:raises TimeOutException: The amphora didn't reply :raises TimeOutException: The amphora didn't reply
""" """
@abc.abstractmethod
def set_interface_rules(self, amphora: db_models.Amphora, ip_address,
rules):
"""Sets interface firewall rules in the amphora
:param amphora: The amphora to query.
:param ip_address: The IP address assigned to the interface the rules
will be applied on.
:param rules: The l1st of allow rules to apply.
"""
class VRRPDriverMixin(object, metaclass=abc.ABCMeta): class VRRPDriverMixin(object, metaclass=abc.ABCMeta):
"""Abstract mixin class for VRRP support in loadbalancer amphorae """Abstract mixin class for VRRP support in loadbalancer amphorae

View File

@ -598,6 +598,24 @@ class HaproxyAmphoraLoadBalancerDriver(
amphora, ip_address, timeout_dict, log_error=False) amphora, ip_address, timeout_dict, log_error=False)
return response_json.get('interface', None) return response_json.get('interface', None)
def set_interface_rules(self, amphora: db_models.Amphora,
ip_address, rules):
"""Sets interface firewall rules in the amphora
:param amphora: The amphora to query.
:param ip_address: The IP address assigned to the interface the rules
will be applied on.
:param rules: The l1st of allow rules to apply.
"""
try:
self._populate_amphora_api_version(amphora)
self.clients[amphora.api_version].set_interface_rules(
amphora, ip_address, rules)
except exc.NotFound as e:
LOG.debug('Amphora %s does not support the set_interface_rules '
'API.', amphora.id)
raise driver_except.AmpDriverNotImplementedError() from e
# Check a custom hostname # Check a custom hostname
class CustomHostNameCheckingAdapter(requests.adapters.HTTPAdapter): class CustomHostNameCheckingAdapter(requests.adapters.HTTPAdapter):
@ -867,3 +885,7 @@ class AmphoraAPIClient1_0(AmphoraAPIClientBase):
def update_agent_config(self, amp, agent_config, timeout_dict=None): def update_agent_config(self, amp, agent_config, timeout_dict=None):
r = self.put(amp, 'config', timeout_dict, data=agent_config) r = self.put(amp, 'config', timeout_dict, data=agent_config)
return exc.check_exception(r) return exc.check_exception(r)
def set_interface_rules(self, amp, ip_address, rules):
r = self.put(amp, f'interface/{ip_address}/rules', json=rules)
return exc.check_exception(r)

View File

@ -218,3 +218,6 @@ class NoopAmphoraLoadBalancerDriver(
def check(self, amphora, timeout_dict=None): def check(self, amphora, timeout_dict=None):
pass pass
def set_interface_rules(self, amphora, ip_address, rules):
pass

View File

@ -308,6 +308,7 @@ AMP_DATA = 'amp_data'
AMP_VRRP_INT = 'amp_vrrp_int' AMP_VRRP_INT = 'amp_vrrp_int'
AMPHORA = 'amphora' AMPHORA = 'amphora'
AMPHORA_DICT = 'amphora_dict' AMPHORA_DICT = 'amphora_dict'
AMPHORA_FIREWALL_RULES = 'amphora_firewall_rules'
AMPHORA_ID = 'amphora_id' AMPHORA_ID = 'amphora_id'
AMPHORA_INDEX = 'amphora_index' AMPHORA_INDEX = 'amphora_index'
AMPHORA_NETWORK_CONFIG = 'amphora_network_config' AMPHORA_NETWORK_CONFIG = 'amphora_network_config'
@ -460,6 +461,7 @@ VIP_VNIC_TYPE = 'vip_vnic_type'
VNIC_TYPE = 'vnic_type' VNIC_TYPE = 'vnic_type'
VNIC_TYPE_DIRECT = 'direct' VNIC_TYPE_DIRECT = 'direct'
VNIC_TYPE_NORMAL = 'normal' VNIC_TYPE_NORMAL = 'normal'
VRRP = 'vrrp'
VRRP_ID = 'vrrp_id' VRRP_ID = 'vrrp_id'
VRRP_IP = 'vrrp_ip' VRRP_IP = 'vrrp_ip'
VRRP_GROUP = 'vrrp_group' VRRP_GROUP = 'vrrp_group'
@ -468,6 +470,7 @@ VRRP_PORT_ID = 'vrrp_port_id'
VRRP_PRIORITY = 'vrrp_priority' VRRP_PRIORITY = 'vrrp_priority'
# Taskflow flow and task names # Taskflow flow and task names
AMP_UPDATE_FW_SUBFLOW = 'amphora-update-firewall-subflow'
CERT_ROTATE_AMPHORA_FLOW = 'octavia-cert-rotate-amphora-flow' CERT_ROTATE_AMPHORA_FLOW = 'octavia-cert-rotate-amphora-flow'
CREATE_AMPHORA_FLOW = 'octavia-create-amphora-flow' CREATE_AMPHORA_FLOW = 'octavia-create-amphora-flow'
CREATE_AMPHORA_RETRY_SUBFLOW = 'octavia-create-amphora-retry-subflow' CREATE_AMPHORA_RETRY_SUBFLOW = 'octavia-create-amphora-retry-subflow'
@ -496,6 +499,7 @@ DELETE_L7RULE_FLOW = 'octavia-delete-l7policy-flow'
FAILOVER_AMPHORA_FLOW = 'octavia-failover-amphora-flow' FAILOVER_AMPHORA_FLOW = 'octavia-failover-amphora-flow'
FAILOVER_LOADBALANCER_FLOW = 'octavia-failover-loadbalancer-flow' FAILOVER_LOADBALANCER_FLOW = 'octavia-failover-loadbalancer-flow'
FINALIZE_AMPHORA_FLOW = 'octavia-finalize-amphora-flow' FINALIZE_AMPHORA_FLOW = 'octavia-finalize-amphora-flow'
FIREWALL_RULES_SUBFLOW = 'firewall-rules-subflow'
LOADBALANCER_NETWORKING_SUBFLOW = 'octavia-new-loadbalancer-net-subflow' LOADBALANCER_NETWORKING_SUBFLOW = 'octavia-new-loadbalancer-net-subflow'
UPDATE_HEALTH_MONITOR_FLOW = 'octavia-update-health-monitor-flow' UPDATE_HEALTH_MONITOR_FLOW = 'octavia-update-health-monitor-flow'
UPDATE_LISTENER_FLOW = 'octavia-update-listener-flow' UPDATE_LISTENER_FLOW = 'octavia-update-listener-flow'
@ -583,6 +587,7 @@ CREATE_VIP_BASE_PORT = 'create-vip-base-port'
DELETE_AMPHORA = 'delete-amphora' DELETE_AMPHORA = 'delete-amphora'
DELETE_PORT = 'delete-port' DELETE_PORT = 'delete-port'
DISABLE_AMP_HEALTH_MONITORING = 'disable-amphora-health-monitoring' DISABLE_AMP_HEALTH_MONITORING = 'disable-amphora-health-monitoring'
GET_AMPHORA_FIREWALL_RULES = 'get-amphora-firewall-rules'
GET_AMPHORA_NETWORK_CONFIGS_BY_ID = 'get-amphora-network-configs-by-id' GET_AMPHORA_NETWORK_CONFIGS_BY_ID = 'get-amphora-network-configs-by-id'
GET_AMPHORAE_FROM_LB = 'get-amphorae-from-lb' GET_AMPHORAE_FROM_LB = 'get-amphorae-from-lb'
GET_SUBNET_FROM_VIP = 'get-subnet-from-vip' GET_SUBNET_FROM_VIP = 'get-subnet-from-vip'
@ -595,6 +600,7 @@ RELOAD_LB_AFTER_AMP_ASSOC = 'reload-lb-after-amp-assoc'
RELOAD_LB_AFTER_AMP_ASSOC_FULL_GRAPH = 'reload-lb-after-amp-assoc-full-graph' RELOAD_LB_AFTER_AMP_ASSOC_FULL_GRAPH = 'reload-lb-after-amp-assoc-full-graph'
RELOAD_LB_AFTER_PLUG_VIP = 'reload-lb-after-plug-vip' RELOAD_LB_AFTER_PLUG_VIP = 'reload-lb-after-plug-vip'
RELOAD_LB_BEFOR_ALLOCATE_VIP = 'reload-lb-before-allocate-vip' RELOAD_LB_BEFOR_ALLOCATE_VIP = 'reload-lb-before-allocate-vip'
SET_AMPHORA_FIREWALL_RULES = 'set-amphora-firewall-rules'
UPDATE_AMP_FAILOVER_DETAILS = 'update-amp-failover-details' UPDATE_AMP_FAILOVER_DETAILS = 'update-amp-failover-details'
@ -974,6 +980,7 @@ NFT_ADD = 'add'
NFT_CMD = '/usr/sbin/nft' NFT_CMD = '/usr/sbin/nft'
NFT_FAMILY = 'inet' NFT_FAMILY = 'inet'
NFT_VIP_RULES_FILE = '/var/lib/octavia/nftables-vip.rules' NFT_VIP_RULES_FILE = '/var/lib/octavia/nftables-vip.rules'
NFT_VIP_TABLE = 'amphora-vip' NFT_VIP_TABLE = 'amphora_vip'
NFT_VIP_CHAIN = 'amphora-vip-chain' NFT_VIP_CHAIN = 'amphora_vip_chain'
NFT_SRIOV_PRIORITY = '-310' NFT_SRIOV_PRIORITY = '-310'
PROTOCOL = 'protocol'

View File

@ -191,3 +191,9 @@ class exception_logger(object):
self.logger(e) self.logger(e)
return None return None
return call return call
def map_protocol_to_nftable_protocol(rule_dict):
rule_dict[constants.PROTOCOL] = (
constants.L4_PROTOCOL_MAP[rule_dict[constants.PROTOCOL]])
return rule_dict

View File

@ -270,6 +270,14 @@ class ControllerWorker(object):
raise db_exceptions.NoResultFound raise db_exceptions.NoResultFound
load_balancer = db_listener.load_balancer load_balancer = db_listener.load_balancer
flavor_dict = {}
if load_balancer.flavor_id:
with session.begin():
flavor_dict = (
self._flavor_repo.get_flavor_metadata_dict(
session, load_balancer.flavor_id))
flavor_dict[constants.LOADBALANCER_TOPOLOGY] = load_balancer.topology
provider_lb = provider_utils.db_loadbalancer_to_provider_loadbalancer( provider_lb = provider_utils.db_loadbalancer_to_provider_loadbalancer(
load_balancer).to_dict(recurse=True) load_balancer).to_dict(recurse=True)
@ -279,7 +287,7 @@ class ControllerWorker(object):
self.run_flow( self.run_flow(
flow_utils.get_create_listener_flow, flow_utils.get_create_listener_flow,
store=store) flavor_dict=flavor_dict, store=store)
def delete_listener(self, listener): def delete_listener(self, listener):
"""Deletes a listener. """Deletes a listener.
@ -288,12 +296,32 @@ class ControllerWorker(object):
:returns: None :returns: None
:raises ListenerNotFound: The referenced listener was not found :raises ListenerNotFound: The referenced listener was not found
""" """
try:
db_lb = self._get_db_obj_until_pending_update(
self._lb_repo, listener[constants.LOADBALANCER_ID])
except tenacity.RetryError as e:
LOG.warning('Loadbalancer did not go into %s in 60 seconds. '
'This either due to an in-progress Octavia upgrade '
'or an overloaded and failing database. Assuming '
'an upgrade is in progress and continuing.',
constants.PENDING_UPDATE)
db_lb = e.last_attempt.result()
flavor_dict = {}
if db_lb.flavor_id:
session = db_apis.get_session()
with session.begin():
flavor_dict = (
self._flavor_repo.get_flavor_metadata_dict(
session, db_lb.flavor_id))
flavor_dict[constants.LOADBALANCER_TOPOLOGY] = db_lb.topology
store = {constants.LISTENER: listener, store = {constants.LISTENER: listener,
constants.LOADBALANCER_ID: constants.LOADBALANCER_ID:
listener[constants.LOADBALANCER_ID], listener[constants.LOADBALANCER_ID],
constants.PROJECT_ID: listener[constants.PROJECT_ID]} constants.PROJECT_ID: listener[constants.PROJECT_ID]}
self.run_flow( self.run_flow(
flow_utils.get_delete_listener_flow, flow_utils.get_delete_listener_flow, flavor_dict=flavor_dict,
store=store) store=store)
def update_listener(self, listener, listener_updates): def update_listener(self, listener, listener_updates):
@ -315,12 +343,21 @@ class ControllerWorker(object):
constants.PENDING_UPDATE) constants.PENDING_UPDATE)
db_lb = e.last_attempt.result() db_lb = e.last_attempt.result()
session = db_apis.get_session()
flavor_dict = {}
if db_lb.flavor_id:
with session.begin():
flavor_dict = (
self._flavor_repo.get_flavor_metadata_dict(
session, db_lb.flavor_id))
flavor_dict[constants.LOADBALANCER_TOPOLOGY] = db_lb.topology
store = {constants.LISTENER: listener, store = {constants.LISTENER: listener,
constants.UPDATE_DICT: listener_updates, constants.UPDATE_DICT: listener_updates,
constants.LOADBALANCER_ID: db_lb.id, constants.LOADBALANCER_ID: db_lb.id,
constants.LISTENERS: [listener]} constants.LISTENERS: [listener]}
self.run_flow( self.run_flow(
flow_utils.get_update_listener_flow, flow_utils.get_update_listener_flow, flavor_dict=flavor_dict,
store=store) store=store)
@tenacity.retry( @tenacity.retry(
@ -998,6 +1035,7 @@ class ControllerWorker(object):
lb_id = loadbalancer.id lb_id = loadbalancer.id
# Even if the LB doesn't have a flavor, create one and # Even if the LB doesn't have a flavor, create one and
# pass through the topology. # pass through the topology.
flavor_dict = {}
if loadbalancer.flavor_id: if loadbalancer.flavor_id:
with session.begin(): with session.begin():
flavor_dict = ( flavor_dict = (
@ -1005,9 +1043,6 @@ class ControllerWorker(object):
session, loadbalancer.flavor_id)) session, loadbalancer.flavor_id))
flavor_dict[constants.LOADBALANCER_TOPOLOGY] = ( flavor_dict[constants.LOADBALANCER_TOPOLOGY] = (
loadbalancer.topology) loadbalancer.topology)
else:
flavor_dict = {constants.LOADBALANCER_TOPOLOGY:
loadbalancer.topology}
if loadbalancer.availability_zone: if loadbalancer.availability_zone:
with session.begin(): with session.begin():
az_metadata = ( az_metadata = (
@ -1162,13 +1197,12 @@ class ControllerWorker(object):
# We must provide a topology in the flavor definition # We must provide a topology in the flavor definition
# here for the amphora to be created with the correct # here for the amphora to be created with the correct
# configuration. # configuration.
flavor = {}
if lb.flavor_id: if lb.flavor_id:
with session.begin(): with session.begin():
flavor = self._flavor_repo.get_flavor_metadata_dict( flavor = self._flavor_repo.get_flavor_metadata_dict(
session, lb.flavor_id) session, lb.flavor_id)
flavor[constants.LOADBALANCER_TOPOLOGY] = lb.topology flavor[constants.LOADBALANCER_TOPOLOGY] = lb.topology
else:
flavor = {constants.LOADBALANCER_TOPOLOGY: lb.topology}
if lb: if lb:
provider_lb_dict = ( provider_lb_dict = (

View File

@ -28,6 +28,7 @@ from octavia.controller.worker.v2.tasks import database_tasks
from octavia.controller.worker.v2.tasks import lifecycle_tasks from octavia.controller.worker.v2.tasks import lifecycle_tasks
from octavia.controller.worker.v2.tasks import network_tasks from octavia.controller.worker.v2.tasks import network_tasks
from octavia.controller.worker.v2.tasks import retry_tasks from octavia.controller.worker.v2.tasks import retry_tasks
from octavia.controller.worker.v2.tasks import shim_tasks
CONF = cfg.CONF CONF = cfg.CONF
LOG = logging.getLogger(__name__) LOG = logging.getLogger(__name__)
@ -227,7 +228,7 @@ class AmphoraFlows(object):
def get_vrrp_subflow(self, prefix, timeout_dict=None, def get_vrrp_subflow(self, prefix, timeout_dict=None,
create_vrrp_group=True, create_vrrp_group=True,
get_amphorae_status=True): get_amphorae_status=True, flavor_dict=None):
sf_name = prefix + '-' + constants.GET_VRRP_SUBFLOW sf_name = prefix + '-' + constants.GET_VRRP_SUBFLOW
vrrp_subflow = linear_flow.Flow(sf_name) vrrp_subflow = linear_flow.Flow(sf_name)
@ -259,7 +260,7 @@ class AmphoraFlows(object):
# unordered subflow. # unordered subflow.
update_amps_subflow = unordered_flow.Flow('VRRP-update-subflow') update_amps_subflow = unordered_flow.Flow('VRRP-update-subflow')
# We have three tasks to run in order, per amphora # We have tasks to run in order, per amphora
amp_0_subflow = linear_flow.Flow('VRRP-amp-0-update-subflow') amp_0_subflow = linear_flow.Flow('VRRP-amp-0-update-subflow')
amp_0_subflow.add(amphora_driver_tasks.AmphoraIndexUpdateVRRPInterface( amp_0_subflow.add(amphora_driver_tasks.AmphoraIndexUpdateVRRPInterface(
@ -279,6 +280,20 @@ class AmphoraFlows(object):
inject={constants.AMPHORA_INDEX: 0, inject={constants.AMPHORA_INDEX: 0,
constants.TIMEOUT_DICT: timeout_dict})) constants.TIMEOUT_DICT: timeout_dict}))
if flavor_dict and flavor_dict.get(constants.SRIOV_VIP, False):
amp_0_subflow.add(database_tasks.GetAmphoraFirewallRules(
name=sf_name + '-0-' + constants.GET_AMPHORA_FIREWALL_RULES,
requires=(constants.AMPHORAE,
constants.AMPHORAE_NETWORK_CONFIG),
provides=constants.AMPHORA_FIREWALL_RULES,
inject={constants.AMPHORA_INDEX: 0}))
amp_0_subflow.add(amphora_driver_tasks.SetAmphoraFirewallRules(
name=sf_name + '-0-' + constants.SET_AMPHORA_FIREWALL_RULES,
requires=(constants.AMPHORAE,
constants.AMPHORA_FIREWALL_RULES),
inject={constants.AMPHORA_INDEX: 0}))
amp_0_subflow.add(amphora_driver_tasks.AmphoraIndexVRRPStart( amp_0_subflow.add(amphora_driver_tasks.AmphoraIndexVRRPStart(
name=sf_name + '-0-' + constants.AMP_VRRP_START, name=sf_name + '-0-' + constants.AMP_VRRP_START,
requires=(constants.AMPHORAE, constants.AMPHORAE_STATUS), requires=(constants.AMPHORAE, constants.AMPHORAE_STATUS),
@ -304,6 +319,21 @@ class AmphoraFlows(object):
rebind={constants.NEW_AMPHORA_ID: constants.AMPHORA_ID}, rebind={constants.NEW_AMPHORA_ID: constants.AMPHORA_ID},
inject={constants.AMPHORA_INDEX: 1, inject={constants.AMPHORA_INDEX: 1,
constants.TIMEOUT_DICT: timeout_dict})) constants.TIMEOUT_DICT: timeout_dict}))
if flavor_dict and flavor_dict.get(constants.SRIOV_VIP, False):
amp_1_subflow.add(database_tasks.GetAmphoraFirewallRules(
name=sf_name + '-1-' + constants.GET_AMPHORA_FIREWALL_RULES,
requires=(constants.AMPHORAE,
constants.AMPHORAE_NETWORK_CONFIG),
provides=constants.AMPHORA_FIREWALL_RULES,
inject={constants.AMPHORA_INDEX: 1}))
amp_1_subflow.add(amphora_driver_tasks.SetAmphoraFirewallRules(
name=sf_name + '-1-' + constants.SET_AMPHORA_FIREWALL_RULES,
requires=(constants.AMPHORAE,
constants.AMPHORA_FIREWALL_RULES),
inject={constants.AMPHORA_INDEX: 1}))
amp_1_subflow.add(amphora_driver_tasks.AmphoraIndexVRRPStart( amp_1_subflow.add(amphora_driver_tasks.AmphoraIndexVRRPStart(
name=sf_name + '-1-' + constants.AMP_VRRP_START, name=sf_name + '-1-' + constants.AMP_VRRP_START,
requires=(constants.AMPHORAE, constants.AMPHORAE_STATUS), requires=(constants.AMPHORAE, constants.AMPHORAE_STATUS),
@ -443,6 +473,27 @@ class AmphoraFlows(object):
requires=(constants.AMPHORA, constants.LOADBALANCER, requires=(constants.AMPHORA, constants.LOADBALANCER,
constants.AMPHORAE_NETWORK_CONFIG))) constants.AMPHORAE_NETWORK_CONFIG)))
if flavor_dict and flavor_dict.get(constants.SRIOV_VIP, False):
amp_for_failover_flow.add(
shim_tasks.AmphoraToAmphoraeWithVRRPIP(
name=prefix + '-' + constants.AMPHORA_TO_AMPHORAE_VRRP_IP,
requires=(constants.AMPHORA, constants.BASE_PORT),
provides=constants.NEW_AMPHORAE))
amp_for_failover_flow.add(database_tasks.GetAmphoraFirewallRules(
name=prefix + '-' + constants.GET_AMPHORA_FIREWALL_RULES,
requires=(constants.AMPHORAE,
constants.AMPHORAE_NETWORK_CONFIG),
rebind={constants.AMPHORAE: constants.NEW_AMPHORAE},
provides=constants.AMPHORA_FIREWALL_RULES,
inject={constants.AMPHORA_INDEX: 0}))
amp_for_failover_flow.add(
amphora_driver_tasks.SetAmphoraFirewallRules(
name=prefix + '-' + constants.SET_AMPHORA_FIREWALL_RULES,
requires=(constants.AMPHORAE,
constants.AMPHORA_FIREWALL_RULES),
rebind={constants.AMPHORAE: constants.NEW_AMPHORAE},
inject={constants.AMPHORA_INDEX: 0}))
# Plug member ports # Plug member ports
amp_for_failover_flow.add(network_tasks.CalculateAmphoraDelta( amp_for_failover_flow.add(network_tasks.CalculateAmphoraDelta(
name=prefix + '-' + constants.CALCULATE_AMPHORA_DELTA, name=prefix + '-' + constants.CALCULATE_AMPHORA_DELTA,
@ -601,7 +652,8 @@ class AmphoraFlows(object):
failover_amp_flow.add( failover_amp_flow.add(
self.get_vrrp_subflow(constants.GET_VRRP_SUBFLOW, self.get_vrrp_subflow(constants.GET_VRRP_SUBFLOW,
timeout_dict, create_vrrp_group=False, timeout_dict, create_vrrp_group=False,
get_amphorae_status=False)) get_amphorae_status=False,
flavor_dict=flavor_dict))
# Reload the listener. This needs to be done here because # Reload the listener. This needs to be done here because
# it will create the required haproxy check scripts for # it will create the required haproxy check scripts for

View File

@ -139,20 +139,21 @@ def get_update_l7rule_flow():
return L7_RULES_FLOWS.get_update_l7rule_flow() return L7_RULES_FLOWS.get_update_l7rule_flow()
def get_create_listener_flow(): def get_create_listener_flow(flavor_dict=None):
return LISTENER_FLOWS.get_create_listener_flow() return LISTENER_FLOWS.get_create_listener_flow(flavor_dict=flavor_dict)
def get_create_all_listeners_flow(): def get_create_all_listeners_flow(flavor_dict=None):
return LISTENER_FLOWS.get_create_all_listeners_flow() return LISTENER_FLOWS.get_create_all_listeners_flow(
flavor_dict=flavor_dict)
def get_delete_listener_flow(): def get_delete_listener_flow(flavor_dict=None):
return LISTENER_FLOWS.get_delete_listener_flow() return LISTENER_FLOWS.get_delete_listener_flow(flavor_dict=flavor_dict)
def get_update_listener_flow(): def get_update_listener_flow(flavor_dict=None):
return LISTENER_FLOWS.get_update_listener_flow() return LISTENER_FLOWS.get_update_listener_flow(flavor_dict=flavor_dict)
def get_create_member_flow(): def get_create_member_flow():

View File

@ -14,6 +14,7 @@
# #
from taskflow.patterns import linear_flow from taskflow.patterns import linear_flow
from taskflow.patterns import unordered_flow
from octavia.common import constants from octavia.common import constants
from octavia.controller.worker.v2.tasks import amphora_driver_tasks from octavia.controller.worker.v2.tasks import amphora_driver_tasks
@ -24,7 +25,7 @@ from octavia.controller.worker.v2.tasks import network_tasks
class ListenerFlows(object): class ListenerFlows(object):
def get_create_listener_flow(self): def get_create_listener_flow(self, flavor_dict=None):
"""Create a flow to create a listener """Create a flow to create a listener
:returns: The flow for creating a listener :returns: The flow for creating a listener
@ -36,13 +37,18 @@ class ListenerFlows(object):
requires=constants.LOADBALANCER_ID)) requires=constants.LOADBALANCER_ID))
create_listener_flow.add(network_tasks.UpdateVIP( create_listener_flow.add(network_tasks.UpdateVIP(
requires=constants.LISTENERS)) requires=constants.LISTENERS))
if flavor_dict and flavor_dict.get(constants.SRIOV_VIP, False):
create_listener_flow.add(*self._get_firewall_rules_subflow(
flavor_dict))
create_listener_flow.add(database_tasks. create_listener_flow.add(database_tasks.
MarkLBAndListenersActiveInDB( MarkLBAndListenersActiveInDB(
requires=(constants.LOADBALANCER_ID, requires=(constants.LOADBALANCER_ID,
constants.LISTENERS))) constants.LISTENERS)))
return create_listener_flow return create_listener_flow
def get_create_all_listeners_flow(self): def get_create_all_listeners_flow(self, flavor_dict=None):
"""Create a flow to create all listeners """Create a flow to create all listeners
:returns: The flow for creating all listeners :returns: The flow for creating all listeners
@ -60,12 +66,17 @@ class ListenerFlows(object):
requires=constants.LOADBALANCER_ID)) requires=constants.LOADBALANCER_ID))
create_all_listeners_flow.add(network_tasks.UpdateVIP( create_all_listeners_flow.add(network_tasks.UpdateVIP(
requires=constants.LISTENERS)) requires=constants.LISTENERS))
if flavor_dict and flavor_dict.get(constants.SRIOV_VIP, False):
create_all_listeners_flow.add(*self._get_firewall_rules_subflow(
flavor_dict))
create_all_listeners_flow.add( create_all_listeners_flow.add(
database_tasks.MarkHealthMonitorsOnlineInDB( database_tasks.MarkHealthMonitorsOnlineInDB(
requires=constants.LOADBALANCER)) requires=constants.LOADBALANCER))
return create_all_listeners_flow return create_all_listeners_flow
def get_delete_listener_flow(self): def get_delete_listener_flow(self, flavor_dict=None):
"""Create a flow to delete a listener """Create a flow to delete a listener
:returns: The flow for deleting a listener :returns: The flow for deleting a listener
@ -79,6 +90,11 @@ class ListenerFlows(object):
requires=constants.LOADBALANCER_ID)) requires=constants.LOADBALANCER_ID))
delete_listener_flow.add(database_tasks.DeleteListenerInDB( delete_listener_flow.add(database_tasks.DeleteListenerInDB(
requires=constants.LISTENER)) requires=constants.LISTENER))
if flavor_dict and flavor_dict.get(constants.SRIOV_VIP, False):
delete_listener_flow.add(*self._get_firewall_rules_subflow(
flavor_dict))
delete_listener_flow.add(database_tasks.DecrementListenerQuota( delete_listener_flow.add(database_tasks.DecrementListenerQuota(
requires=constants.PROJECT_ID)) requires=constants.PROJECT_ID))
delete_listener_flow.add(database_tasks.MarkLBActiveInDBByListener( delete_listener_flow.add(database_tasks.MarkLBActiveInDBByListener(
@ -86,7 +102,7 @@ class ListenerFlows(object):
return delete_listener_flow return delete_listener_flow
def get_delete_listener_internal_flow(self, listener): def get_delete_listener_internal_flow(self, listener, flavor_dict=None):
"""Create a flow to delete a listener and l7policies internally """Create a flow to delete a listener and l7policies internally
(will skip deletion on the amp and marking LB active) (will skip deletion on the amp and marking LB active)
@ -104,13 +120,22 @@ class ListenerFlows(object):
name='delete_listener_in_db_' + listener_id, name='delete_listener_in_db_' + listener_id,
requires=constants.LISTENER, requires=constants.LISTENER,
inject={constants.LISTENER: listener})) inject={constants.LISTENER: listener}))
# Currently the flavor_dict will always be None since there is
# no point updating the firewall rules when deleting the LB.
# However, this may be used for additional flows in the future, so
# adding this code for completeness.
if flavor_dict and flavor_dict.get(constants.SRIOV_VIP, False):
delete_listener_flow.add(*self._get_firewall_rules_subflow(
flavor_dict))
delete_listener_flow.add(database_tasks.DecrementListenerQuota( delete_listener_flow.add(database_tasks.DecrementListenerQuota(
name='decrement_listener_quota_' + listener_id, name='decrement_listener_quota_' + listener_id,
requires=constants.PROJECT_ID)) requires=constants.PROJECT_ID))
return delete_listener_flow return delete_listener_flow
def get_update_listener_flow(self): def get_update_listener_flow(self, flavor_dict=None):
"""Create a flow to update a listener """Create a flow to update a listener
:returns: The flow for updating a listener :returns: The flow for updating a listener
@ -122,6 +147,11 @@ class ListenerFlows(object):
requires=constants.LOADBALANCER_ID)) requires=constants.LOADBALANCER_ID))
update_listener_flow.add(network_tasks.UpdateVIP( update_listener_flow.add(network_tasks.UpdateVIP(
requires=constants.LISTENERS)) requires=constants.LISTENERS))
if flavor_dict and flavor_dict.get(constants.SRIOV_VIP, False):
update_listener_flow.add(*self._get_firewall_rules_subflow(
flavor_dict))
update_listener_flow.add(database_tasks.UpdateListenerInDB( update_listener_flow.add(database_tasks.UpdateListenerInDB(
requires=[constants.LISTENER, constants.UPDATE_DICT])) requires=[constants.LISTENER, constants.UPDATE_DICT]))
update_listener_flow.add(database_tasks. update_listener_flow.add(database_tasks.
@ -130,3 +160,63 @@ class ListenerFlows(object):
constants.LISTENERS))) constants.LISTENERS)))
return update_listener_flow return update_listener_flow
def _get_firewall_rules_subflow(self, flavor_dict):
"""Creates a subflow that updates the firewall rules in the amphorae.
:returns: The subflow for updating firewall rules in the amphorae.
"""
sf_name = constants.FIREWALL_RULES_SUBFLOW
fw_rules_subflow = linear_flow.Flow(sf_name)
fw_rules_subflow.add(database_tasks.GetAmphoraeFromLoadbalancer(
name=sf_name + '-' + constants.GET_AMPHORAE_FROM_LB,
requires=constants.LOADBALANCER_ID,
provides=constants.AMPHORAE))
fw_rules_subflow.add(network_tasks.GetAmphoraeNetworkConfigs(
name=sf_name + '-' + constants.GET_AMP_NETWORK_CONFIG,
requires=constants.LOADBALANCER_ID,
provides=constants.AMPHORAE_NETWORK_CONFIG))
update_amps_subflow = unordered_flow.Flow(
constants.AMP_UPDATE_FW_SUBFLOW)
amp_0_subflow = linear_flow.Flow('amp-0-fw-update')
amp_0_subflow.add(database_tasks.GetAmphoraFirewallRules(
name=sf_name + '-0-' + constants.GET_AMPHORA_FIREWALL_RULES,
requires=(constants.AMPHORAE, constants.AMPHORAE_NETWORK_CONFIG),
provides=constants.AMPHORA_FIREWALL_RULES,
inject={constants.AMPHORA_INDEX: 0}))
amp_0_subflow.add(amphora_driver_tasks.SetAmphoraFirewallRules(
name=sf_name + '-0-' + constants.SET_AMPHORA_FIREWALL_RULES,
requires=(constants.AMPHORAE, constants.AMPHORA_FIREWALL_RULES),
inject={constants.AMPHORA_INDEX: 0}))
update_amps_subflow.add(amp_0_subflow)
if (flavor_dict[constants.LOADBALANCER_TOPOLOGY] ==
constants.TOPOLOGY_ACTIVE_STANDBY):
amp_1_subflow = linear_flow.Flow('amp-1-fw-update')
amp_1_subflow.add(database_tasks.GetAmphoraFirewallRules(
name=sf_name + '-1-' + constants.GET_AMPHORA_FIREWALL_RULES,
requires=(constants.AMPHORAE,
constants.AMPHORAE_NETWORK_CONFIG),
provides=constants.AMPHORA_FIREWALL_RULES,
inject={constants.AMPHORA_INDEX: 1}))
amp_1_subflow.add(amphora_driver_tasks.SetAmphoraFirewallRules(
name=sf_name + '-1-' + constants.SET_AMPHORA_FIREWALL_RULES,
requires=(constants.AMPHORAE,
constants.AMPHORA_FIREWALL_RULES),
inject={constants.AMPHORA_INDEX: 1}))
update_amps_subflow.add(amp_1_subflow)
fw_rules_subflow.add(update_amps_subflow)
return fw_rules_subflow

View File

@ -93,10 +93,12 @@ class LoadBalancerFlows(object):
post_amp_prefix = constants.POST_LB_AMP_ASSOCIATION_SUBFLOW post_amp_prefix = constants.POST_LB_AMP_ASSOCIATION_SUBFLOW
lb_create_flow.add( lb_create_flow.add(
self.get_post_lb_amp_association_flow(post_amp_prefix, topology)) self.get_post_lb_amp_association_flow(post_amp_prefix, topology,
flavor_dict=flavor_dict))
if listeners: if listeners:
lb_create_flow.add(*self._create_listeners_flow()) lb_create_flow.add(
*self._create_listeners_flow(flavor_dict=flavor_dict))
lb_create_flow.add( lb_create_flow.add(
database_tasks.MarkLBActiveInDB( database_tasks.MarkLBActiveInDB(
@ -177,6 +179,7 @@ class LoadBalancerFlows(object):
def _get_amp_net_subflow(self, sf_name, flavor_dict=None): def _get_amp_net_subflow(self, sf_name, flavor_dict=None):
flows = [] flows = []
# If we have an SRIOV VIP, we need to setup a firewall in the amp
if flavor_dict and flavor_dict.get(constants.SRIOV_VIP, False): if flavor_dict and flavor_dict.get(constants.SRIOV_VIP, False):
flows.append(network_tasks.CreateSRIOVBasePort( flows.append(network_tasks.CreateSRIOVBasePort(
name=sf_name + '-' + constants.PLUG_VIP_AMPHORA, name=sf_name + '-' + constants.PLUG_VIP_AMPHORA,
@ -192,7 +195,25 @@ class LoadBalancerFlows(object):
requires=(constants.LOADBALANCER, constants.AMPHORA, requires=(constants.LOADBALANCER, constants.AMPHORA,
constants.PORT_DATA), constants.PORT_DATA),
provides=constants.AMP_DATA)) provides=constants.AMP_DATA))
# TODO(johnsom) nftables need to be handled here in the SG patch flows.append(network_tasks.ApplyQosAmphora(
name=sf_name + '-' + constants.APPLY_QOS_AMP,
requires=(constants.LOADBALANCER, constants.AMP_DATA,
constants.UPDATE_DICT)))
flows.append(database_tasks.UpdateAmphoraVIPData(
name=sf_name + '-' + constants.UPDATE_AMPHORA_VIP_DATA,
requires=constants.AMP_DATA))
flows.append(network_tasks.GetAmphoraNetworkConfigs(
name=sf_name + '-' + constants.GET_AMP_NETWORK_CONFIG,
requires=(constants.LOADBALANCER, constants.AMPHORA),
provides=constants.AMPHORA_NETWORK_CONFIG))
# SR-IOV firewall rules are handled in AmphoraPostVIPPlug
# interface.py up
flows.append(amphora_driver_tasks.AmphoraPostVIPPlug(
name=sf_name + '-' + constants.AMP_POST_VIP_PLUG,
rebind={constants.AMPHORAE_NETWORK_CONFIG:
constants.AMPHORA_NETWORK_CONFIG},
requires=(constants.LOADBALANCER,
constants.AMPHORAE_NETWORK_CONFIG)))
else: else:
flows.append(network_tasks.PlugVIPAmphora( flows.append(network_tasks.PlugVIPAmphora(
name=sf_name + '-' + constants.PLUG_VIP_AMPHORA, name=sf_name + '-' + constants.PLUG_VIP_AMPHORA,
@ -219,7 +240,7 @@ class LoadBalancerFlows(object):
constants.AMPHORAE_NETWORK_CONFIG))) constants.AMPHORAE_NETWORK_CONFIG)))
return flows return flows
def _create_listeners_flow(self): def _create_listeners_flow(self, flavor_dict=None):
flows = [] flows = []
flows.append( flows.append(
database_tasks.ReloadLoadBalancer( database_tasks.ReloadLoadBalancer(
@ -252,11 +273,13 @@ class LoadBalancerFlows(object):
) )
) )
flows.append( flows.append(
self.listener_flows.get_create_all_listeners_flow() self.listener_flows.get_create_all_listeners_flow(
flavor_dict=flavor_dict)
) )
return flows return flows
def get_post_lb_amp_association_flow(self, prefix, topology): def get_post_lb_amp_association_flow(self, prefix, topology,
flavor_dict=None):
"""Reload the loadbalancer and create networking subflows for """Reload the loadbalancer and create networking subflows for
created/allocated amphorae. created/allocated amphorae.
@ -274,14 +297,15 @@ class LoadBalancerFlows(object):
post_create_LB_flow.add(database_tasks.GetAmphoraeFromLoadbalancer( post_create_LB_flow.add(database_tasks.GetAmphoraeFromLoadbalancer(
requires=constants.LOADBALANCER_ID, requires=constants.LOADBALANCER_ID,
provides=constants.AMPHORAE)) provides=constants.AMPHORAE))
vrrp_subflow = self.amp_flows.get_vrrp_subflow(prefix) vrrp_subflow = self.amp_flows.get_vrrp_subflow(
prefix, flavor_dict=flavor_dict)
post_create_LB_flow.add(vrrp_subflow) post_create_LB_flow.add(vrrp_subflow)
post_create_LB_flow.add(database_tasks.UpdateLoadbalancerInDB( post_create_LB_flow.add(database_tasks.UpdateLoadbalancerInDB(
requires=[constants.LOADBALANCER, constants.UPDATE_DICT])) requires=[constants.LOADBALANCER, constants.UPDATE_DICT]))
return post_create_LB_flow return post_create_LB_flow
def _get_delete_listeners_flow(self, listeners): def _get_delete_listeners_flow(self, listeners, flavor_dict=None):
"""Sets up an internal delete flow """Sets up an internal delete flow
:param listeners: A list of listener dicts :param listeners: A list of listener dicts
@ -291,7 +315,7 @@ class LoadBalancerFlows(object):
for listener in listeners: for listener in listeners:
listeners_delete_flow.add( listeners_delete_flow.add(
self.listener_flows.get_delete_listener_internal_flow( self.listener_flows.get_delete_listener_internal_flow(
listener)) listener, flavor_dict=flavor_dict))
return listeners_delete_flow return listeners_delete_flow
def get_delete_load_balancer_flow(self, lb): def get_delete_load_balancer_flow(self, lb):
@ -705,7 +729,7 @@ class LoadBalancerFlows(object):
failover_LB_flow.add(self.amp_flows.get_vrrp_subflow( failover_LB_flow.add(self.amp_flows.get_vrrp_subflow(
new_amp_role + '-' + constants.GET_VRRP_SUBFLOW, new_amp_role + '-' + constants.GET_VRRP_SUBFLOW,
timeout_dict, create_vrrp_group=False, timeout_dict, create_vrrp_group=False,
get_amphorae_status=False)) get_amphorae_status=False, flavor_dict=lb[constants.FLAVOR]))
# #### End of standby #### # #### End of standby ####

View File

@ -760,3 +760,26 @@ class AmphoraeGetConnectivityStatus(BaseAmphoraTask):
amphorae_status[amphora_id][constants.UNREACHABLE] = False amphorae_status[amphora_id][constants.UNREACHABLE] = False
return amphorae_status return amphorae_status
class SetAmphoraFirewallRules(BaseAmphoraTask):
"""Task to push updated firewall ruls to an amphora."""
def execute(self, amphorae: List[dict], amphora_index: int,
amphora_firewall_rules: List[dict]):
if (amphora_firewall_rules and
amphora_firewall_rules[0].get('non-sriov-vip', False)):
# Not an SRIOV VIP, so skip setting firewall rules.
# This is already logged in GetAmphoraFirewallRules.
return
session = db_apis.get_session()
with session.begin():
db_amp = self.amphora_repo.get(
session, id=amphorae[amphora_index][constants.ID])
self.amphora_driver.set_interface_rules(
db_amp,
amphorae[amphora_index][constants.VRRP_IP],
amphora_firewall_rules)

View File

@ -14,6 +14,7 @@
# #
from cryptography import fernet from cryptography import fernet
from octavia_lib.common import constants as lib_consts
from oslo_config import cfg from oslo_config import cfg
from oslo_db import exception as odb_exceptions from oslo_db import exception as odb_exceptions
from oslo_log import log as logging from oslo_log import log as logging
@ -27,6 +28,7 @@ from taskflow.types import failure
from octavia.api.drivers import utils as provider_utils from octavia.api.drivers import utils as provider_utils
from octavia.common import constants from octavia.common import constants
from octavia.common import data_models from octavia.common import data_models
from octavia.common import exceptions
from octavia.common.tls_utils import cert_parser from octavia.common.tls_utils import cert_parser
from octavia.common import utils from octavia.common import utils
from octavia.controller.worker import task_utils as task_utilities from octavia.controller.worker import task_utils as task_utilities
@ -3073,3 +3075,47 @@ class UpdatePoolMembersOperatingStatusInDB(BaseDatabaseTask):
with db_apis.session().begin() as session: with db_apis.session().begin() as session:
self.member_repo.update_pool_members( self.member_repo.update_pool_members(
session, pool_id, operating_status=operating_status) session, pool_id, operating_status=operating_status)
class GetAmphoraFirewallRules(BaseDatabaseTask):
"""Task to build firewall rules for the amphora."""
def execute(self, amphorae, amphora_index, amphorae_network_config):
this_amp_id = amphorae[amphora_index][constants.ID]
amp_net_config = amphorae_network_config[this_amp_id]
lb_dict = amp_net_config[constants.AMPHORA]['load_balancer']
vip_dict = lb_dict[constants.VIP]
if vip_dict[constants.VNIC_TYPE] != constants.VNIC_TYPE_DIRECT:
LOG.debug('Load balancer VIP port is not SR-IOV enabled. Skipping '
'firewall rules update.')
return [{'non-sriov-vip': True}]
session = db_apis.get_session()
with session.begin():
rules = self.listener_repo.get_port_protocol_cidr_for_lb(
session,
amp_net_config[constants.AMPHORA][constants.LOAD_BALANCER_ID])
# If we are act/stdby, inject the VRRP firewall rule(s)
if lb_dict[constants.TOPOLOGY] == constants.TOPOLOGY_ACTIVE_STANDBY:
for amp_cfg in lb_dict[constants.AMPHORAE]:
if (amp_cfg[constants.ID] != this_amp_id and
amp_cfg[constants.STATUS] ==
lib_consts.AMPHORA_ALLOCATED):
vrrp_ip = amp_cfg[constants.VRRP_IP]
vrrp_ip_ver = utils.ip_version(vrrp_ip)
if vrrp_ip_ver == 4:
vrrp_ip_cidr = f'{vrrp_ip}/32'
elif vrrp_ip_ver == 6:
vrrp_ip_cidr = f'{vrrp_ip}/128'
else:
raise exceptions.InvalidIPAddress(ip_addr=vrrp_ip)
rules.append({constants.PROTOCOL: constants.VRRP,
constants.CIDR: vrrp_ip_cidr,
constants.PORT: 112})
LOG.debug('Amphora %s SR-IOV firewall rules: %s', this_amp_id, rules)
return rules

View File

@ -0,0 +1,28 @@
# Copyright 2024 Red Hat
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License. You may obtain
# a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations
# under the License.
#
from taskflow import task
from octavia.common import constants
class AmphoraToAmphoraeWithVRRPIP(task.Task):
"""A shim class to convert a single Amphora instance to a list."""
def execute(self, amphora: dict, base_port: dict):
# The VRRP_IP has not been stamped on the Amphora at this point in the
# flow, so inject it from our port create call in a previous task.
amphora[constants.VRRP_IP] = (
base_port[constants.FIXED_IPS][0][constants.IP_ADDRESS])
return [amphora]

View File

@ -39,6 +39,7 @@ from sqlalchemy import update
from octavia.common import constants as consts from octavia.common import constants as consts
from octavia.common import data_models from octavia.common import data_models
from octavia.common import exceptions from octavia.common import exceptions
from octavia.common import utils
from octavia.common import validate from octavia.common import validate
from octavia.db import api as db_api from octavia.db import api as db_api
from octavia.db import models from octavia.db import models
@ -1085,6 +1086,23 @@ class ListenerRepository(BaseRepository):
update({self.model_class.provisioning_status: consts.ACTIVE}, update({self.model_class.provisioning_status: consts.ACTIVE},
synchronize_session='fetch')) synchronize_session='fetch'))
def get_port_protocol_cidr_for_lb(self, session, loadbalancer_id):
# readability variables
Listener = self.model_class
ListenerCidr = models.ListenerCidr
stmt = (select(Listener.protocol,
ListenerCidr.cidr,
Listener.protocol_port.label(consts.PORT))
.select_from(Listener)
.join(models.ListenerCidr,
Listener.id == ListenerCidr.listener_id, isouter=True)
.where(Listener.load_balancer_id == loadbalancer_id))
rows = session.execute(stmt)
return [utils.map_protocol_to_nftable_protocol(u._asdict()) for u
in rows.all()]
class ListenerStatisticsRepository(BaseRepository): class ListenerStatisticsRepository(BaseRepository):
model_class = models.ListenerStatistics model_class = models.ListenerStatistics

View File

@ -24,6 +24,7 @@ from oslo_config import fixture as oslo_fixture
from oslo_serialization import jsonutils from oslo_serialization import jsonutils
from oslo_utils.secretutils import md5 from oslo_utils.secretutils import md5
from oslo_utils import uuidutils from oslo_utils import uuidutils
import webob
from octavia.amphorae.backends.agent import api_server from octavia.amphorae.backends.agent import api_server
from octavia.amphorae.backends.agent.api_server import certificate_update from octavia.amphorae.backends.agent.api_server import certificate_update
@ -3055,3 +3056,39 @@ class TestServerTestCase(base.TestCase):
self.assertEqual(200, rv.status_code) self.assertEqual(200, rv.status_code)
self.assertEqual(expected_dict, self.assertEqual(expected_dict,
jsonutils.loads(rv.data.decode('utf-8'))) jsonutils.loads(rv.data.decode('utf-8')))
@mock.patch('octavia.amphorae.backends.utils.nftable_utils.'
'load_nftables_file')
@mock.patch('octavia.amphorae.backends.utils.nftable_utils.'
'write_nftable_vip_rules_file')
@mock.patch('octavia.amphorae.backends.agent.api_server.amphora_info.'
'AmphoraInfo.get_interface')
def test_set_interface_rules(self, mock_get_int, mock_write_rules,
mock_load_rules):
mock_get_int.side_effect = [
webob.Response(status=400),
webob.Response(status=200, json={'interface': 'fake1'}),
webob.Response(status=200, json={'interface': 'fake1'})]
# Test can't find interface
rv = self.ubuntu_app.put('/' + api_server.VERSION +
'/interface/192.0.2.10/rules', data='fake')
self.assertEqual(400, rv.status_code)
mock_write_rules.assert_not_called()
# Test schema validation failure
rv = self.ubuntu_app.put('/' + api_server.VERSION +
'/interface/192.0.2.10/rules', data='fake')
self.assertEqual('400 Bad Request', rv.status)
# Test successful path
rules_json = ('[{"protocol":"TCP","cidr":"192.0.2.0/24","port":8080},'
'{"protocol":"UDP","cidr":null,"port":80}]')
rv = self.ubuntu_app.put('/' + api_server.VERSION +
'/interface/192.0.2.10/rules',
data=rules_json,
content_type='application/json')
self.assertEqual('200 OK', rv.status)
mock_write_rules.assert_called_once_with('fake1',
jsonutils.loads(rules_json))
mock_load_rules.assert_called_once()

View File

@ -2762,6 +2762,14 @@ class TestListenerRepositoryTest(BaseRepositoryTest):
self.assertEqual(constants.PENDING_UPDATE, self.assertEqual(constants.PENDING_UPDATE,
new_listener.provisioning_status) new_listener.provisioning_status)
def test_get_port_protocol_cidr_for_lb(self):
self.create_listener(self.FAKE_UUID_1, 80,
provisioning_status=constants.ACTIVE)
rules = self.listener_repo.get_port_protocol_cidr_for_lb(
self.session, self.FAKE_UUID_1)
self.assertEqual([{'protocol': 'TCP', 'cidr': None, 'port': 80}],
rules)
class ListenerStatisticsRepositoryTest(BaseRepositoryTest): class ListenerStatisticsRepositoryTest(BaseRepositoryTest):

View File

@ -15,6 +15,7 @@
import errno import errno
import os import os
import socket import socket
import subprocess
from unittest import mock from unittest import mock
import pyroute2 import pyroute2
@ -448,6 +449,8 @@ class TestInterface(base.TestCase):
mock.call(["post-up", "eth1"]) mock.call(["post-up", "eth1"])
]) ])
@mock.patch('octavia.amphorae.backends.utils.network_namespace.'
'NetworkNamespace')
@mock.patch('octavia.amphorae.backends.utils.nftable_utils.' @mock.patch('octavia.amphorae.backends.utils.nftable_utils.'
'write_nftable_vip_rules_file') 'write_nftable_vip_rules_file')
@mock.patch('pyroute2.IPRoute.rule') @mock.patch('pyroute2.IPRoute.rule')
@ -459,7 +462,7 @@ class TestInterface(base.TestCase):
@mock.patch('subprocess.check_output') @mock.patch('subprocess.check_output')
def test_up_sriov(self, mock_check_output, mock_link_lookup, def test_up_sriov(self, mock_check_output, mock_link_lookup,
mock_get_links, mock_link, mock_addr, mock_route, mock_get_links, mock_link, mock_addr, mock_route,
mock_rule, mock_nftable): mock_rule, mock_nftable, mock_netns):
iface = interface_file.InterfaceFile( iface = interface_file.InterfaceFile(
name="fake-eth1", name="fake-eth1",
if_type="vip", if_type="vip",
@ -1441,3 +1444,56 @@ class TestInterface(base.TestCase):
addr = controller._normalize_ip_network(None) addr = controller._normalize_ip_network(None)
self.assertIsNone(addr) self.assertIsNone(addr)
@mock.patch('octavia.amphorae.backends.utils.nftable_utils.'
'load_nftables_file')
@mock.patch('octavia.amphorae.backends.utils.nftable_utils.'
'write_nftable_vip_rules_file')
@mock.patch('subprocess.check_output')
def test__setup_nftables_chain(self, mock_check_output, mock_write_rules,
mock_load_rules):
controller = interface.InterfaceController()
mock_check_output.side_effect = [
mock.DEFAULT, mock.DEFAULT,
subprocess.CalledProcessError(cmd=consts.NFT_CMD, returncode=-1),
mock.DEFAULT,
subprocess.CalledProcessError(cmd=consts.NFT_CMD, returncode=-1)]
interface_mock = mock.MagicMock()
interface_mock.name = 'fake2'
# Test succeessful path
controller._setup_nftables_chain(interface_mock)
mock_write_rules.assert_called_once_with('fake2', [])
mock_load_rules.assert_called_once_with()
mock_check_output.assert_has_calls([
mock.call([consts.NFT_CMD, 'add', 'table', consts.NFT_FAMILY,
consts.NFT_VIP_TABLE], stderr=subprocess.STDOUT),
mock.call([consts.NFT_CMD, 'add', 'chain', consts.NFT_FAMILY,
consts.NFT_VIP_TABLE, consts.NFT_VIP_CHAIN, '{',
'type', 'filter', 'hook', 'ingress', 'device',
'fake2', 'priority', consts.NFT_SRIOV_PRIORITY, ';',
'policy', 'drop', ';', '}'], stderr=subprocess.STDOUT)])
# Test first nft call fails
mock_write_rules.reset_mock()
mock_load_rules.reset_mock()
mock_check_output.reset_mock()
self.assertRaises(subprocess.CalledProcessError,
controller._setup_nftables_chain, interface_mock)
mock_check_output.assert_called_once()
mock_write_rules.assert_not_called()
# Test second nft call fails
mock_write_rules.reset_mock()
mock_load_rules.reset_mock()
mock_check_output.reset_mock()
self.assertRaises(subprocess.CalledProcessError,
controller._setup_nftables_chain, interface_mock)
self.assertEqual(2, mock_check_output.call_count)
mock_write_rules.assert_not_called()

View File

@ -13,10 +13,15 @@
# under the License. # under the License.
import os import os
import stat import stat
import subprocess
from unittest import mock from unittest import mock
from octavia_lib.common import constants as lib_consts
from webob import exc
from octavia.amphorae.backends.utils import nftable_utils from octavia.amphorae.backends.utils import nftable_utils
from octavia.common import constants as consts from octavia.common import constants as consts
from octavia.common import exceptions
import octavia.tests.unit.base as base import octavia.tests.unit.base as base
@ -47,10 +52,17 @@ class TestNFTableUtils(base.TestCase):
mock_isfile.return_value = True mock_isfile.return_value = True
mock_open.return_value = 'fake-fd' mock_open.return_value = 'fake-fd'
test_rule_1 = {consts.CIDR: None,
consts.PROTOCOL: lib_consts.PROTOCOL_TCP,
consts.PORT: 1234}
test_rule_2 = {consts.CIDR: '192.0.2.0/24',
consts.PROTOCOL: consts.VRRP,
consts.PORT: 4321}
mocked_open = mock.mock_open() mocked_open = mock.mock_open()
with mock.patch.object(os, 'fdopen', mocked_open): with mock.patch.object(os, 'fdopen', mocked_open):
nftable_utils.write_nftable_vip_rules_file( nftable_utils.write_nftable_vip_rules_file(
'fake-eth2', ['test rule 1', 'test rule 2']) 'fake-eth2', [test_rule_1, test_rule_2])
mocked_open.assert_called_once_with('fake-fd', 'w') mocked_open.assert_called_once_with('fake-fd', 'w')
mock_open.assert_called_once_with( mock_open.assert_called_once_with(
@ -60,15 +72,23 @@ class TestNFTableUtils(base.TestCase):
handle = mocked_open() handle = mocked_open()
handle.write.assert_has_calls([ handle.write.assert_has_calls([
mock.call(f'flush chain {consts.NFT_FAMILY} ' mock.call(f'table {consts.NFT_FAMILY} {consts.NFT_VIP_TABLE} '
f'{consts.NFT_VIP_TABLE} {consts.NFT_VIP_CHAIN}\n'), '{}\n'),
mock.call(f'delete table {consts.NFT_FAMILY} '
f'{consts.NFT_VIP_TABLE}\n'),
mock.call(f'table {consts.NFT_FAMILY} {consts.NFT_VIP_TABLE} ' mock.call(f'table {consts.NFT_FAMILY} {consts.NFT_VIP_TABLE} '
'{\n'), '{\n'),
mock.call(f' chain {consts.NFT_VIP_CHAIN} {{\n'), mock.call(f' chain {consts.NFT_VIP_CHAIN} {{\n'),
mock.call(' type filter hook ingress device fake-eth2 ' mock.call(' type filter hook ingress device fake-eth2 '
f'priority {consts.NFT_SRIOV_PRIORITY}; policy drop;\n'), f'priority {consts.NFT_SRIOV_PRIORITY}; policy drop;\n'),
mock.call(' test rule 1\n'), mock.call(' icmp type destination-unreachable accept\n'),
mock.call(' test rule 2\n'), mock.call(' icmpv6 type { nd-neighbor-solicit, '
'nd-router-advert, nd-neighbor-advert, packet-too-big, '
'destination-unreachable } accept\n'),
mock.call(' udp sport 67 udp dport 68 accept\n'),
mock.call(' udp sport 547 udp dport 546 accept\n'),
mock.call(' tcp dport 1234 accept\n'),
mock.call(' ip saddr 192.0.2.0/24 ip protocol 112 accept\n'),
mock.call(' }\n'), mock.call(' }\n'),
mock.call('}\n') mock.call('}\n')
]) ])
@ -101,6 +121,74 @@ class TestNFTableUtils(base.TestCase):
mock.call(f' chain {consts.NFT_VIP_CHAIN} {{\n'), mock.call(f' chain {consts.NFT_VIP_CHAIN} {{\n'),
mock.call(' type filter hook ingress device fake-eth2 ' mock.call(' type filter hook ingress device fake-eth2 '
f'priority {consts.NFT_SRIOV_PRIORITY}; policy drop;\n'), f'priority {consts.NFT_SRIOV_PRIORITY}; policy drop;\n'),
mock.call(' icmp type destination-unreachable accept\n'),
mock.call(' icmpv6 type { nd-neighbor-solicit, '
'nd-router-advert, nd-neighbor-advert, packet-too-big, '
'destination-unreachable } accept\n'),
mock.call(' udp sport 67 udp dport 68 accept\n'),
mock.call(' udp sport 547 udp dport 546 accept\n'),
mock.call(' }\n'), mock.call(' }\n'),
mock.call('}\n') mock.call('}\n')
]) ])
@mock.patch('octavia.common.utils.ip_version')
def test__build_rule_cmd(self, mock_ip_version):
mock_ip_version.side_effect = [4, 6, 99]
cmd = nftable_utils._build_rule_cmd({
consts.CIDR: '192.0.2.0/24',
consts.PROTOCOL: lib_consts.PROTOCOL_SCTP,
consts.PORT: 1234})
self.assertEqual('ip saddr 192.0.2.0/24 sctp dport 1234 accept', cmd)
cmd = nftable_utils._build_rule_cmd({
consts.CIDR: '2001:db8::/32',
consts.PROTOCOL: lib_consts.PROTOCOL_TCP,
consts.PORT: 1235})
self.assertEqual('ip6 saddr 2001:db8::/32 tcp dport 1235 accept', cmd)
self.assertRaises(exc.HTTPBadRequest, nftable_utils._build_rule_cmd,
{consts.CIDR: '192/32',
consts.PROTOCOL: lib_consts.PROTOCOL_TCP,
consts.PORT: 1237})
cmd = nftable_utils._build_rule_cmd({
consts.CIDR: None,
consts.PROTOCOL: lib_consts.PROTOCOL_UDP,
consts.PORT: 1236})
self.assertEqual('udp dport 1236 accept', cmd)
cmd = nftable_utils._build_rule_cmd({
consts.CIDR: None,
consts.PROTOCOL: consts.VRRP,
consts.PORT: 1237})
self.assertEqual('ip protocol 112 accept', cmd)
self.assertRaises(exc.HTTPBadRequest, nftable_utils._build_rule_cmd,
{consts.CIDR: None,
consts.PROTOCOL: 'bad-protocol',
consts.PORT: 1237})
@mock.patch('octavia.amphorae.backends.utils.network_namespace.'
'NetworkNamespace')
@mock.patch('subprocess.check_output')
def test_load_nftables_file(self, mock_check_output, mock_netns):
mock_netns.side_effect = [
mock.DEFAULT,
subprocess.CalledProcessError(cmd=consts.NFT_CMD, returncode=-1),
exceptions.AmphoraNetworkConfigException]
nftable_utils.load_nftables_file()
mock_netns.assert_called_once_with(consts.AMPHORA_NAMESPACE)
mock_check_output.assert_called_once_with([
consts.NFT_CMD, '-o', '-f', consts.NFT_VIP_RULES_FILE],
stderr=subprocess.STDOUT)
self.assertRaises(subprocess.CalledProcessError,
nftable_utils.load_nftables_file)
self.assertRaises(exceptions.AmphoraNetworkConfigException,
nftable_utils.load_nftables_file)

View File

@ -13,7 +13,7 @@
# under the License. # under the License.
from unittest import mock from unittest import mock
from octavia.amphorae.driver_exceptions.exceptions import AmpVersionUnsupported from octavia.amphorae.driver_exceptions import exceptions as driver_except
from octavia.amphorae.drivers.haproxy import exceptions as exc from octavia.amphorae.drivers.haproxy import exceptions as exc
from octavia.amphorae.drivers.haproxy import rest_api_driver from octavia.amphorae.drivers.haproxy import rest_api_driver
import octavia.tests.unit.base as base import octavia.tests.unit.base as base
@ -87,6 +87,28 @@ class TestHAProxyAmphoraDriver(base.TestCase):
mock_amp = mock.MagicMock() mock_amp = mock.MagicMock()
mock_amp.api_version = "0.5" mock_amp.api_version = "0.5"
self.assertRaises(AmpVersionUnsupported, self.assertRaises(driver_except.AmpVersionUnsupported,
self.driver._populate_amphora_api_version, self.driver._populate_amphora_api_version,
mock_amp) mock_amp)
@mock.patch('octavia.amphorae.drivers.haproxy.rest_api_driver.'
'HaproxyAmphoraLoadBalancerDriver.'
'_populate_amphora_api_version')
def test_set_interface_rules(self, mock_api_version):
IP_ADDRESS = '203.0.113.44'
amphora_mock = mock.MagicMock()
amphora_mock.api_version = '0'
client_mock = mock.MagicMock()
client_mock.set_interface_rules.side_effect = [mock.DEFAULT,
exc.NotFound]
self.driver.clients['0'] = client_mock
self.driver.set_interface_rules(amphora_mock, IP_ADDRESS, 'fake_rules')
mock_api_version.assert_called_once_with(amphora_mock)
client_mock.set_interface_rules.assert_called_once_with(
amphora_mock, IP_ADDRESS, 'fake_rules')
self.assertRaises(driver_except.AmpDriverNotImplementedError,
self.driver.set_interface_rules, amphora_mock,
IP_ADDRESS, 'fake_rules')

View File

@ -1548,3 +1548,13 @@ class TestAmphoraAPIClientTest(base.TestCase):
self.assertRaises(exc.InternalServerError, self.assertRaises(exc.InternalServerError,
self.driver.update_agent_config, self.amp, self.driver.update_agent_config, self.amp,
"some_file") "some_file")
@requests_mock.mock()
def test_set_interface_rules(self, m):
ip_addr = '192.0.2.44'
rules = ('[{"protocol":"TCP","cidr":"192.0.2.0/24","port":8080},'
'{"protocol":"UDP","cidr":null,"port":80}]')
m.put(f'{self.base_url_ver}/interface/{ip_addr}/rules')
self.driver.set_interface_rules(self.amp, ip_addr, rules)
self.assertTrue(m.called)

View File

@ -13,6 +13,7 @@
# under the License. # under the License.
from unittest import mock from unittest import mock
from octavia_lib.common import constants as lib_consts
from oslo_utils import uuidutils from oslo_utils import uuidutils
from octavia.common import constants from octavia.common import constants
@ -139,3 +140,41 @@ class TestConfig(base.TestCase):
expected_sg_name = constants.VIP_SECURITY_GROUP_PREFIX + FAKE_LB_ID expected_sg_name = constants.VIP_SECURITY_GROUP_PREFIX + FAKE_LB_ID
self.assertEqual(expected_sg_name, self.assertEqual(expected_sg_name,
utils.get_vip_security_group_name(FAKE_LB_ID)) utils.get_vip_security_group_name(FAKE_LB_ID))
def test_map_protocol_to_nftable_protocol(self):
result = utils.map_protocol_to_nftable_protocol(
{constants.PROTOCOL: lib_consts.PROTOCOL_TCP})
self.assertEqual({constants.PROTOCOL: lib_consts.PROTOCOL_TCP}, result)
result = utils.map_protocol_to_nftable_protocol(
{constants.PROTOCOL: lib_consts.PROTOCOL_HTTP})
self.assertEqual({constants.PROTOCOL: lib_consts.PROTOCOL_TCP}, result)
result = utils.map_protocol_to_nftable_protocol(
{constants.PROTOCOL: lib_consts.PROTOCOL_HTTPS})
self.assertEqual({constants.PROTOCOL: lib_consts.PROTOCOL_TCP}, result)
result = utils.map_protocol_to_nftable_protocol(
{constants.PROTOCOL: lib_consts.PROTOCOL_TERMINATED_HTTPS})
self.assertEqual({constants.PROTOCOL: lib_consts.PROTOCOL_TCP}, result)
result = utils.map_protocol_to_nftable_protocol(
{constants.PROTOCOL: lib_consts.PROTOCOL_PROXY})
self.assertEqual({constants.PROTOCOL: lib_consts.PROTOCOL_TCP}, result)
result = utils.map_protocol_to_nftable_protocol(
{constants.PROTOCOL: lib_consts.PROTOCOL_PROXYV2})
self.assertEqual({constants.PROTOCOL: lib_consts.PROTOCOL_TCP}, result)
result = utils.map_protocol_to_nftable_protocol(
{constants.PROTOCOL: lib_consts.PROTOCOL_UDP})
self.assertEqual({constants.PROTOCOL: lib_consts.PROTOCOL_UDP}, result)
result = utils.map_protocol_to_nftable_protocol(
{constants.PROTOCOL: lib_consts.PROTOCOL_SCTP})
self.assertEqual({constants.PROTOCOL: lib_consts.PROTOCOL_SCTP},
result)
result = utils.map_protocol_to_nftable_protocol(
{constants.PROTOCOL: lib_consts.PROTOCOL_PROMETHEUS})
self.assertEqual({constants.PROTOCOL: lib_consts.PROTOCOL_TCP}, result)

View File

@ -34,19 +34,31 @@ class TestListenerFlows(base.TestCase):
def test_get_create_listener_flow(self, mock_get_net_driver): def test_get_create_listener_flow(self, mock_get_net_driver):
listener_flow = self.ListenerFlow.get_create_listener_flow() flavor_dict = {
constants.SRIOV_VIP: True,
constants.LOADBALANCER_TOPOLOGY: constants.TOPOLOGY_SINGLE}
listener_flow = self.ListenerFlow.get_create_listener_flow(
flavor_dict=flavor_dict)
self.assertIsInstance(listener_flow, flow.Flow) self.assertIsInstance(listener_flow, flow.Flow)
self.assertIn(constants.LISTENERS, listener_flow.requires) self.assertIn(constants.LISTENERS, listener_flow.requires)
self.assertIn(constants.LOADBALANCER_ID, listener_flow.requires) self.assertIn(constants.LOADBALANCER_ID, listener_flow.requires)
self.assertIn(constants.AMPHORAE_NETWORK_CONFIG,
listener_flow.provides)
self.assertIn(constants.AMPHORAE, listener_flow.provides)
self.assertIn(constants.AMPHORA_FIREWALL_RULES, listener_flow.provides)
self.assertEqual(2, len(listener_flow.requires)) self.assertEqual(2, len(listener_flow.requires))
self.assertEqual(0, len(listener_flow.provides)) self.assertEqual(3, len(listener_flow.provides))
def test_get_delete_listener_flow(self, mock_get_net_driver): def test_get_delete_listener_flow(self, mock_get_net_driver):
flavor_dict = {
listener_flow = self.ListenerFlow.get_delete_listener_flow() constants.SRIOV_VIP: True,
constants.LOADBALANCER_TOPOLOGY: constants.TOPOLOGY_SINGLE}
listener_flow = self.ListenerFlow.get_delete_listener_flow(
flavor_dict=flavor_dict)
self.assertIsInstance(listener_flow, flow.Flow) self.assertIsInstance(listener_flow, flow.Flow)
@ -54,25 +66,42 @@ class TestListenerFlows(base.TestCase):
self.assertIn(constants.LOADBALANCER_ID, listener_flow.requires) self.assertIn(constants.LOADBALANCER_ID, listener_flow.requires)
self.assertIn(constants.PROJECT_ID, listener_flow.requires) self.assertIn(constants.PROJECT_ID, listener_flow.requires)
self.assertIn(constants.AMPHORAE_NETWORK_CONFIG,
listener_flow.provides)
self.assertIn(constants.AMPHORAE, listener_flow.provides)
self.assertIn(constants.AMPHORA_FIREWALL_RULES, listener_flow.provides)
self.assertEqual(3, len(listener_flow.requires)) self.assertEqual(3, len(listener_flow.requires))
self.assertEqual(0, len(listener_flow.provides)) self.assertEqual(3, len(listener_flow.provides))
def test_get_delete_listener_internal_flow(self, mock_get_net_driver): def test_get_delete_listener_internal_flow(self, mock_get_net_driver):
flavor_dict = {
constants.SRIOV_VIP: True,
constants.LOADBALANCER_TOPOLOGY: constants.TOPOLOGY_SINGLE}
fake_listener = {constants.LISTENER_ID: uuidutils.generate_uuid()} fake_listener = {constants.LISTENER_ID: uuidutils.generate_uuid()}
listener_flow = self.ListenerFlow.get_delete_listener_internal_flow( listener_flow = self.ListenerFlow.get_delete_listener_internal_flow(
fake_listener) fake_listener, flavor_dict=flavor_dict)
self.assertIsInstance(listener_flow, flow.Flow) self.assertIsInstance(listener_flow, flow.Flow)
self.assertIn(constants.LOADBALANCER_ID, listener_flow.requires) self.assertIn(constants.LOADBALANCER_ID, listener_flow.requires)
self.assertIn(constants.PROJECT_ID, listener_flow.requires) self.assertIn(constants.PROJECT_ID, listener_flow.requires)
self.assertIn(constants.AMPHORAE_NETWORK_CONFIG,
listener_flow.provides)
self.assertIn(constants.AMPHORAE, listener_flow.provides)
self.assertIn(constants.AMPHORA_FIREWALL_RULES, listener_flow.provides)
self.assertEqual(2, len(listener_flow.requires)) self.assertEqual(2, len(listener_flow.requires))
self.assertEqual(0, len(listener_flow.provides)) self.assertEqual(3, len(listener_flow.provides))
def test_get_update_listener_flow(self, mock_get_net_driver): def test_get_update_listener_flow(self, mock_get_net_driver):
flavor_dict = {
constants.SRIOV_VIP: True,
constants.LOADBALANCER_TOPOLOGY: constants.TOPOLOGY_SINGLE}
listener_flow = self.ListenerFlow.get_update_listener_flow() listener_flow = self.ListenerFlow.get_update_listener_flow(
flavor_dict=flavor_dict)
self.assertIsInstance(listener_flow, flow.Flow) self.assertIsInstance(listener_flow, flow.Flow)
@ -81,14 +110,28 @@ class TestListenerFlows(base.TestCase):
self.assertIn(constants.LISTENERS, listener_flow.requires) self.assertIn(constants.LISTENERS, listener_flow.requires)
self.assertIn(constants.LOADBALANCER_ID, listener_flow.requires) self.assertIn(constants.LOADBALANCER_ID, listener_flow.requires)
self.assertIn(constants.AMPHORAE_NETWORK_CONFIG,
listener_flow.provides)
self.assertIn(constants.AMPHORAE, listener_flow.provides)
self.assertIn(constants.AMPHORA_FIREWALL_RULES, listener_flow.provides)
self.assertEqual(4, len(listener_flow.requires)) self.assertEqual(4, len(listener_flow.requires))
self.assertEqual(0, len(listener_flow.provides)) self.assertEqual(3, len(listener_flow.provides))
def test_get_create_all_listeners_flow(self, mock_get_net_driver): def test_get_create_all_listeners_flow(self, mock_get_net_driver):
listeners_flow = self.ListenerFlow.get_create_all_listeners_flow() flavor_dict = {
constants.SRIOV_VIP: True,
constants.LOADBALANCER_TOPOLOGY: constants.TOPOLOGY_ACTIVE_STANDBY}
listeners_flow = self.ListenerFlow.get_create_all_listeners_flow(
flavor_dict=flavor_dict)
self.assertIsInstance(listeners_flow, flow.Flow) self.assertIsInstance(listeners_flow, flow.Flow)
self.assertIn(constants.LOADBALANCER, listeners_flow.requires) self.assertIn(constants.LOADBALANCER, listeners_flow.requires)
self.assertIn(constants.LOADBALANCER_ID, listeners_flow.requires) self.assertIn(constants.LOADBALANCER_ID, listeners_flow.requires)
self.assertIn(constants.LOADBALANCER, listeners_flow.provides) self.assertIn(constants.LOADBALANCER, listeners_flow.provides)
self.assertIn(constants.AMPHORAE_NETWORK_CONFIG,
listeners_flow.provides)
self.assertIn(constants.AMPHORAE, listeners_flow.provides)
self.assertIn(constants.AMPHORA_FIREWALL_RULES,
listeners_flow.provides)
self.assertEqual(2, len(listeners_flow.requires)) self.assertEqual(2, len(listeners_flow.requires))
self.assertEqual(2, len(listeners_flow.provides)) self.assertEqual(5, len(listeners_flow.provides))

View File

@ -359,10 +359,13 @@ class TestLoadBalancerFlows(base.TestCase):
self.assertIn(constants.VIP, failover_flow.provides) self.assertIn(constants.VIP, failover_flow.provides)
self.assertIn(constants.ADDITIONAL_VIPS, failover_flow.provides) self.assertIn(constants.ADDITIONAL_VIPS, failover_flow.provides)
self.assertIn(constants.VIP_SG_ID, failover_flow.provides) self.assertIn(constants.VIP_SG_ID, failover_flow.provides)
self.assertIn(constants.AMPHORA_FIREWALL_RULES, failover_flow.provides)
self.assertIn(constants.SUBNET, failover_flow.provides)
self.assertIn(constants.NEW_AMPHORAE, failover_flow.provides)
self.assertEqual(6, len(failover_flow.requires), self.assertEqual(6, len(failover_flow.requires),
failover_flow.requires) failover_flow.requires)
self.assertEqual(14, len(failover_flow.provides), self.assertEqual(16, len(failover_flow.provides),
failover_flow.provides) failover_flow.provides)
@mock.patch('octavia.common.rpc.NOTIFIER', @mock.patch('octavia.common.rpc.NOTIFIER',
@ -435,10 +438,14 @@ class TestLoadBalancerFlows(base.TestCase):
self.assertIn(constants.VIP, failover_flow.provides) self.assertIn(constants.VIP, failover_flow.provides)
self.assertIn(constants.ADDITIONAL_VIPS, failover_flow.provides) self.assertIn(constants.ADDITIONAL_VIPS, failover_flow.provides)
self.assertIn(constants.VIP_SG_ID, failover_flow.provides) self.assertIn(constants.VIP_SG_ID, failover_flow.provides)
self.assertIn(constants.SUBNET, failover_flow.provides)
self.assertIn(constants.AMPHORA_FIREWALL_RULES, failover_flow.provides)
self.assertIn(constants.SUBNET, failover_flow.provides)
self.assertIn(constants.NEW_AMPHORAE, failover_flow.provides)
self.assertEqual(6, len(failover_flow.requires), self.assertEqual(6, len(failover_flow.requires),
failover_flow.requires) failover_flow.requires)
self.assertEqual(14, len(failover_flow.provides), self.assertEqual(16, len(failover_flow.provides),
failover_flow.provides) failover_flow.provides)
@mock.patch('octavia.common.rpc.NOTIFIER', @mock.patch('octavia.common.rpc.NOTIFIER',

View File

@ -1246,3 +1246,31 @@ class TestAmphoraDriverTasks(base.TestCase):
ret[amphora1_mock[constants.ID]][constants.UNREACHABLE]) ret[amphora1_mock[constants.ID]][constants.UNREACHABLE])
self.assertTrue( self.assertTrue(
ret[amphora2_mock[constants.ID]][constants.UNREACHABLE]) ret[amphora2_mock[constants.ID]][constants.UNREACHABLE])
def test_set_amphora_firewall_rules(self,
mock_driver,
mock_generate_uuid,
mock_log,
mock_get_session,
mock_listener_repo_get,
mock_listener_repo_update,
mock_amphora_repo_get,
mock_amphora_repo_update):
amphora = {constants.ID: AMP_ID, constants.VRRP_IP: '192.0.2.88'}
mock_amphora_repo_get.return_value = _db_amphora_mock
set_amp_fw_rules = amphora_driver_tasks.SetAmphoraFirewallRules()
# Test non-SRIOV VIP path
set_amp_fw_rules.execute([amphora], 0, [{'non-sriov-vip': True}])
mock_get_session.assert_not_called()
mock_driver.set_interface_rules.assert_not_called()
# Test SRIOV VIP path
set_amp_fw_rules.execute([amphora], 0, [{'fake_rule': True}])
mock_amphora_repo_get.assert_called_once_with(_session_mock, id=AMP_ID)
mock_driver.set_interface_rules.assert_called_once_with(
_db_amphora_mock, '192.0.2.88', [{'fake_rule': True}])

View File

@ -12,6 +12,7 @@
# License for the specific language governing permissions and limitations # License for the specific language governing permissions and limitations
# under the License. # under the License.
# #
import copy
import random import random
from unittest import mock from unittest import mock
@ -24,6 +25,7 @@ from taskflow.types import failure
from octavia.api.drivers import utils as provider_utils from octavia.api.drivers import utils as provider_utils
from octavia.common import constants from octavia.common import constants
from octavia.common import data_models from octavia.common import data_models
from octavia.common import exceptions
from octavia.common import utils from octavia.common import utils
from octavia.controller.worker.v2.tasks import database_tasks from octavia.controller.worker.v2.tasks import database_tasks
from octavia.db import repositories as repo from octavia.db import repositories as repo
@ -31,6 +33,7 @@ import octavia.tests.unit.base as base
AMP_ID = uuidutils.generate_uuid() AMP_ID = uuidutils.generate_uuid()
AMP2_ID = uuidutils.generate_uuid()
COMPUTE_ID = uuidutils.generate_uuid() COMPUTE_ID = uuidutils.generate_uuid()
LB_ID = uuidutils.generate_uuid() LB_ID = uuidutils.generate_uuid()
SERVER_GROUP_ID = uuidutils.generate_uuid() SERVER_GROUP_ID = uuidutils.generate_uuid()
@ -2987,3 +2990,100 @@ class TestDatabaseTasks(base.TestCase):
mock_session, mock_session,
POOL_ID, POOL_ID,
operating_status=constants.ONLINE) operating_status=constants.ONLINE)
@mock.patch('octavia.common.utils.ip_version')
@mock.patch('octavia.db.api.get_session')
@mock.patch('octavia.db.repositories.ListenerRepository.'
'get_port_protocol_cidr_for_lb')
def test_get_amphora_firewall_rules(self,
mock_get_port_for_lb,
mock_db_get_session,
mock_ip_version,
mock_generate_uuid,
mock_LOG,
mock_get_session,
mock_loadbalancer_repo_update,
mock_listener_repo_update,
mock_amphora_repo_update,
mock_amphora_repo_delete):
amphora_dict = {constants.ID: AMP_ID}
rules = [{'protocol': 'TCP', 'cidr': '192.0.2.0/24', 'port': 80},
{'protocol': 'TCP', 'cidr': '198.51.100.0/24', 'port': 80}]
vrrp_rules = [
{'protocol': 'TCP', 'cidr': '192.0.2.0/24', 'port': 80},
{'protocol': 'TCP', 'cidr': '198.51.100.0/24', 'port': 80},
{'cidr': '203.0.113.5/32', 'port': 112, 'protocol': 'vrrp'}]
mock_get_port_for_lb.side_effect = [
copy.deepcopy(rules), copy.deepcopy(rules), copy.deepcopy(rules),
copy.deepcopy(rules)]
mock_ip_version.side_effect = [4, 6, 55]
get_amp_fw_rules = database_tasks.GetAmphoraFirewallRules()
# Test non-SRIOV VIP
amphora_net_cfg_dict = {
AMP_ID: {constants.AMPHORA: {
'load_balancer': {constants.VIP: {
constants.VNIC_TYPE: constants.VNIC_TYPE_NORMAL}}}}}
result = get_amp_fw_rules.execute([amphora_dict], 0,
amphora_net_cfg_dict)
self.assertEqual([{'non-sriov-vip': True}], result)
# Test SRIOV VIP - Single
amphora_net_cfg_dict = {
AMP_ID: {constants.AMPHORA: {
'load_balancer': {constants.VIP: {
constants.VNIC_TYPE: constants.VNIC_TYPE_DIRECT},
constants.TOPOLOGY: constants.TOPOLOGY_SINGLE},
constants.LOAD_BALANCER_ID: LB_ID}}}
result = get_amp_fw_rules.execute([amphora_dict], 0,
amphora_net_cfg_dict)
mock_get_port_for_lb.assert_called_once_with(mock_db_get_session(),
LB_ID)
self.assertEqual(rules, result)
mock_get_port_for_lb.reset_mock()
# Test SRIOV VIP - Active/Standby
amphora_net_cfg_dict = {
AMP_ID: {constants.AMPHORA: {
'load_balancer': {constants.VIP: {
constants.VNIC_TYPE: constants.VNIC_TYPE_DIRECT},
constants.TOPOLOGY: constants.TOPOLOGY_ACTIVE_STANDBY,
constants.AMPHORAE: [{
constants.ID: AMP_ID,
constants.STATUS: constants.AMPHORA_ALLOCATED},
{constants.ID: AMP2_ID,
constants.STATUS: constants.AMPHORA_ALLOCATED,
constants.VRRP_IP: '203.0.113.5'}]},
constants.LOAD_BALANCER_ID: LB_ID}}}
# IPv4 path
mock_get_port_for_lb.reset_mock()
vrrp_rules = [
{'protocol': 'TCP', 'cidr': '192.0.2.0/24', 'port': 80},
{'protocol': 'TCP', 'cidr': '198.51.100.0/24', 'port': 80},
{'cidr': '203.0.113.5/32', 'port': 112, 'protocol': 'vrrp'}]
result = get_amp_fw_rules.execute([amphora_dict], 0,
amphora_net_cfg_dict)
mock_get_port_for_lb.assert_called_once_with(mock_db_get_session(),
LB_ID)
self.assertEqual(vrrp_rules, result)
# IPv6 path
mock_get_port_for_lb.reset_mock()
vrrp_rules = [
{'protocol': 'TCP', 'cidr': '192.0.2.0/24', 'port': 80},
{'protocol': 'TCP', 'cidr': '198.51.100.0/24', 'port': 80},
{'cidr': '203.0.113.5/128', 'port': 112, 'protocol': 'vrrp'}]
result = get_amp_fw_rules.execute([amphora_dict], 0,
amphora_net_cfg_dict)
mock_get_port_for_lb.assert_called_once_with(mock_db_get_session(),
LB_ID)
self.assertEqual(vrrp_rules, result)
# Bogus IP version path
self.assertRaises(exceptions.InvalidIPAddress,
get_amp_fw_rules.execute, [amphora_dict], 0,
amphora_net_cfg_dict)

View File

@ -0,0 +1,33 @@
# Copyright 2024 Red Hat
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License. You may obtain
# a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations
# under the License.
#
from octavia.common import constants
from octavia.controller.worker.v2.tasks import shim_tasks
import octavia.tests.unit.base as base
class TestShimTasks(base.TestCase):
def test_amphora_to_amphorae_with_vrrp_ip(self):
amp_to_amps = shim_tasks.AmphoraToAmphoraeWithVRRPIP()
base_port = {constants.FIXED_IPS:
[{constants.IP_ADDRESS: '192.0.2.43'}]}
amphora = {constants.ID: '123456'}
expected_amphora = [{constants.ID: '123456',
constants.VRRP_IP: '192.0.2.43'}]
self.assertEqual(expected_amphora,
amp_to_amps.execute(amphora, base_port))

View File

@ -70,7 +70,7 @@ _db_load_balancer_mock = mock.MagicMock()
_load_balancer_mock = { _load_balancer_mock = {
constants.LOADBALANCER_ID: LB_ID, constants.LOADBALANCER_ID: LB_ID,
constants.TOPOLOGY: constants.TOPOLOGY_SINGLE, constants.TOPOLOGY: constants.TOPOLOGY_SINGLE,
constants.FLAVOR_ID: None, constants.FLAVOR_ID: 1,
constants.AVAILABILITY_ZONE: None, constants.AVAILABILITY_ZONE: None,
constants.SERVER_GROUP_ID: None constants.SERVER_GROUP_ID: None
} }
@ -133,7 +133,7 @@ class TestControllerWorker(base.TestCase):
_db_load_balancer_mock.amphorae = _db_amphora_mock _db_load_balancer_mock.amphorae = _db_amphora_mock
_db_load_balancer_mock.vip = _vip_mock _db_load_balancer_mock.vip = _vip_mock
_db_load_balancer_mock.id = LB_ID _db_load_balancer_mock.id = LB_ID
_db_load_balancer_mock.flavor_id = None _db_load_balancer_mock.flavor_id = 1
_db_load_balancer_mock.availability_zone = None _db_load_balancer_mock.availability_zone = None
_db_load_balancer_mock.server_group_id = None _db_load_balancer_mock.server_group_id = None
_db_load_balancer_mock.project_id = PROJECT_ID _db_load_balancer_mock.project_id = PROJECT_ID
@ -331,7 +331,10 @@ class TestControllerWorker(base.TestCase):
cw.update_health_monitor(_health_mon_mock, cw.update_health_monitor(_health_mon_mock,
HEALTH_UPDATE_DICT) HEALTH_UPDATE_DICT)
@mock.patch('octavia.db.repositories.FlavorRepository.'
'get_flavor_metadata_dict', return_value={})
def test_create_listener(self, def test_create_listener(self,
mock_get_flavor_dict,
mock_api_get_session, mock_api_get_session,
mock_dyn_log_listener, mock_dyn_log_listener,
mock_taskflow_load, mock_taskflow_load,
@ -355,42 +358,19 @@ class TestControllerWorker(base.TestCase):
provider_lb = provider_utils.db_loadbalancer_to_provider_loadbalancer( provider_lb = provider_utils.db_loadbalancer_to_provider_loadbalancer(
_db_load_balancer_mock).to_dict(recurse=True) _db_load_balancer_mock).to_dict(recurse=True)
flavor_dict = {constants.LOADBALANCER_TOPOLOGY:
constants.TOPOLOGY_SINGLE}
(cw.services_controller.run_poster. (cw.services_controller.run_poster.
assert_called_once_with( assert_called_once_with(
flow_utils.get_create_listener_flow, store={ flow_utils.get_create_listener_flow, flavor_dict=flavor_dict,
constants.LOADBALANCER: provider_lb, store={constants.LOADBALANCER: provider_lb,
constants.LOADBALANCER_ID: LB_ID, constants.LOADBALANCER_ID: LB_ID,
constants.LISTENERS: [listener_dict]})) constants.LISTENERS: [listener_dict]}))
@mock.patch('octavia.db.repositories.FlavorRepository.'
'get_flavor_metadata_dict', return_value={})
def test_delete_listener(self, def test_delete_listener(self,
mock_api_get_session, mock_get_flavor_dict,
mock_dyn_log_listener,
mock_taskflow_load,
mock_pool_repo_get,
mock_member_repo_get,
mock_l7rule_repo_get,
mock_l7policy_repo_get,
mock_listener_repo_get,
mock_lb_repo_get,
mock_health_mon_repo_get,
mock_amp_repo_get):
_flow_mock.reset_mock()
listener_dict = {constants.LISTENER_ID: LISTENER_ID,
constants.LOADBALANCER_ID: LB_ID,
constants.PROJECT_ID: PROJECT_ID}
cw = controller_worker.ControllerWorker()
cw.delete_listener(listener_dict)
(cw.services_controller.run_poster.
assert_called_once_with(
flow_utils.get_delete_listener_flow,
store={constants.LISTENER: self.ref_listener_dict,
constants.LOADBALANCER_ID: LB_ID,
constants.PROJECT_ID: PROJECT_ID}))
def test_update_listener(self,
mock_api_get_session, mock_api_get_session,
mock_dyn_log_listener, mock_dyn_log_listener,
mock_taskflow_load, mock_taskflow_load,
@ -406,6 +386,48 @@ class TestControllerWorker(base.TestCase):
load_balancer_mock = mock.MagicMock() load_balancer_mock = mock.MagicMock()
load_balancer_mock.provisioning_status = constants.PENDING_UPDATE load_balancer_mock.provisioning_status = constants.PENDING_UPDATE
load_balancer_mock.id = LB_ID load_balancer_mock.id = LB_ID
load_balancer_mock.flavor_id = 1
load_balancer_mock.topology = constants.TOPOLOGY_SINGLE
mock_lb_repo_get.return_value = load_balancer_mock
_flow_mock.reset_mock()
listener_dict = {constants.LISTENER_ID: LISTENER_ID,
constants.LOADBALANCER_ID: LB_ID,
constants.PROJECT_ID: PROJECT_ID}
cw = controller_worker.ControllerWorker()
cw.delete_listener(listener_dict)
flavor_dict = {constants.LOADBALANCER_TOPOLOGY:
constants.TOPOLOGY_SINGLE}
(cw.services_controller.run_poster.
assert_called_once_with(
flow_utils.get_delete_listener_flow, flavor_dict=flavor_dict,
store={constants.LISTENER: self.ref_listener_dict,
constants.LOADBALANCER_ID: LB_ID,
constants.PROJECT_ID: PROJECT_ID}))
@mock.patch('octavia.db.repositories.FlavorRepository.'
'get_flavor_metadata_dict', return_value={})
def test_update_listener(self,
mock_get_flavor_dict,
mock_api_get_session,
mock_dyn_log_listener,
mock_taskflow_load,
mock_pool_repo_get,
mock_member_repo_get,
mock_l7rule_repo_get,
mock_l7policy_repo_get,
mock_listener_repo_get,
mock_lb_repo_get,
mock_health_mon_repo_get,
mock_amp_repo_get):
load_balancer_mock = mock.MagicMock()
load_balancer_mock.provisioning_status = constants.PENDING_UPDATE
load_balancer_mock.id = LB_ID
load_balancer_mock.flavor_id = None
load_balancer_mock.topology = constants.TOPOLOGY_SINGLE
mock_lb_repo_get.return_value = load_balancer_mock mock_lb_repo_get.return_value = load_balancer_mock
_flow_mock.reset_mock() _flow_mock.reset_mock()
@ -416,8 +438,11 @@ class TestControllerWorker(base.TestCase):
cw = controller_worker.ControllerWorker() cw = controller_worker.ControllerWorker()
cw.update_listener(listener_dict, LISTENER_UPDATE_DICT) cw.update_listener(listener_dict, LISTENER_UPDATE_DICT)
flavor_dict = {constants.LOADBALANCER_TOPOLOGY:
constants.TOPOLOGY_SINGLE}
(cw.services_controller.run_poster. (cw.services_controller.run_poster.
assert_called_once_with(flow_utils.get_update_listener_flow, assert_called_once_with(flow_utils.get_update_listener_flow,
flavor_dict=flavor_dict,
store={constants.LISTENER: listener_dict, store={constants.LISTENER: listener_dict,
constants.UPDATE_DICT: constants.UPDATE_DICT:
LISTENER_UPDATE_DICT, LISTENER_UPDATE_DICT,
@ -425,10 +450,13 @@ class TestControllerWorker(base.TestCase):
constants.LISTENERS: constants.LISTENERS:
[listener_dict]})) [listener_dict]}))
@mock.patch('octavia.db.repositories.FlavorRepository.'
'get_flavor_metadata_dict', return_value={})
@mock.patch("octavia.controller.worker.v2.controller_worker." @mock.patch("octavia.controller.worker.v2.controller_worker."
"ControllerWorker._get_db_obj_until_pending_update") "ControllerWorker._get_db_obj_until_pending_update")
def test_update_listener_timeout(self, def test_update_listener_timeout(self,
mock__get_db_obj_until_pending, mock__get_db_obj_until_pending,
mock_get_flavor_dict,
mock_api_get_session, mock_api_get_session,
mock_dyn_log_listener, mock_dyn_log_listener,
mock_taskflow_load, mock_taskflow_load,
@ -443,6 +471,7 @@ class TestControllerWorker(base.TestCase):
load_balancer_mock = mock.MagicMock() load_balancer_mock = mock.MagicMock()
load_balancer_mock.provisioning_status = constants.PENDING_UPDATE load_balancer_mock.provisioning_status = constants.PENDING_UPDATE
load_balancer_mock.id = LB_ID load_balancer_mock.id = LB_ID
load_balancer_mock.flavor_id = 1
_flow_mock.reset_mock() _flow_mock.reset_mock()
_listener_mock.provisioning_status = constants.PENDING_UPDATE _listener_mock.provisioning_status = constants.PENDING_UPDATE
last_attempt_mock = mock.MagicMock() last_attempt_mock = mock.MagicMock()
@ -2095,10 +2124,13 @@ class TestControllerWorker(base.TestCase):
cw._get_amphorae_for_failover, cw._get_amphorae_for_failover,
load_balancer_mock) load_balancer_mock)
@mock.patch('octavia.db.repositories.FlavorRepository.'
'get_flavor_metadata_dict')
@mock.patch('octavia.controller.worker.v2.controller_worker.' @mock.patch('octavia.controller.worker.v2.controller_worker.'
'ControllerWorker._get_amphorae_for_failover') 'ControllerWorker._get_amphorae_for_failover')
def test_failover_loadbalancer_single(self, def test_failover_loadbalancer_single(self,
mock_get_amps_for_failover, mock_get_amps_for_failover,
mock_get_flavor_dict,
mock_api_get_session, mock_api_get_session,
mock_dyn_log_listener, mock_dyn_log_listener,
mock_taskflow_load, mock_taskflow_load,
@ -2113,6 +2145,7 @@ class TestControllerWorker(base.TestCase):
_flow_mock.reset_mock() _flow_mock.reset_mock()
mock_lb_repo_get.return_value = _db_load_balancer_mock mock_lb_repo_get.return_value = _db_load_balancer_mock
mock_get_amps_for_failover.return_value = [_amphora_mock] mock_get_amps_for_failover.return_value = [_amphora_mock]
mock_get_flavor_dict.return_value = {}
provider_lb = provider_utils.db_loadbalancer_to_provider_loadbalancer( provider_lb = provider_utils.db_loadbalancer_to_provider_loadbalancer(
_db_load_balancer_mock).to_dict() _db_load_balancer_mock).to_dict()

View File

@ -0,0 +1,9 @@
---
features:
- |
Octavia Amphora based load balancers now support using SR-IOV virtual
functions (VF) on the VIP port(s) of the load balancer. This is enabled
by using an Octavia Flavor that includes the 'sriov_vip': True setting.
upgrade:
- |
You must update the amphora image to support the SR-IOV VIP feature.