port security: gracefully handle resources with no bindings
Resources could be created before the extension was enabled in the setup. In that case, no bindings are created for them. In that case, we should gracefully return default (True) value when extracting the value using the mixin; and we should also create binding model on update request, if there is no existing binding model for the resource. While at it, introduced a constant to store the default value for port security (True) and changed several tests to use the constant instead of extracting it from extension resource map. Change-Id: I8607cdecdc16c5f94635c94e2f02700c732806eb Closes-Bug: #1509312
This commit is contained in:
parent
89e8e2a9c5
commit
b0519cf0ad
|
@ -29,8 +29,8 @@ class PortSecurityDbMixin(portsecurity_db_common.PortSecurityDbCommon):
|
|||
def _extend_port_security_dict(self, response_data, db_data):
|
||||
if ('port-security' in
|
||||
getattr(self, 'supported_extension_aliases', [])):
|
||||
psec_value = db_data['port_security'][psec.PORTSECURITY]
|
||||
response_data[psec.PORTSECURITY] = psec_value
|
||||
super(PortSecurityDbMixin, self)._extend_port_security_dict(
|
||||
response_data, db_data)
|
||||
|
||||
def _determine_port_security_and_has_ip(self, context, port):
|
||||
"""Returns a tuple of booleans (port_security_enabled, has_ip).
|
||||
|
|
|
@ -53,6 +53,13 @@ class NetworkSecurityBinding(model_base.BASEV2):
|
|||
class PortSecurityDbCommon(object):
|
||||
"""Mixin class to add port security."""
|
||||
|
||||
def _extend_port_security_dict(self, response_data, db_data):
|
||||
if db_data.get('port_security') is None:
|
||||
response_data[psec.PORTSECURITY] = psec.DEFAULT_PORT_SECURITY
|
||||
else:
|
||||
response_data[psec.PORTSECURITY] = (
|
||||
db_data['port_security'][psec.PORTSECURITY])
|
||||
|
||||
def _process_network_port_security_create(
|
||||
self, context, network_req, network_res):
|
||||
with context.session.begin(subtransactions=True):
|
||||
|
@ -78,52 +85,58 @@ class PortSecurityDbCommon(object):
|
|||
query = self._model_query(context, NetworkSecurityBinding)
|
||||
binding = query.filter(
|
||||
NetworkSecurityBinding.network_id == network_id).one()
|
||||
return binding.port_security_enabled
|
||||
except exc.NoResultFound:
|
||||
raise psec.PortSecurityBindingNotFound()
|
||||
return binding.port_security_enabled
|
||||
# NOTE(ihrachys) the resource may have been created before port
|
||||
# security extension was enabled; return default value
|
||||
return psec.DEFAULT_PORT_SECURITY
|
||||
|
||||
def _get_port_security_binding(self, context, port_id):
|
||||
try:
|
||||
query = self._model_query(context, PortSecurityBinding)
|
||||
binding = query.filter(
|
||||
PortSecurityBinding.port_id == port_id).one()
|
||||
return binding.port_security_enabled
|
||||
except exc.NoResultFound:
|
||||
raise psec.PortSecurityBindingNotFound()
|
||||
return binding.port_security_enabled
|
||||
# NOTE(ihrachys) the resource may have been created before port
|
||||
# security extension was enabled; return default value
|
||||
return psec.DEFAULT_PORT_SECURITY
|
||||
|
||||
def _process_port_port_security_update(
|
||||
self, context, port_req, port_res):
|
||||
if psec.PORTSECURITY in port_req:
|
||||
port_security_enabled = port_req[psec.PORTSECURITY]
|
||||
else:
|
||||
if psec.PORTSECURITY not in port_req:
|
||||
return
|
||||
port_security_enabled = port_req[psec.PORTSECURITY]
|
||||
try:
|
||||
query = self._model_query(context, PortSecurityBinding)
|
||||
port_id = port_res['id']
|
||||
binding = query.filter(
|
||||
PortSecurityBinding.port_id == port_id).one()
|
||||
|
||||
binding.port_security_enabled = port_security_enabled
|
||||
port_res[psec.PORTSECURITY] = port_security_enabled
|
||||
except exc.NoResultFound:
|
||||
raise psec.PortSecurityBindingNotFound()
|
||||
# NOTE(ihrachys) the resource may have been created before port
|
||||
# security extension was enabled; create the binding model
|
||||
self._process_port_port_security_create(
|
||||
context, port_req, port_res)
|
||||
|
||||
def _process_network_port_security_update(
|
||||
self, context, network_req, network_res):
|
||||
if psec.PORTSECURITY in network_req:
|
||||
port_security_enabled = network_req[psec.PORTSECURITY]
|
||||
else:
|
||||
if psec.PORTSECURITY not in network_req:
|
||||
return
|
||||
port_security_enabled = network_req[psec.PORTSECURITY]
|
||||
try:
|
||||
query = self._model_query(context, NetworkSecurityBinding)
|
||||
network_id = network_res['id']
|
||||
binding = query.filter(
|
||||
NetworkSecurityBinding.network_id == network_id).one()
|
||||
|
||||
binding.port_security_enabled = port_security_enabled
|
||||
network_res[psec.PORTSECURITY] = port_security_enabled
|
||||
except exc.NoResultFound:
|
||||
raise psec.PortSecurityBindingNotFound()
|
||||
# NOTE(ihrachys) the resource may have been created before port
|
||||
# security extension was enabled; create the binding model
|
||||
self._process_network_port_security_create(
|
||||
context, network_req, network_res)
|
||||
|
||||
def _make_network_port_security_dict(self, port_security, fields=None):
|
||||
res = {'network_id': port_security['network_id'],
|
||||
|
|
|
@ -18,6 +18,9 @@ from neutron.api.v2 import attributes
|
|||
from neutron.common import exceptions as nexception
|
||||
|
||||
|
||||
DEFAULT_PORT_SECURITY = True
|
||||
|
||||
|
||||
class PortSecurityPortHasSecurityGroup(nexception.InUse):
|
||||
message = _("Port has security group associated. Cannot disable port "
|
||||
"security or ip address until security group is removed")
|
||||
|
@ -28,16 +31,13 @@ class PortSecurityAndIPRequiredForSecurityGroups(nexception.InvalidInput):
|
|||
" address in order to use security groups.")
|
||||
|
||||
|
||||
class PortSecurityBindingNotFound(nexception.InvalidExtensionEnv):
|
||||
message = _("Port does not have port security binding.")
|
||||
|
||||
PORTSECURITY = 'port_security_enabled'
|
||||
EXTENDED_ATTRIBUTES_2_0 = {
|
||||
'networks': {
|
||||
PORTSECURITY: {'allow_post': True, 'allow_put': True,
|
||||
'convert_to': attributes.convert_to_boolean,
|
||||
'enforce_policy': True,
|
||||
'default': True,
|
||||
'default': DEFAULT_PORT_SECURITY,
|
||||
'is_visible': True},
|
||||
},
|
||||
'ports': {
|
||||
|
|
|
@ -41,8 +41,7 @@ class PortSecurityExtensionDriver(api.ExtensionDriver,
|
|||
def process_create_network(self, context, data, result):
|
||||
# Create the network extension attributes.
|
||||
if psec.PORTSECURITY not in data:
|
||||
data[psec.PORTSECURITY] = (psec.EXTENDED_ATTRIBUTES_2_0['networks']
|
||||
[psec.PORTSECURITY]['default'])
|
||||
data[psec.PORTSECURITY] = psec.DEFAULT_PORT_SECURITY
|
||||
self._process_network_port_security_create(context, data, result)
|
||||
|
||||
def process_update_network(self, context, data, result):
|
||||
|
@ -66,15 +65,6 @@ class PortSecurityExtensionDriver(api.ExtensionDriver,
|
|||
def extend_port_dict(self, session, db_data, result):
|
||||
self._extend_port_security_dict(result, db_data)
|
||||
|
||||
def _extend_port_security_dict(self, response_data, db_data):
|
||||
if db_data.get('port_security') is None:
|
||||
response_data[psec.PORTSECURITY] = (
|
||||
psec.EXTENDED_ATTRIBUTES_2_0['networks']
|
||||
[psec.PORTSECURITY]['default'])
|
||||
else:
|
||||
response_data[psec.PORTSECURITY] = (
|
||||
db_data['port_security'][psec.PORTSECURITY])
|
||||
|
||||
def _determine_port_security(self, context, port):
|
||||
"""Returns a boolean (port_security_enabled).
|
||||
|
||||
|
|
|
@ -0,0 +1,47 @@
|
|||
# Licensed under the Apache License, Version 2.0 (the "License"); you may
|
||||
# not use this file except in compliance with the License. You may obtain
|
||||
# a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
# License for the specific language governing permissions and limitations
|
||||
# under the License.
|
||||
|
||||
import mock
|
||||
|
||||
from neutron.db import portsecurity_db as pd
|
||||
from neutron.db import portsecurity_db_common as pdc
|
||||
from neutron.tests import base
|
||||
|
||||
common = pdc.PortSecurityDbCommon
|
||||
|
||||
|
||||
class FakePlugin(pd.PortSecurityDbMixin):
|
||||
|
||||
supported_extension_aliases = ['port-security']
|
||||
|
||||
|
||||
class PortSecurityDbMixinTestCase(base.BaseTestCase):
|
||||
|
||||
def setUp(self):
|
||||
super(PortSecurityDbMixinTestCase, self).setUp()
|
||||
self.plugin = FakePlugin()
|
||||
|
||||
@mock.patch.object(common, '_extend_port_security_dict')
|
||||
def test__extend_port_security_dict_relies_on_common(self, extend):
|
||||
response = mock.Mock()
|
||||
dbdata = mock.Mock()
|
||||
self.plugin._extend_port_security_dict(response, dbdata)
|
||||
extend.assert_called_once_with(response, dbdata)
|
||||
|
||||
@mock.patch.object(common, '_extend_port_security_dict')
|
||||
def test__extend_port_security_dict_ignored_if_extension_disabled(self,
|
||||
extend):
|
||||
response = mock.Mock()
|
||||
dbdata = mock.Mock()
|
||||
self.plugin.supported_extension_aliases = []
|
||||
self.plugin._extend_port_security_dict(response, dbdata)
|
||||
self.assertFalse(extend.called)
|
|
@ -0,0 +1,76 @@
|
|||
# Licensed under the Apache License, Version 2.0 (the "License"); you may
|
||||
# not use this file except in compliance with the License. You may obtain
|
||||
# a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
# License for the specific language governing permissions and limitations
|
||||
# under the License.
|
||||
|
||||
import mock
|
||||
from sqlalchemy.orm import exc
|
||||
|
||||
from neutron.db import portsecurity_db_common as pdc
|
||||
from neutron.extensions import portsecurity as psec
|
||||
from neutron.tests import base
|
||||
|
||||
|
||||
common = pdc.PortSecurityDbCommon
|
||||
|
||||
|
||||
class PortSecurityDbCommonTestCase(base.BaseTestCase):
|
||||
|
||||
def setUp(self):
|
||||
super(PortSecurityDbCommonTestCase, self).setUp()
|
||||
self.common = common()
|
||||
|
||||
def _test__get_security_binding_no_binding(self, getter):
|
||||
port_sec_enabled = True
|
||||
req = {psec.PORTSECURITY: port_sec_enabled}
|
||||
res = {}
|
||||
with mock.patch.object(
|
||||
self.common, '_model_query',
|
||||
create=True,
|
||||
side_effect=exc.NoResultFound):
|
||||
val = getter(req, res)
|
||||
self.assertEqual(port_sec_enabled, val)
|
||||
|
||||
def test__get_port_security_binding_no_binding(self):
|
||||
self._test__get_security_binding_no_binding(
|
||||
self.common._get_port_security_binding)
|
||||
|
||||
def test__get_network_security_binding_no_binding(self):
|
||||
self._test__get_security_binding_no_binding(
|
||||
self.common._get_network_security_binding)
|
||||
|
||||
def _test__process_security_update_no_binding(self, creator, updater):
|
||||
req = {psec.PORTSECURITY: False}
|
||||
res = {}
|
||||
context = mock.Mock()
|
||||
with mock.patch.object(
|
||||
self.common, '_model_query',
|
||||
create=True,
|
||||
side_effect=exc.NoResultFound):
|
||||
updater(context, req, res)
|
||||
creator.assert_called_with(context, req, res)
|
||||
|
||||
@mock.patch.object(common, '_process_port_port_security_create')
|
||||
def test__process_port_port_security_update_no_binding(self, creator):
|
||||
self._test__process_security_update_no_binding(
|
||||
creator,
|
||||
self.common._process_port_port_security_update)
|
||||
|
||||
@mock.patch.object(common, '_process_network_port_security_create')
|
||||
def test__process_network_port_security_update_no_binding(self, creator):
|
||||
self._test__process_security_update_no_binding(
|
||||
creator,
|
||||
self.common._process_network_port_security_update)
|
||||
|
||||
def test__extend_port_security_dict_no_port_security(self):
|
||||
for db_data in ({'port_security': None, 'name': 'net1'}, {}):
|
||||
response_data = {}
|
||||
self.common._extend_port_security_dict(response_data, db_data)
|
||||
self.assertTrue(response_data[psec.PORTSECURITY])
|
|
@ -13,20 +13,27 @@
|
|||
# License for the specific language governing permissions and limitations
|
||||
# under the License.
|
||||
|
||||
import mock
|
||||
|
||||
from neutron.extensions import portsecurity as psec
|
||||
from neutron.plugins.ml2.extensions import port_security
|
||||
from neutron.tests.unit.plugins.ml2 import test_plugin
|
||||
|
||||
|
||||
class TestML2ExtensionPortSecurity(test_plugin.Ml2PluginV2TestCase):
|
||||
def test_extend_port_dict_no_port_security(self):
|
||||
"""Test _extend_port_security_dict won't crash
|
||||
if port_security item is None
|
||||
"""
|
||||
def _test_extend_dict_no_port_security(self, func):
|
||||
"""Test extend_*_dict won't crash if port_security item is None."""
|
||||
for db_data in ({'port_security': None, 'name': 'net1'}, {}):
|
||||
response_data = {}
|
||||
session = mock.Mock()
|
||||
|
||||
driver = port_security.PortSecurityExtensionDriver()
|
||||
driver._extend_port_security_dict(response_data, db_data)
|
||||
getattr(driver, func)(session, db_data, response_data)
|
||||
|
||||
self.assertTrue(response_data[psec.PORTSECURITY])
|
||||
|
||||
def test_extend_port_dict_no_port_security(self):
|
||||
self._test_extend_dict_no_port_security('extend_port_dict')
|
||||
|
||||
def test_extend_network_dict_no_port_security(self):
|
||||
self._test_extend_dict_no_port_security('extend_network_dict')
|
||||
|
|
|
@ -34,8 +34,6 @@ class PSExtDriverTestCase(test_plugin.Ml2PluginV2TestCase,
|
|||
def test_create_net_port_security_default(self):
|
||||
_core_plugin = manager.NeutronManager.get_plugin()
|
||||
admin_ctx = context.get_admin_context()
|
||||
_default_value = (psec.EXTENDED_ATTRIBUTES_2_0['networks']
|
||||
[psec.PORTSECURITY]['default'])
|
||||
args = {'network':
|
||||
{'name': 'test',
|
||||
'tenant_id': '',
|
||||
|
@ -48,7 +46,7 @@ class PSExtDriverTestCase(test_plugin.Ml2PluginV2TestCase,
|
|||
finally:
|
||||
if network:
|
||||
_core_plugin.delete_network(admin_ctx, network['id'])
|
||||
self.assertEqual(_default_value, _value)
|
||||
self.assertEqual(psec.DEFAULT_PORT_SECURITY, _value)
|
||||
|
||||
def test_create_port_with_secgroup_none_and_port_security_false(self):
|
||||
if self._skip_security_group:
|
||||
|
|
Loading…
Reference in New Issue