Merge "validate the subnet id for create loadbalancer and member"

This commit is contained in:
Jenkins 2016-06-03 06:19:34 +00:00 committed by Gerrit Code Review
commit e040d5926e
8 changed files with 116 additions and 2 deletions

View File

@ -27,6 +27,7 @@ from octavia.api.v1.types import load_balancer as lb_types
from octavia.common import constants
from octavia.common import data_models
from octavia.common import exceptions
import octavia.common.validate as validate
from octavia.db import prepare as db_prepare
from octavia.i18n import _LI
@ -97,6 +98,12 @@ class LoadBalancersController(base.BaseController):
body=lb_types.LoadBalancerPOST, status_code=202)
def post(self, load_balancer):
"""Creates a load balancer."""
# Validate the subnet id
if load_balancer.vip.subnet_id:
if not validate.subnet_exists(load_balancer.vip.subnet_id):
raise exceptions.NotFound(resource='Subnet',
id=load_balancer.vip.subnet_id)
context = pecan.request.context.get('octavia_context')
if load_balancer.listeners:
return self._create_load_balancer_graph(context, load_balancer)

View File

@ -25,6 +25,7 @@ from octavia.api.v1.controllers import base
from octavia.api.v1.types import member as member_types
from octavia.common import constants
from octavia.common import exceptions
import octavia.common.validate as validate
from octavia.db import prepare as db_prepare
from octavia.i18n import _LI
@ -91,6 +92,10 @@ class MembersController(base.BaseController):
def post(self, member):
"""Creates a pool member on a pool."""
context = pecan.request.context.get('octavia_context')
# Validate member subnet
if member.subnet_id and not validate.subnet_exists(member.subnet_id):
raise exceptions.NotFound(resource='Subnet',
id=member.subnet_id)
member_dict = db_prepare.create_member(member.to_dict(
render_unsets=True), self.pool_id)
self._test_lb_and_listener_statuses(context.session)

View File

@ -24,8 +24,12 @@ import hashlib
import random
import socket
from oslo_config import cfg
from oslo_log import log as logging
from oslo_utils import excutils
from stevedore import driver as stevedore_driver
CONF = cfg.CONF
LOG = logging.getLogger(__name__)
@ -56,6 +60,16 @@ def base64_sha1_string(string_to_hash):
return b64_str.decode('UTF-8')
def get_network_driver():
CONF.import_group('controller_worker', 'octavia.common.config')
network_driver = stevedore_driver.DriverManager(
namespace='octavia.network.drivers',
name=CONF.controller_worker.network_driver,
invoke_on_load=True
).driver
return network_driver
class exception_logger(object):
"""Wrap a function and log raised exception

View File

@ -18,12 +18,14 @@ Several handy validation functions that go beyond simple type checking.
Defined here so these can also be used at deeper levels than the API.
"""
import re
import rfc3986
from octavia.common import constants
from octavia.common import exceptions
from octavia.common import utils
def url(url):
@ -198,3 +200,14 @@ def sanitize_l7policy_api_args(l7policy, create=False):
if len(l7policy.keys()) == 0:
raise exceptions.InvalidL7PolicyArgs(msg='Invalid update options')
return l7policy
def subnet_exists(subnet_id):
network_driver = utils.get_network_driver()
# Throws an exception when trying to get a subnet which
# does not exist.
try:
network_driver.get_subnet(subnet_id)
except Exception:
return False
return True

View File

@ -168,11 +168,11 @@ class BaseAPITest(base_db_test.OctaviaDBTestBase):
return response.json
def create_member(self, lb_id, pool_id, ip_address,
protocol_port, **optionals):
protocol_port, expect_error=False, **optionals):
req_dict = {'ip_address': ip_address, 'protocol_port': protocol_port}
req_dict.update(optionals)
path = self.MEMBERS_PATH.format(lb_id=lb_id, pool_id=pool_id)
response = self.post(path, req_dict)
response = self.post(path, req_dict, expect_errors=expect_error)
return response.json
def create_member_with_listener(self, lb_id, listener_id, pool_id,

View File

@ -14,9 +14,11 @@
import copy
import mock
from oslo_utils import uuidutils
from octavia.common import constants
from octavia.network import base as network_base
from octavia.tests.functional.api.v1 import base
@ -229,6 +231,32 @@ class TestLoadBalancer(base.BaseAPITest):
path = self.LB_PATH.format(lb_id='bad_uuid')
self.delete(path, status=404)
def test_create_with_bad_subnet(self, **optionals):
with mock.patch(
'octavia.common.utils.get_network_driver') as net_mock:
net_mock.return_value.get_subnet = mock.Mock(
side_effect=network_base.SubnetNotFound('Subnet not found'))
subnet_id = uuidutils.generate_uuid()
lb_json = {'name': 'test1', 'vip': {'subnet_id': subnet_id,
'ip_address': '10.0.0.1'}}
lb_json.update(optionals)
response = self.post(self.LBS_PATH, lb_json, expect_errors=True)
err_msg = 'Subnet ' + subnet_id + ' not found.'
self.assertEqual(response.json.get('faultstring'), err_msg)
def test_create_with_valid_subnet(self, **optionals):
subnet_id = uuidutils.generate_uuid()
with mock.patch(
'octavia.common.utils.get_network_driver') as net_mock:
net_mock.return_value.get_subnet.return_value = subnet_id
lb_json = {'name': 'test1', 'vip': {'subnet_id': subnet_id,
'ip_address': '10.0.0.1'}}
lb_json.update(optionals)
response = self.post(self.LBS_PATH, lb_json)
api_lb = response.json
self.assertEqual(lb_json.get('vip')['subnet_id'],
api_lb.get('vip')['subnet_id'])
class TestLoadBalancerGraph(base.BaseAPITest):

View File

@ -12,9 +12,11 @@
# License for the specific language governing permissions and limitations
# under the License.
import mock
from oslo_utils import uuidutils
from octavia.common import constants
from octavia.network import base as network_base
from octavia.tests.functional.api.v1 import base
@ -177,6 +179,32 @@ class TestMember(base.BaseAPITest):
self.set_lb_status(self.lb.get('id'))
self.post(self.members_path, member, status=409)
def test_create_with_bad_subnet(self, **optionals):
with mock.patch(
'octavia.common.utils.get_network_driver') as net_mock:
net_mock.return_value.get_subnet = mock.Mock(
side_effect=network_base.SubnetNotFound('Subnet not found'))
subnet_id = uuidutils.generate_uuid()
response = self.create_member(self.lb.get('id'),
self.pool.get('id'),
'10.0.0.1', 80, expect_error=True,
subnet_id=subnet_id)
err_msg = 'Subnet ' + subnet_id + ' not found.'
self.assertEqual(response.get('faultstring'), err_msg)
def test_create_with_valid_subnet(self, **optionals):
with mock.patch(
'octavia.common.utils.get_network_driver') as net_mock:
subnet_id = uuidutils.generate_uuid()
net_mock.return_value.get_subnet.return_value = subnet_id
response = self.create_member(self.lb.get('id'),
self.pool.get('id'),
'10.0.0.1', 80, expect_error=True,
subnet_id=subnet_id)
self.assertEqual('10.0.0.1', response.get('ip_address'))
self.assertEqual(80, response.get('protocol_port'))
self.assertEqual(subnet_id, response.get('subnet_id'))
def test_update(self):
old_port = 80
new_port = 88

View File

@ -12,9 +12,13 @@
# License for the specific language governing permissions and limitations
# under the License.
import mock
from oslo_utils import uuidutils
import octavia.common.constants as constants
import octavia.common.exceptions as exceptions
import octavia.common.validate as validate
from octavia.network import base as network_base
import octavia.tests.unit.base as base
@ -238,3 +242,18 @@ class TestValidations(base.TestCase):
l7p = {}
self.assertRaises(exceptions.InvalidL7PolicyArgs,
validate.sanitize_l7policy_api_args, l7p)
def test_subnet_exists_with_bad_subnet(self):
with mock.patch(
'octavia.common.utils.get_network_driver') as net_mock:
net_mock.return_value.get_subnet = mock.Mock(
side_effect=network_base.SubnetNotFound('Subnet not found'))
subnet_id = uuidutils.generate_uuid()
self.assertEqual(validate.subnet_exists(subnet_id), False)
def test_subnet_exists_with_valid_subnet(self):
subnet_id = uuidutils.generate_uuid()
with mock.patch(
'octavia.common.utils.get_network_driver') as net_mock:
net_mock.return_value.get_subnet.return_value = subnet_id
self.assertEqual(validate.subnet_exists(subnet_id), True)