Merge "validate the subnet id for create loadbalancer and member"
This commit is contained in:
commit
e040d5926e
@ -27,6 +27,7 @@ from octavia.api.v1.types import load_balancer as lb_types
|
||||
from octavia.common import constants
|
||||
from octavia.common import data_models
|
||||
from octavia.common import exceptions
|
||||
import octavia.common.validate as validate
|
||||
from octavia.db import prepare as db_prepare
|
||||
from octavia.i18n import _LI
|
||||
|
||||
@ -97,6 +98,12 @@ class LoadBalancersController(base.BaseController):
|
||||
body=lb_types.LoadBalancerPOST, status_code=202)
|
||||
def post(self, load_balancer):
|
||||
"""Creates a load balancer."""
|
||||
# Validate the subnet id
|
||||
if load_balancer.vip.subnet_id:
|
||||
if not validate.subnet_exists(load_balancer.vip.subnet_id):
|
||||
raise exceptions.NotFound(resource='Subnet',
|
||||
id=load_balancer.vip.subnet_id)
|
||||
|
||||
context = pecan.request.context.get('octavia_context')
|
||||
if load_balancer.listeners:
|
||||
return self._create_load_balancer_graph(context, load_balancer)
|
||||
|
@ -25,6 +25,7 @@ from octavia.api.v1.controllers import base
|
||||
from octavia.api.v1.types import member as member_types
|
||||
from octavia.common import constants
|
||||
from octavia.common import exceptions
|
||||
import octavia.common.validate as validate
|
||||
from octavia.db import prepare as db_prepare
|
||||
from octavia.i18n import _LI
|
||||
|
||||
@ -91,6 +92,10 @@ class MembersController(base.BaseController):
|
||||
def post(self, member):
|
||||
"""Creates a pool member on a pool."""
|
||||
context = pecan.request.context.get('octavia_context')
|
||||
# Validate member subnet
|
||||
if member.subnet_id and not validate.subnet_exists(member.subnet_id):
|
||||
raise exceptions.NotFound(resource='Subnet',
|
||||
id=member.subnet_id)
|
||||
member_dict = db_prepare.create_member(member.to_dict(
|
||||
render_unsets=True), self.pool_id)
|
||||
self._test_lb_and_listener_statuses(context.session)
|
||||
|
@ -24,8 +24,12 @@ import hashlib
|
||||
import random
|
||||
import socket
|
||||
|
||||
from oslo_config import cfg
|
||||
from oslo_log import log as logging
|
||||
from oslo_utils import excutils
|
||||
from stevedore import driver as stevedore_driver
|
||||
|
||||
CONF = cfg.CONF
|
||||
|
||||
LOG = logging.getLogger(__name__)
|
||||
|
||||
@ -56,6 +60,16 @@ def base64_sha1_string(string_to_hash):
|
||||
return b64_str.decode('UTF-8')
|
||||
|
||||
|
||||
def get_network_driver():
|
||||
CONF.import_group('controller_worker', 'octavia.common.config')
|
||||
network_driver = stevedore_driver.DriverManager(
|
||||
namespace='octavia.network.drivers',
|
||||
name=CONF.controller_worker.network_driver,
|
||||
invoke_on_load=True
|
||||
).driver
|
||||
return network_driver
|
||||
|
||||
|
||||
class exception_logger(object):
|
||||
"""Wrap a function and log raised exception
|
||||
|
||||
|
@ -18,12 +18,14 @@ Several handy validation functions that go beyond simple type checking.
|
||||
Defined here so these can also be used at deeper levels than the API.
|
||||
"""
|
||||
|
||||
|
||||
import re
|
||||
|
||||
import rfc3986
|
||||
|
||||
from octavia.common import constants
|
||||
from octavia.common import exceptions
|
||||
from octavia.common import utils
|
||||
|
||||
|
||||
def url(url):
|
||||
@ -198,3 +200,14 @@ def sanitize_l7policy_api_args(l7policy, create=False):
|
||||
if len(l7policy.keys()) == 0:
|
||||
raise exceptions.InvalidL7PolicyArgs(msg='Invalid update options')
|
||||
return l7policy
|
||||
|
||||
|
||||
def subnet_exists(subnet_id):
|
||||
network_driver = utils.get_network_driver()
|
||||
# Throws an exception when trying to get a subnet which
|
||||
# does not exist.
|
||||
try:
|
||||
network_driver.get_subnet(subnet_id)
|
||||
except Exception:
|
||||
return False
|
||||
return True
|
||||
|
@ -168,11 +168,11 @@ class BaseAPITest(base_db_test.OctaviaDBTestBase):
|
||||
return response.json
|
||||
|
||||
def create_member(self, lb_id, pool_id, ip_address,
|
||||
protocol_port, **optionals):
|
||||
protocol_port, expect_error=False, **optionals):
|
||||
req_dict = {'ip_address': ip_address, 'protocol_port': protocol_port}
|
||||
req_dict.update(optionals)
|
||||
path = self.MEMBERS_PATH.format(lb_id=lb_id, pool_id=pool_id)
|
||||
response = self.post(path, req_dict)
|
||||
response = self.post(path, req_dict, expect_errors=expect_error)
|
||||
return response.json
|
||||
|
||||
def create_member_with_listener(self, lb_id, listener_id, pool_id,
|
||||
|
@ -14,9 +14,11 @@
|
||||
|
||||
import copy
|
||||
|
||||
import mock
|
||||
from oslo_utils import uuidutils
|
||||
|
||||
from octavia.common import constants
|
||||
from octavia.network import base as network_base
|
||||
from octavia.tests.functional.api.v1 import base
|
||||
|
||||
|
||||
@ -229,6 +231,32 @@ class TestLoadBalancer(base.BaseAPITest):
|
||||
path = self.LB_PATH.format(lb_id='bad_uuid')
|
||||
self.delete(path, status=404)
|
||||
|
||||
def test_create_with_bad_subnet(self, **optionals):
|
||||
with mock.patch(
|
||||
'octavia.common.utils.get_network_driver') as net_mock:
|
||||
net_mock.return_value.get_subnet = mock.Mock(
|
||||
side_effect=network_base.SubnetNotFound('Subnet not found'))
|
||||
subnet_id = uuidutils.generate_uuid()
|
||||
lb_json = {'name': 'test1', 'vip': {'subnet_id': subnet_id,
|
||||
'ip_address': '10.0.0.1'}}
|
||||
lb_json.update(optionals)
|
||||
response = self.post(self.LBS_PATH, lb_json, expect_errors=True)
|
||||
err_msg = 'Subnet ' + subnet_id + ' not found.'
|
||||
self.assertEqual(response.json.get('faultstring'), err_msg)
|
||||
|
||||
def test_create_with_valid_subnet(self, **optionals):
|
||||
subnet_id = uuidutils.generate_uuid()
|
||||
with mock.patch(
|
||||
'octavia.common.utils.get_network_driver') as net_mock:
|
||||
net_mock.return_value.get_subnet.return_value = subnet_id
|
||||
lb_json = {'name': 'test1', 'vip': {'subnet_id': subnet_id,
|
||||
'ip_address': '10.0.0.1'}}
|
||||
lb_json.update(optionals)
|
||||
response = self.post(self.LBS_PATH, lb_json)
|
||||
api_lb = response.json
|
||||
self.assertEqual(lb_json.get('vip')['subnet_id'],
|
||||
api_lb.get('vip')['subnet_id'])
|
||||
|
||||
|
||||
class TestLoadBalancerGraph(base.BaseAPITest):
|
||||
|
||||
|
@ -12,9 +12,11 @@
|
||||
# License for the specific language governing permissions and limitations
|
||||
# under the License.
|
||||
|
||||
import mock
|
||||
from oslo_utils import uuidutils
|
||||
|
||||
from octavia.common import constants
|
||||
from octavia.network import base as network_base
|
||||
from octavia.tests.functional.api.v1 import base
|
||||
|
||||
|
||||
@ -177,6 +179,32 @@ class TestMember(base.BaseAPITest):
|
||||
self.set_lb_status(self.lb.get('id'))
|
||||
self.post(self.members_path, member, status=409)
|
||||
|
||||
def test_create_with_bad_subnet(self, **optionals):
|
||||
with mock.patch(
|
||||
'octavia.common.utils.get_network_driver') as net_mock:
|
||||
net_mock.return_value.get_subnet = mock.Mock(
|
||||
side_effect=network_base.SubnetNotFound('Subnet not found'))
|
||||
subnet_id = uuidutils.generate_uuid()
|
||||
response = self.create_member(self.lb.get('id'),
|
||||
self.pool.get('id'),
|
||||
'10.0.0.1', 80, expect_error=True,
|
||||
subnet_id=subnet_id)
|
||||
err_msg = 'Subnet ' + subnet_id + ' not found.'
|
||||
self.assertEqual(response.get('faultstring'), err_msg)
|
||||
|
||||
def test_create_with_valid_subnet(self, **optionals):
|
||||
with mock.patch(
|
||||
'octavia.common.utils.get_network_driver') as net_mock:
|
||||
subnet_id = uuidutils.generate_uuid()
|
||||
net_mock.return_value.get_subnet.return_value = subnet_id
|
||||
response = self.create_member(self.lb.get('id'),
|
||||
self.pool.get('id'),
|
||||
'10.0.0.1', 80, expect_error=True,
|
||||
subnet_id=subnet_id)
|
||||
self.assertEqual('10.0.0.1', response.get('ip_address'))
|
||||
self.assertEqual(80, response.get('protocol_port'))
|
||||
self.assertEqual(subnet_id, response.get('subnet_id'))
|
||||
|
||||
def test_update(self):
|
||||
old_port = 80
|
||||
new_port = 88
|
||||
|
@ -12,9 +12,13 @@
|
||||
# License for the specific language governing permissions and limitations
|
||||
# under the License.
|
||||
|
||||
import mock
|
||||
from oslo_utils import uuidutils
|
||||
|
||||
import octavia.common.constants as constants
|
||||
import octavia.common.exceptions as exceptions
|
||||
import octavia.common.validate as validate
|
||||
from octavia.network import base as network_base
|
||||
import octavia.tests.unit.base as base
|
||||
|
||||
|
||||
@ -238,3 +242,18 @@ class TestValidations(base.TestCase):
|
||||
l7p = {}
|
||||
self.assertRaises(exceptions.InvalidL7PolicyArgs,
|
||||
validate.sanitize_l7policy_api_args, l7p)
|
||||
|
||||
def test_subnet_exists_with_bad_subnet(self):
|
||||
with mock.patch(
|
||||
'octavia.common.utils.get_network_driver') as net_mock:
|
||||
net_mock.return_value.get_subnet = mock.Mock(
|
||||
side_effect=network_base.SubnetNotFound('Subnet not found'))
|
||||
subnet_id = uuidutils.generate_uuid()
|
||||
self.assertEqual(validate.subnet_exists(subnet_id), False)
|
||||
|
||||
def test_subnet_exists_with_valid_subnet(self):
|
||||
subnet_id = uuidutils.generate_uuid()
|
||||
with mock.patch(
|
||||
'octavia.common.utils.get_network_driver') as net_mock:
|
||||
net_mock.return_value.get_subnet.return_value = subnet_id
|
||||
self.assertEqual(validate.subnet_exists(subnet_id), True)
|
||||
|
Loading…
Reference in New Issue
Block a user