port security: gracefully handle resources with no bindings

Resources could be created before the extension was enabled in the
setup. In that case, no bindings are created for them. In that case, we
should gracefully return default (True) value when extracting the value
using the mixin; and we should also create binding model on update
request, if there is no existing binding model for the resource.

While at it, introduced a constant to store the default value for port
security (True) and changed several tests to use the constant instead of
extracting it from extension resource map.

Change-Id: I8607cdecdc16c5f94635c94e2f02700c732806eb
Closes-Bug: #1509312
This commit is contained in:
Ihar Hrachyshka 2016-03-17 16:20:52 +01:00
parent 89e8e2a9c5
commit b0519cf0ad
8 changed files with 170 additions and 39 deletions

View File

@ -29,8 +29,8 @@ class PortSecurityDbMixin(portsecurity_db_common.PortSecurityDbCommon):
def _extend_port_security_dict(self, response_data, db_data):
if ('port-security' in
getattr(self, 'supported_extension_aliases', [])):
psec_value = db_data['port_security'][psec.PORTSECURITY]
response_data[psec.PORTSECURITY] = psec_value
super(PortSecurityDbMixin, self)._extend_port_security_dict(
response_data, db_data)
def _determine_port_security_and_has_ip(self, context, port):
"""Returns a tuple of booleans (port_security_enabled, has_ip).

View File

@ -53,6 +53,13 @@ class NetworkSecurityBinding(model_base.BASEV2):
class PortSecurityDbCommon(object):
"""Mixin class to add port security."""
def _extend_port_security_dict(self, response_data, db_data):
if db_data.get('port_security') is None:
response_data[psec.PORTSECURITY] = psec.DEFAULT_PORT_SECURITY
else:
response_data[psec.PORTSECURITY] = (
db_data['port_security'][psec.PORTSECURITY])
def _process_network_port_security_create(
self, context, network_req, network_res):
with context.session.begin(subtransactions=True):
@ -78,52 +85,58 @@ class PortSecurityDbCommon(object):
query = self._model_query(context, NetworkSecurityBinding)
binding = query.filter(
NetworkSecurityBinding.network_id == network_id).one()
return binding.port_security_enabled
except exc.NoResultFound:
raise psec.PortSecurityBindingNotFound()
return binding.port_security_enabled
# NOTE(ihrachys) the resource may have been created before port
# security extension was enabled; return default value
return psec.DEFAULT_PORT_SECURITY
def _get_port_security_binding(self, context, port_id):
try:
query = self._model_query(context, PortSecurityBinding)
binding = query.filter(
PortSecurityBinding.port_id == port_id).one()
return binding.port_security_enabled
except exc.NoResultFound:
raise psec.PortSecurityBindingNotFound()
return binding.port_security_enabled
# NOTE(ihrachys) the resource may have been created before port
# security extension was enabled; return default value
return psec.DEFAULT_PORT_SECURITY
def _process_port_port_security_update(
self, context, port_req, port_res):
if psec.PORTSECURITY in port_req:
port_security_enabled = port_req[psec.PORTSECURITY]
else:
if psec.PORTSECURITY not in port_req:
return
port_security_enabled = port_req[psec.PORTSECURITY]
try:
query = self._model_query(context, PortSecurityBinding)
port_id = port_res['id']
binding = query.filter(
PortSecurityBinding.port_id == port_id).one()
binding.port_security_enabled = port_security_enabled
port_res[psec.PORTSECURITY] = port_security_enabled
except exc.NoResultFound:
raise psec.PortSecurityBindingNotFound()
# NOTE(ihrachys) the resource may have been created before port
# security extension was enabled; create the binding model
self._process_port_port_security_create(
context, port_req, port_res)
def _process_network_port_security_update(
self, context, network_req, network_res):
if psec.PORTSECURITY in network_req:
port_security_enabled = network_req[psec.PORTSECURITY]
else:
if psec.PORTSECURITY not in network_req:
return
port_security_enabled = network_req[psec.PORTSECURITY]
try:
query = self._model_query(context, NetworkSecurityBinding)
network_id = network_res['id']
binding = query.filter(
NetworkSecurityBinding.network_id == network_id).one()
binding.port_security_enabled = port_security_enabled
network_res[psec.PORTSECURITY] = port_security_enabled
except exc.NoResultFound:
raise psec.PortSecurityBindingNotFound()
# NOTE(ihrachys) the resource may have been created before port
# security extension was enabled; create the binding model
self._process_network_port_security_create(
context, network_req, network_res)
def _make_network_port_security_dict(self, port_security, fields=None):
res = {'network_id': port_security['network_id'],

View File

@ -18,6 +18,9 @@ from neutron.api.v2 import attributes
from neutron.common import exceptions as nexception
DEFAULT_PORT_SECURITY = True
class PortSecurityPortHasSecurityGroup(nexception.InUse):
message = _("Port has security group associated. Cannot disable port "
"security or ip address until security group is removed")
@ -28,16 +31,13 @@ class PortSecurityAndIPRequiredForSecurityGroups(nexception.InvalidInput):
" address in order to use security groups.")
class PortSecurityBindingNotFound(nexception.InvalidExtensionEnv):
message = _("Port does not have port security binding.")
PORTSECURITY = 'port_security_enabled'
EXTENDED_ATTRIBUTES_2_0 = {
'networks': {
PORTSECURITY: {'allow_post': True, 'allow_put': True,
'convert_to': attributes.convert_to_boolean,
'enforce_policy': True,
'default': True,
'default': DEFAULT_PORT_SECURITY,
'is_visible': True},
},
'ports': {

View File

@ -41,8 +41,7 @@ class PortSecurityExtensionDriver(api.ExtensionDriver,
def process_create_network(self, context, data, result):
# Create the network extension attributes.
if psec.PORTSECURITY not in data:
data[psec.PORTSECURITY] = (psec.EXTENDED_ATTRIBUTES_2_0['networks']
[psec.PORTSECURITY]['default'])
data[psec.PORTSECURITY] = psec.DEFAULT_PORT_SECURITY
self._process_network_port_security_create(context, data, result)
def process_update_network(self, context, data, result):
@ -66,15 +65,6 @@ class PortSecurityExtensionDriver(api.ExtensionDriver,
def extend_port_dict(self, session, db_data, result):
self._extend_port_security_dict(result, db_data)
def _extend_port_security_dict(self, response_data, db_data):
if db_data.get('port_security') is None:
response_data[psec.PORTSECURITY] = (
psec.EXTENDED_ATTRIBUTES_2_0['networks']
[psec.PORTSECURITY]['default'])
else:
response_data[psec.PORTSECURITY] = (
db_data['port_security'][psec.PORTSECURITY])
def _determine_port_security(self, context, port):
"""Returns a boolean (port_security_enabled).

View File

@ -0,0 +1,47 @@
# Licensed under the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License. You may obtain
# a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations
# under the License.
import mock
from neutron.db import portsecurity_db as pd
from neutron.db import portsecurity_db_common as pdc
from neutron.tests import base
common = pdc.PortSecurityDbCommon
class FakePlugin(pd.PortSecurityDbMixin):
supported_extension_aliases = ['port-security']
class PortSecurityDbMixinTestCase(base.BaseTestCase):
def setUp(self):
super(PortSecurityDbMixinTestCase, self).setUp()
self.plugin = FakePlugin()
@mock.patch.object(common, '_extend_port_security_dict')
def test__extend_port_security_dict_relies_on_common(self, extend):
response = mock.Mock()
dbdata = mock.Mock()
self.plugin._extend_port_security_dict(response, dbdata)
extend.assert_called_once_with(response, dbdata)
@mock.patch.object(common, '_extend_port_security_dict')
def test__extend_port_security_dict_ignored_if_extension_disabled(self,
extend):
response = mock.Mock()
dbdata = mock.Mock()
self.plugin.supported_extension_aliases = []
self.plugin._extend_port_security_dict(response, dbdata)
self.assertFalse(extend.called)

View File

@ -0,0 +1,76 @@
# Licensed under the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License. You may obtain
# a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations
# under the License.
import mock
from sqlalchemy.orm import exc
from neutron.db import portsecurity_db_common as pdc
from neutron.extensions import portsecurity as psec
from neutron.tests import base
common = pdc.PortSecurityDbCommon
class PortSecurityDbCommonTestCase(base.BaseTestCase):
def setUp(self):
super(PortSecurityDbCommonTestCase, self).setUp()
self.common = common()
def _test__get_security_binding_no_binding(self, getter):
port_sec_enabled = True
req = {psec.PORTSECURITY: port_sec_enabled}
res = {}
with mock.patch.object(
self.common, '_model_query',
create=True,
side_effect=exc.NoResultFound):
val = getter(req, res)
self.assertEqual(port_sec_enabled, val)
def test__get_port_security_binding_no_binding(self):
self._test__get_security_binding_no_binding(
self.common._get_port_security_binding)
def test__get_network_security_binding_no_binding(self):
self._test__get_security_binding_no_binding(
self.common._get_network_security_binding)
def _test__process_security_update_no_binding(self, creator, updater):
req = {psec.PORTSECURITY: False}
res = {}
context = mock.Mock()
with mock.patch.object(
self.common, '_model_query',
create=True,
side_effect=exc.NoResultFound):
updater(context, req, res)
creator.assert_called_with(context, req, res)
@mock.patch.object(common, '_process_port_port_security_create')
def test__process_port_port_security_update_no_binding(self, creator):
self._test__process_security_update_no_binding(
creator,
self.common._process_port_port_security_update)
@mock.patch.object(common, '_process_network_port_security_create')
def test__process_network_port_security_update_no_binding(self, creator):
self._test__process_security_update_no_binding(
creator,
self.common._process_network_port_security_update)
def test__extend_port_security_dict_no_port_security(self):
for db_data in ({'port_security': None, 'name': 'net1'}, {}):
response_data = {}
self.common._extend_port_security_dict(response_data, db_data)
self.assertTrue(response_data[psec.PORTSECURITY])

View File

@ -13,20 +13,27 @@
# License for the specific language governing permissions and limitations
# under the License.
import mock
from neutron.extensions import portsecurity as psec
from neutron.plugins.ml2.extensions import port_security
from neutron.tests.unit.plugins.ml2 import test_plugin
class TestML2ExtensionPortSecurity(test_plugin.Ml2PluginV2TestCase):
def test_extend_port_dict_no_port_security(self):
"""Test _extend_port_security_dict won't crash
if port_security item is None
"""
def _test_extend_dict_no_port_security(self, func):
"""Test extend_*_dict won't crash if port_security item is None."""
for db_data in ({'port_security': None, 'name': 'net1'}, {}):
response_data = {}
session = mock.Mock()
driver = port_security.PortSecurityExtensionDriver()
driver._extend_port_security_dict(response_data, db_data)
getattr(driver, func)(session, db_data, response_data)
self.assertTrue(response_data[psec.PORTSECURITY])
def test_extend_port_dict_no_port_security(self):
self._test_extend_dict_no_port_security('extend_port_dict')
def test_extend_network_dict_no_port_security(self):
self._test_extend_dict_no_port_security('extend_network_dict')

View File

@ -34,8 +34,6 @@ class PSExtDriverTestCase(test_plugin.Ml2PluginV2TestCase,
def test_create_net_port_security_default(self):
_core_plugin = manager.NeutronManager.get_plugin()
admin_ctx = context.get_admin_context()
_default_value = (psec.EXTENDED_ATTRIBUTES_2_0['networks']
[psec.PORTSECURITY]['default'])
args = {'network':
{'name': 'test',
'tenant_id': '',
@ -48,7 +46,7 @@ class PSExtDriverTestCase(test_plugin.Ml2PluginV2TestCase,
finally:
if network:
_core_plugin.delete_network(admin_ctx, network['id'])
self.assertEqual(_default_value, _value)
self.assertEqual(psec.DEFAULT_PORT_SECURITY, _value)
def test_create_port_with_secgroup_none_and_port_security_false(self):
if self._skip_security_group: