diff --git a/neutron/db/portsecurity_db.py b/neutron/db/portsecurity_db.py index 343b5371050..944f5f89c01 100644 --- a/neutron/db/portsecurity_db.py +++ b/neutron/db/portsecurity_db.py @@ -29,8 +29,8 @@ class PortSecurityDbMixin(portsecurity_db_common.PortSecurityDbCommon): def _extend_port_security_dict(self, response_data, db_data): if ('port-security' in getattr(self, 'supported_extension_aliases', [])): - psec_value = db_data['port_security'][psec.PORTSECURITY] - response_data[psec.PORTSECURITY] = psec_value + super(PortSecurityDbMixin, self)._extend_port_security_dict( + response_data, db_data) def _determine_port_security_and_has_ip(self, context, port): """Returns a tuple of booleans (port_security_enabled, has_ip). diff --git a/neutron/db/portsecurity_db_common.py b/neutron/db/portsecurity_db_common.py index e348f81b19c..46e769bb5fb 100644 --- a/neutron/db/portsecurity_db_common.py +++ b/neutron/db/portsecurity_db_common.py @@ -56,6 +56,13 @@ class NetworkSecurityBinding(model_base.BASEV2): class PortSecurityDbCommon(object): """Mixin class to add port security.""" + def _extend_port_security_dict(self, response_data, db_data): + if db_data.get('port_security') is None: + response_data[psec.PORTSECURITY] = psec.DEFAULT_PORT_SECURITY + else: + response_data[psec.PORTSECURITY] = ( + db_data['port_security'][psec.PORTSECURITY]) + def _process_network_port_security_create( self, context, network_req, network_res): with context.session.begin(subtransactions=True): @@ -81,52 +88,58 @@ class PortSecurityDbCommon(object): query = self._model_query(context, NetworkSecurityBinding) binding = query.filter( NetworkSecurityBinding.network_id == network_id).one() + return binding.port_security_enabled except exc.NoResultFound: - raise psec.PortSecurityBindingNotFound() - return binding.port_security_enabled + # NOTE(ihrachys) the resource may have been created before port + # security extension was enabled; return default value + return psec.DEFAULT_PORT_SECURITY def _get_port_security_binding(self, context, port_id): try: query = self._model_query(context, PortSecurityBinding) binding = query.filter( PortSecurityBinding.port_id == port_id).one() + return binding.port_security_enabled except exc.NoResultFound: - raise psec.PortSecurityBindingNotFound() - return binding.port_security_enabled + # NOTE(ihrachys) the resource may have been created before port + # security extension was enabled; return default value + return psec.DEFAULT_PORT_SECURITY def _process_port_port_security_update( self, context, port_req, port_res): - if psec.PORTSECURITY in port_req: - port_security_enabled = port_req[psec.PORTSECURITY] - else: + if psec.PORTSECURITY not in port_req: return + port_security_enabled = port_req[psec.PORTSECURITY] try: query = self._model_query(context, PortSecurityBinding) port_id = port_res['id'] binding = query.filter( PortSecurityBinding.port_id == port_id).one() - binding.port_security_enabled = port_security_enabled port_res[psec.PORTSECURITY] = port_security_enabled except exc.NoResultFound: - raise psec.PortSecurityBindingNotFound() + # NOTE(ihrachys) the resource may have been created before port + # security extension was enabled; create the binding model + self._process_port_port_security_create( + context, port_req, port_res) def _process_network_port_security_update( self, context, network_req, network_res): - if psec.PORTSECURITY in network_req: - port_security_enabled = network_req[psec.PORTSECURITY] - else: + if psec.PORTSECURITY not in network_req: return + port_security_enabled = network_req[psec.PORTSECURITY] try: query = self._model_query(context, NetworkSecurityBinding) network_id = network_res['id'] binding = query.filter( NetworkSecurityBinding.network_id == network_id).one() - binding.port_security_enabled = port_security_enabled network_res[psec.PORTSECURITY] = port_security_enabled except exc.NoResultFound: - raise psec.PortSecurityBindingNotFound() + # NOTE(ihrachys) the resource may have been created before port + # security extension was enabled; create the binding model + self._process_network_port_security_create( + context, network_req, network_res) def _make_network_port_security_dict(self, port_security, fields=None): res = {'network_id': port_security['network_id'], diff --git a/neutron/extensions/portsecurity.py b/neutron/extensions/portsecurity.py index e743c0c025c..b22fa31fdea 100644 --- a/neutron/extensions/portsecurity.py +++ b/neutron/extensions/portsecurity.py @@ -17,6 +17,9 @@ from neutron.api.v2 import attributes from neutron.common import exceptions as nexception +DEFAULT_PORT_SECURITY = True + + class PortSecurityPortHasSecurityGroup(nexception.InUse): message = _("Port has security group associated. Cannot disable port " "security or ip address until security group is removed") @@ -27,16 +30,13 @@ class PortSecurityAndIPRequiredForSecurityGroups(nexception.InvalidInput): " address in order to use security groups.") -class PortSecurityBindingNotFound(nexception.InvalidExtensionEnv): - message = _("Port does not have port security binding.") - PORTSECURITY = 'port_security_enabled' EXTENDED_ATTRIBUTES_2_0 = { 'networks': { PORTSECURITY: {'allow_post': True, 'allow_put': True, 'convert_to': attributes.convert_to_boolean, 'enforce_policy': True, - 'default': True, + 'default': DEFAULT_PORT_SECURITY, 'is_visible': True}, }, 'ports': { diff --git a/neutron/plugins/ml2/extensions/port_security.py b/neutron/plugins/ml2/extensions/port_security.py index 6a10a41a617..f6b581b29a2 100644 --- a/neutron/plugins/ml2/extensions/port_security.py +++ b/neutron/plugins/ml2/extensions/port_security.py @@ -40,8 +40,7 @@ class PortSecurityExtensionDriver(api.ExtensionDriver, def process_create_network(self, context, data, result): # Create the network extension attributes. if psec.PORTSECURITY not in data: - data[psec.PORTSECURITY] = (psec.EXTENDED_ATTRIBUTES_2_0['networks'] - [psec.PORTSECURITY]['default']) + data[psec.PORTSECURITY] = psec.DEFAULT_PORT_SECURITY self._process_network_port_security_create(context, data, result) def process_update_network(self, context, data, result): @@ -65,15 +64,6 @@ class PortSecurityExtensionDriver(api.ExtensionDriver, def extend_port_dict(self, session, db_data, result): self._extend_port_security_dict(result, db_data) - def _extend_port_security_dict(self, response_data, db_data): - if db_data.get('port_security') is None: - response_data[psec.PORTSECURITY] = ( - psec.EXTENDED_ATTRIBUTES_2_0['networks'] - [psec.PORTSECURITY]['default']) - else: - response_data[psec.PORTSECURITY] = ( - db_data['port_security'][psec.PORTSECURITY]) - def _determine_port_security(self, context, port): """Returns a boolean (port_security_enabled). diff --git a/neutron/tests/unit/db/test_portsecurity_db.py b/neutron/tests/unit/db/test_portsecurity_db.py new file mode 100644 index 00000000000..5ed483070f1 --- /dev/null +++ b/neutron/tests/unit/db/test_portsecurity_db.py @@ -0,0 +1,47 @@ +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + +import mock + +from neutron.db import portsecurity_db as pd +from neutron.db import portsecurity_db_common as pdc +from neutron.tests import base + +common = pdc.PortSecurityDbCommon + + +class FakePlugin(pd.PortSecurityDbMixin): + + supported_extension_aliases = ['port-security'] + + +class PortSecurityDbMixinTestCase(base.BaseTestCase): + + def setUp(self): + super(PortSecurityDbMixinTestCase, self).setUp() + self.plugin = FakePlugin() + + @mock.patch.object(common, '_extend_port_security_dict') + def test__extend_port_security_dict_relies_on_common(self, extend): + response = mock.Mock() + dbdata = mock.Mock() + self.plugin._extend_port_security_dict(response, dbdata) + extend.assert_called_once_with(response, dbdata) + + @mock.patch.object(common, '_extend_port_security_dict') + def test__extend_port_security_dict_ignored_if_extension_disabled(self, + extend): + response = mock.Mock() + dbdata = mock.Mock() + self.plugin.supported_extension_aliases = [] + self.plugin._extend_port_security_dict(response, dbdata) + self.assertFalse(extend.called) diff --git a/neutron/tests/unit/db/test_portsecurity_db_common.py b/neutron/tests/unit/db/test_portsecurity_db_common.py new file mode 100644 index 00000000000..a63ddcf19e0 --- /dev/null +++ b/neutron/tests/unit/db/test_portsecurity_db_common.py @@ -0,0 +1,76 @@ +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + +import mock +from sqlalchemy.orm import exc + +from neutron.db import portsecurity_db_common as pdc +from neutron.extensions import portsecurity as psec +from neutron.tests import base + + +common = pdc.PortSecurityDbCommon + + +class PortSecurityDbCommonTestCase(base.BaseTestCase): + + def setUp(self): + super(PortSecurityDbCommonTestCase, self).setUp() + self.common = common() + + def _test__get_security_binding_no_binding(self, getter): + port_sec_enabled = True + req = {psec.PORTSECURITY: port_sec_enabled} + res = {} + with mock.patch.object( + self.common, '_model_query', + create=True, + side_effect=exc.NoResultFound): + val = getter(req, res) + self.assertEqual(port_sec_enabled, val) + + def test__get_port_security_binding_no_binding(self): + self._test__get_security_binding_no_binding( + self.common._get_port_security_binding) + + def test__get_network_security_binding_no_binding(self): + self._test__get_security_binding_no_binding( + self.common._get_network_security_binding) + + def _test__process_security_update_no_binding(self, creator, updater): + req = {psec.PORTSECURITY: False} + res = {} + context = mock.Mock() + with mock.patch.object( + self.common, '_model_query', + create=True, + side_effect=exc.NoResultFound): + updater(context, req, res) + creator.assert_called_with(context, req, res) + + @mock.patch.object(common, '_process_port_port_security_create') + def test__process_port_port_security_update_no_binding(self, creator): + self._test__process_security_update_no_binding( + creator, + self.common._process_port_port_security_update) + + @mock.patch.object(common, '_process_network_port_security_create') + def test__process_network_port_security_update_no_binding(self, creator): + self._test__process_security_update_no_binding( + creator, + self.common._process_network_port_security_update) + + def test__extend_port_security_dict_no_port_security(self): + for db_data in ({'port_security': None, 'name': 'net1'}, {}): + response_data = {} + self.common._extend_port_security_dict(response_data, db_data) + self.assertTrue(response_data[psec.PORTSECURITY]) diff --git a/neutron/tests/unit/plugins/ml2/extensions/test_port_security.py b/neutron/tests/unit/plugins/ml2/extensions/test_port_security.py index 06ebd5baaf3..2f81d71445e 100644 --- a/neutron/tests/unit/plugins/ml2/extensions/test_port_security.py +++ b/neutron/tests/unit/plugins/ml2/extensions/test_port_security.py @@ -13,20 +13,27 @@ # License for the specific language governing permissions and limitations # under the License. +import mock + from neutron.extensions import portsecurity as psec from neutron.plugins.ml2.extensions import port_security from neutron.tests.unit.plugins.ml2 import test_plugin class TestML2ExtensionPortSecurity(test_plugin.Ml2PluginV2TestCase): - def test_extend_port_dict_no_port_security(self): - """Test _extend_port_security_dict won't crash - if port_security item is None - """ + def _test_extend_dict_no_port_security(self, func): + """Test extend_*_dict won't crash if port_security item is None.""" for db_data in ({'port_security': None, 'name': 'net1'}, {}): response_data = {} + session = mock.Mock() driver = port_security.PortSecurityExtensionDriver() - driver._extend_port_security_dict(response_data, db_data) + getattr(driver, func)(session, db_data, response_data) self.assertTrue(response_data[psec.PORTSECURITY]) + + def test_extend_port_dict_no_port_security(self): + self._test_extend_dict_no_port_security('extend_port_dict') + + def test_extend_network_dict_no_port_security(self): + self._test_extend_dict_no_port_security('extend_network_dict') diff --git a/neutron/tests/unit/plugins/ml2/test_ext_portsecurity.py b/neutron/tests/unit/plugins/ml2/test_ext_portsecurity.py index e6ea22e81fe..4c0d6ca314e 100644 --- a/neutron/tests/unit/plugins/ml2/test_ext_portsecurity.py +++ b/neutron/tests/unit/plugins/ml2/test_ext_portsecurity.py @@ -34,8 +34,6 @@ class PSExtDriverTestCase(test_plugin.Ml2PluginV2TestCase, def test_create_net_port_security_default(self): _core_plugin = manager.NeutronManager.get_plugin() admin_ctx = context.get_admin_context() - _default_value = (psec.EXTENDED_ATTRIBUTES_2_0['networks'] - [psec.PORTSECURITY]['default']) args = {'network': {'name': 'test', 'tenant_id': '', @@ -48,7 +46,7 @@ class PSExtDriverTestCase(test_plugin.Ml2PluginV2TestCase, finally: if network: _core_plugin.delete_network(admin_ctx, network['id']) - self.assertEqual(_default_value, _value) + self.assertEqual(psec.DEFAULT_PORT_SECURITY, _value) def test_create_port_with_secgroup_none_and_port_security_false(self): if self._skip_security_group: