From 3c61304fb5693a70e4668219f97361fda5e3aaae Mon Sep 17 00:00:00 2001 From: Matt Dietz Date: Fri, 19 Jul 2013 19:58:23 +0000 Subject: [PATCH] Refactors Security Groups into their own module Pulls security groups and their associated unit tests into separate submodules intended to make plugin.py easier to read. --- quark/__init__.py | 44 +++ quark/plugin.py | 267 +++-------------- quark/plugin_modules/security_groups.py | 227 +++++++++++++++ .../plugin_modules/test_security_groups.py | 272 ++++++++++++++++++ quark/tests/test_quark_plugin.py | 248 ---------------- 5 files changed, 586 insertions(+), 472 deletions(-) create mode 100644 quark/plugin_modules/security_groups.py create mode 100644 quark/tests/plugin_modules/test_security_groups.py diff --git a/quark/__init__.py b/quark/__init__.py index 0982b06..94b4de0 100644 --- a/quark/__init__.py +++ b/quark/__init__.py @@ -12,3 +12,47 @@ # implied. # See the License for the specific language governing permissions and # limitations under the License. + +from neutron import quota +from oslo.config import cfg + + +CONF = cfg.CONF + + +quark_opts = [ + cfg.StrOpt('net_driver', + default='quark.drivers.base.BaseDriver', + help=_('The client to use to talk to the backend')), + cfg.StrOpt('ipam_driver', default='quark.ipam.QuarkIpam', + help=_('IPAM Implementation to use')), + cfg.BoolOpt('ipam_reuse_after', default=7200, + help=_("Time in seconds til IP and MAC reuse" + "after deallocation.")), + cfg.StrOpt("strategy_driver", + default='quark.network_strategy.JSONStrategy', + help=_("Tree of network assignment strategy")), + cfg.StrOpt('net_driver_cfg', default='/etc/neutron/quark.ini', + help=_("Path to the config for the net driver")) +] + +quark_quota_opts = [ + cfg.IntOpt('quota_ports_per_network', + default=64, + help=_('Maximum ports per network per tenant')), + cfg.IntOpt('quota_security_rules_per_group', + default=20, + help=_('Maximum security group rules in a group')), +] + +quark_resources = [ + quota.BaseResource('ports_per_network', + 'quota_ports_per_network'), + quota.BaseResource('security_rules_per_group', + 'quota_security_rules_per_group'), +] + +CONF.register_opts(quark_opts, "QUARK") +CONF.register_opts(quark_quota_opts, "QUOTAS") + +quota.QUOTAS.register_resources(quark_resources) diff --git a/quark/plugin.py b/quark/plugin.py index 30cdabf..47004cd 100644 --- a/quark/plugin.py +++ b/quark/plugin.py @@ -42,49 +42,13 @@ from quark.db import models from quark import exceptions as quark_exceptions from quark import network_strategy from quark.plugin_modules import mac_address_ranges +from quark.plugin_modules import security_groups from quark import plugin_views as v LOG = logging.getLogger("neutron.quark") CONF = cfg.CONF DEFAULT_ROUTE = netaddr.IPNetwork("0.0.0.0/0") -DEFAULT_SG_UUID = "00000000-0000-0000-0000-000000000000" - -quark_opts = [ - cfg.StrOpt('net_driver', - default='quark.drivers.base.BaseDriver', - help=_('The client to use to talk to the backend')), - cfg.StrOpt('ipam_driver', default='quark.ipam.QuarkIpam', - help=_('IPAM Implementation to use')), - cfg.BoolOpt('ipam_reuse_after', default=7200, - help=_("Time in seconds til IP and MAC reuse" - "after deallocation.")), - cfg.StrOpt("strategy_driver", - default='quark.network_strategy.JSONStrategy', - help=_("Tree of network assignment strategy")), - cfg.StrOpt('net_driver_cfg', default='/etc/neutron/quark.ini', - help=_("Path to the config for the net driver")) -] - -quark_quota_opts = [ - cfg.IntOpt('quota_ports_per_network', - default=64, - help=_('Maximum ports per network per tenant')), - cfg.IntOpt('quota_security_rules_per_group', - default=20, - help=_('Maximum security group rules in a group')), -] -quark_resources = [ - quota.BaseResource('ports_per_network', - 'quota_ports_per_network'), - quota.BaseResource('security_rules_per_group', - 'quota_security_rules_per_group'), -] - STRATEGY = network_strategy.STRATEGY -CONF.register_opts(quark_opts, "QUARK") -CONF.register_opts(quark_quota_opts, "QUOTAS") - -quota.QUOTAS.register_resources(quark_resources) def _pop_param(attrs, param, default=None): @@ -133,49 +97,14 @@ class Plugin(neutron_plugin_base_v2.NeutronPluginBaseV2, if not group_ids or group_ids is attributes.ATTR_NOT_SPECIFIED: return ([], []) group_ids = list(set(group_ids)) - security_groups = [] + groups = [] for gid in group_ids: group = db_api.security_group_find(context, id=gid, scope=db_api.ONE) if not group: raise sg_ext.SecurityGroupNotFound(id=gid) - security_groups.append(group) - return (group_ids, security_groups) - - def _validate_security_group_rule(self, context, rule): - PROTOCOLS = {"icmp": 1, "tcp": 6, "udp": 17} - ALLOWED_WITH_RANGE = [6, 17] - - if rule.get("remote_ip_prefix") and rule.get("remote_group_id"): - raise sg_ext.SecurityGroupRemoteGroupAndRemoteIpPrefix() - - protocol = rule.pop('protocol') - port_range_min = rule['port_range_min'] - port_range_max = rule['port_range_max'] - - if protocol: - if isinstance(protocol, str): - protocol = protocol.lower() - protocol = PROTOCOLS.get(protocol) - - if not protocol: - raise sg_ext.SecurityGroupRuleInvalidProtocol() - - if protocol in ALLOWED_WITH_RANGE: - if (port_range_min is None) != (port_range_max is None): - raise exceptions.InvalidInput( - error_message="For TCP/UDP rules, cannot wildcard " - "only one end of port range.") - if port_range_min is not None and port_range_max is not None: - if port_range_min > port_range_max: - raise sg_ext.SecurityGroupInvalidPortRange() - - rule['protocol'] = protocol - else: - if port_range_min is not None or port_range_max is not None: - raise sg_ext.SecurityGroupProtocolRequiredWithPorts() - - return rule + groups.append(group) + return (group_ids, groups) def _validate_subnet_cidr(self, context, network_id, new_subnet_cidr): """Validate the CIDR for a subnet. @@ -470,8 +399,8 @@ class Plugin(neutron_plugin_base_v2.NeutronPluginBaseV2, if not self.get_security_groups( context, - filters={"id": DEFAULT_SG_UUID}): - self._create_default_security_group(context) + filters={"id": security_groups.DEFAULT_SG_UUID}): + security_groups._create_default_security_group(context) return v._make_network_dict(new_net) def update_network(self, context, id, network): @@ -1013,153 +942,6 @@ class Plugin(neutron_plugin_base_v2.NeutronPluginBaseV2, return v._make_ip_dict(address) - def create_security_group(self, context, security_group): - LOG.info("create_security_group for tenant %s" % - (context.tenant_id)) - group = security_group["security_group"] - group_name = group.get('name', '') - if group_name == "default": - raise sg_ext.SecurityGroupDefaultAlreadyExists() - group_id = uuidutils.generate_uuid() - - self.net_driver.create_security_group( - context, - group_name, - group_id=group_id, - **group) - - group["id"] = group_id - group["name"] = group_name - group["tenant_id"] = context.tenant_id - dbgroup = db_api.security_group_create(context, **group) - return v._make_security_group_dict(dbgroup) - - def _create_default_security_group(self, context): - default_group = { - "name": "default", "description": "", - "group_id": DEFAULT_SG_UUID, - "port_egress_rules": [], - "port_ingress_rules": [ - {"ethertype": "IPv4", "protocol": 1}, - {"ethertype": "IPv4", "protocol": 6}, - {"ethertype": "IPv4", "protocol": 17}, - {"ethertype": "IPv6", "protocol": 1}, - {"ethertype": "IPv6", "protocol": 6}, - {"ethertype": "IPv6", "protocol": 17}, - ]} - - self.net_driver.create_security_group( - context, - "default", - **default_group) - - default_group["id"] = DEFAULT_SG_UUID - default_group["tenant_id"] = context.tenant_id - for rule in default_group.pop("port_ingress_rules"): - db_api.security_group_rule_create( - context, security_group_id=default_group["id"], - tenant_id=context.tenant_id, direction="ingress", - **rule) - db_api.security_group_create(context, **default_group) - - def create_security_group_rule(self, context, security_group_rule): - LOG.info("create_security_group for tenant %s" % - (context.tenant_id)) - rule = self._validate_security_group_rule( - context, security_group_rule["security_group_rule"]) - rule["id"] = uuidutils.generate_uuid() - - group_id = rule["security_group_id"] - group = db_api.security_group_find(context, id=group_id, - scope=db_api.ONE) - if not group: - raise sg_ext.SecurityGroupNotFound(group_id=group_id) - - quota.QUOTAS.limit_check( - context, context.tenant_id, - security_rules_per_group=len(group.get("rules", [])) + 1) - - self.net_driver.create_security_group_rule(context, group_id, rule) - - return v._make_security_group_rule_dict( - db_api.security_group_rule_create(context, **rule)) - - def delete_security_group(self, context, id): - LOG.info("delete_security_group %s for tenant %s" % - (id, context.tenant_id)) - - group = db_api.security_group_find(context, id=id, scope=db_api.ONE) - - #TODO(anyone): name and ports are lazy-loaded. Could be good op later - if not group: - raise sg_ext.SecurityGroupNotFound(group_id=id) - if id == DEFAULT_SG_UUID or group.name == "default": - raise sg_ext.SecurityGroupCannotRemoveDefault() - if group.ports: - raise sg_ext.SecurityGroupInUse(id=id) - self.net_driver.delete_security_group(context, id) - db_api.security_group_delete(context, group) - - def delete_security_group_rule(self, context, id): - LOG.info("delete_security_group %s for tenant %s" % - (id, context.tenant_id)) - rule = db_api.security_group_rule_find(context, id=id, - scope=db_api.ONE) - if not rule: - raise sg_ext.SecurityGroupRuleNotFound(group_id=id) - - group = db_api.security_group_find(context, id=rule["group_id"], - scope=db_api.ONE) - if not group: - raise sg_ext.SecurityGroupNotFound(id=id) - - self.net_driver.delete_security_group_rule( - context, group.id, v._make_security_group_rule_dict(rule)) - - rule["id"] = id - db_api.security_group_rule_delete(context, rule) - - def get_security_group(self, context, id, fields=None): - LOG.info("get_security_group %s for tenant %s" % - (id, context.tenant_id)) - group = db_api.security_group_find(context, id=id, scope=db_api.ONE) - if not group: - raise sg_ext.SecurityGroupNotFound(group_id=id) - return v._make_security_group_dict(group, fields) - - def get_security_group_rule(self, context, id, fields=None): - LOG.info("get_security_group_rule %s for tenant %s" % - (id, context.tenant_id)) - rule = db_api.security_group_rule_find(context, id=id, - scope=db_api.ONE) - if not rule: - raise sg_ext.SecurityGroupRuleNotFound(rule_id=id) - return v._make_security_group_rule_dict(rule, fields) - - def get_security_groups(self, context, filters=None, fields=None, - sorts=None, limit=None, marker=None, - page_reverse=False): - LOG.info("get_security_groups for tenant %s" % - (context.tenant_id)) - groups = db_api.security_group_find(context, **filters) - return [v._make_security_group_dict(group) for group in groups] - - def get_security_group_rules(self, context, filters=None, fields=None, - sorts=None, limit=None, marker=None, - page_reverse=False): - LOG.info("get_security_group_rules for tenant %s" % - (context.tenant_id)) - rules = db_api.security_group_rule_find(context, filters=filters) - return [v._make_security_group_rule_dict(rule) for rule in rules] - - def update_security_group(self, context, id, security_group): - new_group = security_group["security_group"] - group = db_api.security_group_find(context, id=id, scope=db_api.ONE) - self.net_driver.update_security_group(context, id, **new_group) - - db_group = db_api.security_group_update(context, group, **new_group) - return v._make_security_group_dict(db_group) - def create_ip_policy(self, context, ip_policy): LOG.info("create_ip_policy for tenant %s" % context.tenant_id) @@ -1224,3 +1006,40 @@ class Plugin(neutron_plugin_base_v2.NeutronPluginBaseV2, def delete_mac_address_range(self, context, id): mac_address_ranges.delete_mac_address_range(context, id) + + def create_security_group(self, context, security_group): + return security_groups.create_security_group(context, security_group) + + def create_security_group_rule(self, context, security_group_rule): + return security_groups.create_security_group_rule(context, + security_group_rule) + + def delete_security_group(self, context, id): + security_groups.delete_security_group(context, id) + + def delete_security_group_rule(self, context, id): + security_groups.delete_security_group_rule(context, id) + + def get_security_group(self, context, id, fields=None): + return security_groups.get_security_group(context, id, fields) + + def get_security_group_rule(self, context, id, fields=None): + return security_groups.get_security_group_rule(context, id, fields) + + def get_security_groups(self, context, filters=None, fields=None, + sorts=None, limit=None, marker=None, + page_reverse=False): + return security_groups.get_security_groups(context, filters, fields, + sorts, limit, marker, + page_reverse) + + def get_security_group_rules(self, context, filters=None, fields=None, + sorts=None, limit=None, marker=None, + page_reverse=False): + return security_groups.get_security_group_rules(context, filters, + fields, sorts, limit, + marker, page_reverse) + + def update_security_group(self, context, id, security_group): + return security_groups.update_security_group(context, id, + security_group) diff --git a/quark/plugin_modules/security_groups.py b/quark/plugin_modules/security_groups.py new file mode 100644 index 0000000..7c4e5f5 --- /dev/null +++ b/quark/plugin_modules/security_groups.py @@ -0,0 +1,227 @@ +# Copyright 2013 Openstack Foundation +# All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + +from neutron.common import exceptions +from neutron.extensions import securitygroup as sg_ext +from neutron.openstack.common import importutils +from neutron.openstack.common import log as logging +from neutron.openstack.common import uuidutils +from neutron import quota +from oslo.config import cfg + +from quark.db import api as db_api +from quark import plugin_views as v + + +CONF = cfg.CONF +LOG = logging.getLogger("neutron.quark") +DEFAULT_SG_UUID = "00000000-0000-0000-0000-000000000000" + + +net_driver = (importutils.import_class(CONF.QUARK.net_driver))() +net_driver.load_config(CONF.QUARK.net_driver_cfg) + + +def _validate_security_group_rule(context, rule): + PROTOCOLS = {"icmp": 1, "tcp": 6, "udp": 17} + ALLOWED_WITH_RANGE = [6, 17] + + if rule.get("remote_ip_prefix") and rule.get("remote_group_id"): + raise sg_ext.SecurityGroupRemoteGroupAndRemoteIpPrefix() + + protocol = rule.pop('protocol') + port_range_min = rule['port_range_min'] + port_range_max = rule['port_range_max'] + + if protocol: + if isinstance(protocol, str): + protocol = protocol.lower() + protocol = PROTOCOLS.get(protocol) + + if not protocol: + raise sg_ext.SecurityGroupRuleInvalidProtocol() + + if protocol in ALLOWED_WITH_RANGE: + if (port_range_min is None) != (port_range_max is None): + raise exceptions.InvalidInput( + error_message="For TCP/UDP rules, cannot wildcard " + "only one end of port range.") + if port_range_min is not None and port_range_max is not None: + if port_range_min > port_range_max: + raise sg_ext.SecurityGroupInvalidPortRange() + + rule['protocol'] = protocol + else: + if port_range_min is not None or port_range_max is not None: + raise sg_ext.SecurityGroupProtocolRequiredWithPorts() + + return rule + + +def create_security_group(context, security_group): + LOG.info("create_security_group for tenant %s" % + (context.tenant_id)) + group = security_group["security_group"] + group_name = group.get('name', '') + if group_name == "default": + raise sg_ext.SecurityGroupDefaultAlreadyExists() + group_id = uuidutils.generate_uuid() + + net_driver.create_security_group( + context, + group_name, + group_id=group_id, + **group) + + group["id"] = group_id + group["name"] = group_name + group["tenant_id"] = context.tenant_id + dbgroup = db_api.security_group_create(context, **group) + return v._make_security_group_dict(dbgroup) + + +def _create_default_security_group(context): + default_group = { + "name": "default", "description": "", + "group_id": DEFAULT_SG_UUID, + "port_egress_rules": [], + "port_ingress_rules": [ + {"ethertype": "IPv4", "protocol": 1}, + {"ethertype": "IPv4", "protocol": 6}, + {"ethertype": "IPv4", "protocol": 17}, + {"ethertype": "IPv6", "protocol": 1}, + {"ethertype": "IPv6", "protocol": 6}, + {"ethertype": "IPv6", "protocol": 17}, + ]} + + net_driver.create_security_group( + context, + "default", + **default_group) + + default_group["id"] = DEFAULT_SG_UUID + default_group["tenant_id"] = context.tenant_id + for rule in default_group.pop("port_ingress_rules"): + db_api.security_group_rule_create( + context, security_group_id=default_group["id"], + tenant_id=context.tenant_id, direction="ingress", + **rule) + db_api.security_group_create(context, **default_group) + + +def create_security_group_rule(context, security_group_rule): + LOG.info("create_security_group for tenant %s" % + (context.tenant_id)) + rule = _validate_security_group_rule( + context, security_group_rule["security_group_rule"]) + rule["id"] = uuidutils.generate_uuid() + + group_id = rule["security_group_id"] + group = db_api.security_group_find(context, id=group_id, + scope=db_api.ONE) + if not group: + raise sg_ext.SecurityGroupNotFound(group_id=group_id) + + quota.QUOTAS.limit_check( + context, context.tenant_id, + security_rules_per_group=len(group.get("rules", [])) + 1) + + net_driver.create_security_group_rule(context, group_id, rule) + + return v._make_security_group_rule_dict( + db_api.security_group_rule_create(context, **rule)) + + +def delete_security_group(context, id): + LOG.info("delete_security_group %s for tenant %s" % + (id, context.tenant_id)) + + group = db_api.security_group_find(context, id=id, scope=db_api.ONE) + + #TODO(anyone): name and ports are lazy-loaded. Could be good op later + if not group: + raise sg_ext.SecurityGroupNotFound(group_id=id) + if id == DEFAULT_SG_UUID or group.name == "default": + raise sg_ext.SecurityGroupCannotRemoveDefault() + if group.ports: + raise sg_ext.SecurityGroupInUse(id=id) + net_driver.delete_security_group(context, id) + db_api.security_group_delete(context, group) + + +def delete_security_group_rule(context, id): + LOG.info("delete_security_group %s for tenant %s" % + (id, context.tenant_id)) + rule = db_api.security_group_rule_find(context, id=id, + scope=db_api.ONE) + if not rule: + raise sg_ext.SecurityGroupRuleNotFound(group_id=id) + + group = db_api.security_group_find(context, id=rule["group_id"], + scope=db_api.ONE) + if not group: + raise sg_ext.SecurityGroupNotFound(id=id) + + net_driver.delete_security_group_rule( + context, group.id, v._make_security_group_rule_dict(rule)) + + rule["id"] = id + db_api.security_group_rule_delete(context, rule) + + +def get_security_group(context, id, fields=None): + LOG.info("get_security_group %s for tenant %s" % + (id, context.tenant_id)) + group = db_api.security_group_find(context, id=id, scope=db_api.ONE) + if not group: + raise sg_ext.SecurityGroupNotFound(group_id=id) + return v._make_security_group_dict(group, fields) + + +def get_security_group_rule(context, id, fields=None): + LOG.info("get_security_group_rule %s for tenant %s" % + (id, context.tenant_id)) + rule = db_api.security_group_rule_find(context, id=id, + scope=db_api.ONE) + if not rule: + raise sg_ext.SecurityGroupRuleNotFound(rule_id=id) + return v._make_security_group_rule_dict(rule, fields) + + +def get_security_groups(context, filters=None, fields=None, + sorts=None, limit=None, marker=None, + page_reverse=False): + LOG.info("get_security_groups for tenant %s" % + (context.tenant_id)) + groups = db_api.security_group_find(context, **filters) + return [v._make_security_group_dict(group) for group in groups] + + +def get_security_group_rules(context, filters=None, fields=None, + sorts=None, limit=None, marker=None, + page_reverse=False): + LOG.info("get_security_group_rules for tenant %s" % + (context.tenant_id)) + rules = db_api.security_group_rule_find(context, filters=filters) + return [v._make_security_group_rule_dict(rule) for rule in rules] + + +def update_security_group(context, id, security_group): + new_group = security_group["security_group"] + group = db_api.security_group_find(context, id=id, scope=db_api.ONE) + net_driver.update_security_group(context, id, **new_group) + + db_group = db_api.security_group_update(context, group, **new_group) + return v._make_security_group_dict(db_group) diff --git a/quark/tests/plugin_modules/test_security_groups.py b/quark/tests/plugin_modules/test_security_groups.py new file mode 100644 index 0000000..9af2436 --- /dev/null +++ b/quark/tests/plugin_modules/test_security_groups.py @@ -0,0 +1,272 @@ +# Copyright 2013 Openstack Foundation +# All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + +import contextlib + +import mock +from neutron.common import exceptions +from neutron.extensions import securitygroup as sg_ext +from oslo.config import cfg + +from quark.db import models +from quark.tests import test_quark_plugin + + +class TestQuarkCreateSecurityGroup(test_quark_plugin.TestQuarkPlugin): + def setUp(self, *args, **kwargs): + super(TestQuarkCreateSecurityGroup, self).setUp(*args, **kwargs) + cfg.CONF.set_override('quota_security_group', 1, 'QUOTAS') + + @contextlib.contextmanager + def _stubs(self, security_group, other=0): + dbgroup = models.SecurityGroup() + dbgroup.update(security_group) + + with contextlib.nested( + mock.patch("quark.db.api.security_group_find"), + mock.patch("quark.db.api.security_group_create"), + ) as (db_find, db_create): + db_find.return_value.count.return_value = other + db_create.return_value = dbgroup + yield db_create + + def test_create_security_group(self): + group = {'name': 'foo', 'description': 'bar', + 'tenant_id': self.context.tenant_id} + expected = {'name': 'foo', 'description': 'bar', + 'tenant_id': self.context.tenant_id, + 'security_group_rules': []} + with self._stubs(group) as group_create: + result = self.plugin.create_security_group( + self.context, {'security_group': group}) + self.assertTrue(group_create.called) + for key in expected.keys(): + self.assertEqual(result[key], expected[key]) + + def test_create_default_security_group(self): + group = {'name': 'default', 'description': 'bar', + 'tenant_id': self.context.tenant_id} + with self._stubs(group) as group_create: + with self.assertRaises(sg_ext.SecurityGroupDefaultAlreadyExists): + self.plugin.create_security_group( + self.context, {'security_group': group}) + self.assertTrue(group_create.called) + + +class TestQuarkDeleteSecurityGroup(test_quark_plugin.TestQuarkPlugin): + @contextlib.contextmanager + def _stubs(self, security_group=None): + dbgroup = None + if security_group: + dbgroup = models.SecurityGroup() + dbgroup.update(security_group) + + with contextlib.nested( + mock.patch("quark.db.api.security_group_find"), + mock.patch("quark.db.api.security_group_delete"), + mock.patch( + "quark.drivers.base.BaseDriver.delete_security_group") + ) as (group_find, db_group_delete, driver_group_delete): + group_find.return_value = dbgroup + db_group_delete.return_value = dbgroup + yield db_group_delete, driver_group_delete + + def test_delete_security_group(self): + group = {'name': 'foo', 'description': 'bar', 'id': 1, + 'tenant_id': self.context.tenant_id} + with self._stubs(group) as (db_delete, driver_delete): + self.plugin.delete_security_group(self.context, 1) + self.assertTrue(db_delete.called) + driver_delete.assert_called_once_with(self.context, 1) + + def test_delete_default_security_group(self): + group = {'name': 'default', 'id': 1, + 'tenant_id': self.context.tenant_id} + with self._stubs(group) as (db_delete, driver_delete): + with self.assertRaises(sg_ext.SecurityGroupCannotRemoveDefault): + self.plugin.delete_security_group(self.context, 1) + + def test_delete_security_group_with_ports(self): + port = models.Port() + group = {'name': 'foo', 'description': 'bar', 'id': 1, + 'tenant_id': self.context.tenant_id, 'ports': [port]} + with self._stubs(group) as (db_delete, driver_delete): + with self.assertRaises(sg_ext.SecurityGroupInUse): + self.plugin.delete_security_group(self.context, 1) + + def test_delete_security_group_not_found(self): + with self._stubs() as (db_delete, driver_delete): + with self.assertRaises(sg_ext.SecurityGroupNotFound): + self.plugin.delete_security_group(self.context, 1) + + +class TestQuarkCreateSecurityGroupRule(test_quark_plugin.TestQuarkPlugin): + def setUp(self, *args, **kwargs): + super(TestQuarkCreateSecurityGroupRule, self).setUp(*args, **kwargs) + cfg.CONF.set_override('quota_security_group_rule', 1, 'QUOTAS') + cfg.CONF.set_override('quota_security_rules_per_group', 1, 'QUOTAS') + self.rule = {'id': 1, 'ethertype': 'IPv4', + 'security_group_id': 1, 'group': {'id': 1}, + 'protocol': None, 'port_range_min': None, + 'port_range_max': None} + self.expected = { + 'id': 1, + 'remote_group_id': None, + 'direction': None, + 'port_range_min': None, + 'port_range_max': None, + 'remote_ip_prefix': None, + 'ethertype': 'IPv4', + 'tenant_id': None, + 'protocol': None, + 'security_group_id': 1} + + @contextlib.contextmanager + def _stubs(self, rule, group): + dbrule = models.SecurityGroupRule() + dbrule.update(rule) + dbrule.group_id = rule['security_group_id'] + dbgroup = None + if group: + dbgroup = models.SecurityGroup() + dbgroup.update(group) + + with contextlib.nested( + mock.patch("quark.db.api.security_group_find"), + mock.patch("quark.db.api.security_group_rule_find"), + mock.patch("quark.db.api.security_group_rule_create") + ) as (group_find, rule_find, rule_create): + group_find.return_value = dbgroup + rule_find.return_value.count.return_value = group.get( + 'port_rules', None) if group else 0 + rule_create.return_value = dbrule + yield rule_create + + def _test_create_security_rule(self, **ruleset): + ruleset['tenant_id'] = self.context.tenant_id + rule = dict(self.rule, **ruleset) + group = rule.pop('group') + expected = dict(self.expected, **ruleset) + expected.pop('group', None) + with self._stubs(rule, group) as rule_create: + result = self.plugin.create_security_group_rule( + self.context, {'security_group_rule': rule}) + self.assertTrue(rule_create.called) + for key in expected.keys(): + self.assertEqual(expected[key], result[key]) + + def test_create_security_rule_IPv6(self): + self._test_create_security_rule(ethertype='IPv6') + + def test_create_security_rule_UDP(self): + self._test_create_security_rule(protocol=17) + + def test_create_security_rule_UDP_string(self): + self._test_create_security_rule(protocol="UDP") + + def test_create_security_rule_bad_string_fail(self): + self.assertRaises(sg_ext.SecurityGroupRuleInvalidProtocol, + self._test_create_security_rule, protocol="DERP") + + def test_create_security_rule_TCP(self): + self._test_create_security_rule(protocol=6) + + def test_create_security_rule_remote_ip(self): + self._test_create_security_rule(remote_ip_prefix='192.168.0.1') + + def test_create_security_rule_remote_group(self): + self._test_create_security_rule(remote_group_id=2) + + def test_create_security_rule_port_range_invalid_ranges_fails(self): + with self.assertRaises(exceptions.InvalidInput): + self._test_create_security_rule(protocol=6, port_range_min=0) + + def test_create_security_group_no_proto_with_ranges_fails(self): + with self.assertRaises(sg_ext.SecurityGroupProtocolRequiredWithPorts): + self._test_create_security_rule(protocol=None, port_range_min=0) + with self.assertRaises(Exception): + self._test_create_security_rule( + protocol=6, port_range_min=1, port_range_max=0) + + def test_create_security_rule_remote_conflicts(self): + with self.assertRaises(Exception): + self._test_create_security_rule(remote_ip_prefix='192.168.0.1', + remote_group_id='0') + + def test_create_security_rule_min_greater_than_max_fails(self): + with self.assertRaises(sg_ext.SecurityGroupInvalidPortRange): + self._test_create_security_rule(protocol=6, port_range_min=10, + port_range_max=9) + + def test_create_security_rule_no_group(self): + with self.assertRaises(sg_ext.SecurityGroupNotFound): + self._test_create_security_rule(group=None) + + def test_create_security_rule_group_at_max(self): + with self.assertRaises(exceptions.OverQuota): + self._test_create_security_rule( + group={'id': 1, 'rules': [models.SecurityGroupRule()]}) + + +class TestQuarkDeleteSecurityGroupRule(test_quark_plugin.TestQuarkPlugin): + @contextlib.contextmanager + def _stubs(self, rule={}, group={'id': 1}): + dbrule = None + dbgroup = None + if group: + dbgroup = models.SecurityGroup() + dbgroup.update(group) + if rule: + dbrule = models.SecurityGroupRule() + dbrule.update(dict(rule, group=dbgroup)) + + with contextlib.nested( + mock.patch("quark.db.api.security_group_find"), + mock.patch("quark.db.api.security_group_rule_find"), + mock.patch("quark.db.api.security_group_rule_delete"), + mock.patch( + "quark.drivers.base.BaseDriver.delete_security_group_rule") + ) as (group_find, rule_find, db_group_delete, driver_group_delete): + group_find.return_value = dbgroup + rule_find.return_value = dbrule + yield db_group_delete, driver_group_delete + + def test_delete_security_group_rule(self): + rule = {'id': 1, 'security_group_id': 1, 'ethertype': 'IPv4', + 'protocol': 6, 'port_range_min': 0, 'port_range_max': 10, + 'direction': 'ingress', 'tenant_id': self.context.tenant_id} + expected = { + 'id': 1, 'ethertype': 'IPv4', 'security_group_id': 1, + 'direction': 'ingress', 'port_range_min': 0, 'port_range_max': 10, + 'remote_group_id': None, 'remote_ip_prefix': None, + 'tenant_id': self.context.tenant_id, 'protocol': 6} + + with self._stubs(dict(rule, group_id=1)) as (db_delete, driver_delete): + self.plugin.delete_security_group_rule(self.context, 1) + self.assertTrue(db_delete.called) + driver_delete.assert_called_once_with(self.context, 1, + expected) + + def test_delete_security_group_rule_rule_not_found(self): + with self._stubs() as (db_delete, driver_delete): + with self.assertRaises(sg_ext.SecurityGroupRuleNotFound): + self.plugin.delete_security_group_rule(self.context, 1) + + def test_delete_security_group_rule_group_not_found(self): + rule = {'id': 1, 'security_group_id': 1, 'ethertype': 'IPv4'} + with self._stubs(dict(rule, group_id=1), + None) as (db_delete, driver_delete): + with self.assertRaises(sg_ext.SecurityGroupNotFound): + self.plugin.delete_security_group_rule(self.context, 1) diff --git a/quark/tests/test_quark_plugin.py b/quark/tests/test_quark_plugin.py index 4e321a7..a2ab414 100644 --- a/quark/tests/test_quark_plugin.py +++ b/quark/tests/test_quark_plugin.py @@ -1741,254 +1741,6 @@ class TestQuarkGetIpAddresses(TestQuarkPlugin): self.plugin.get_ip_address(self.context, 1) -class TestQuarkCreateSecurityGroup(TestQuarkPlugin): - def setUp(self, *args, **kwargs): - super(TestQuarkCreateSecurityGroup, self).setUp(*args, **kwargs) - cfg.CONF.set_override('quota_security_group', 1, 'QUOTAS') - - @contextlib.contextmanager - def _stubs(self, security_group, other=0): - dbgroup = models.SecurityGroup() - dbgroup.update(security_group) - - with contextlib.nested( - mock.patch("quark.db.api.security_group_find"), - mock.patch("quark.db.api.security_group_create"), - ) as (db_find, db_create): - db_find.return_value.count.return_value = other - db_create.return_value = dbgroup - yield db_create - - def test_create_security_group(self): - group = {'name': 'foo', 'description': 'bar', - 'tenant_id': self.context.tenant_id} - expected = {'name': 'foo', 'description': 'bar', - 'tenant_id': self.context.tenant_id, - 'security_group_rules': []} - with self._stubs(group) as group_create: - result = self.plugin.create_security_group( - self.context, {'security_group': group}) - self.assertTrue(group_create.called) - for key in expected.keys(): - self.assertEqual(result[key], expected[key]) - - def test_create_default_security_group(self): - group = {'name': 'default', 'description': 'bar', - 'tenant_id': self.context.tenant_id} - with self._stubs(group) as group_create: - with self.assertRaises(sg_ext.SecurityGroupDefaultAlreadyExists): - self.plugin.create_security_group( - self.context, {'security_group': group}) - self.assertTrue(group_create.called) - - -class TestQuarkDeleteSecurityGroup(TestQuarkPlugin): - @contextlib.contextmanager - def _stubs(self, security_group=None): - dbgroup = None - if security_group: - dbgroup = models.SecurityGroup() - dbgroup.update(security_group) - - with contextlib.nested( - mock.patch("quark.db.api.security_group_find"), - mock.patch("quark.db.api.security_group_delete"), - mock.patch( - "quark.drivers.base.BaseDriver.delete_security_group") - ) as (group_find, db_group_delete, driver_group_delete): - group_find.return_value = dbgroup - db_group_delete.return_value = dbgroup - yield db_group_delete, driver_group_delete - - def test_delete_security_group(self): - group = {'name': 'foo', 'description': 'bar', 'id': 1, - 'tenant_id': self.context.tenant_id} - with self._stubs(group) as (db_delete, driver_delete): - self.plugin.delete_security_group(self.context, 1) - self.assertTrue(db_delete.called) - driver_delete.assert_called_once_with(self.context, 1) - - def test_delete_default_security_group(self): - group = {'name': 'default', 'id': 1, - 'tenant_id': self.context.tenant_id} - with self._stubs(group) as (db_delete, driver_delete): - with self.assertRaises(sg_ext.SecurityGroupCannotRemoveDefault): - self.plugin.delete_security_group(self.context, 1) - - def test_delete_security_group_with_ports(self): - port = models.Port() - group = {'name': 'foo', 'description': 'bar', 'id': 1, - 'tenant_id': self.context.tenant_id, 'ports': [port]} - with self._stubs(group) as (db_delete, driver_delete): - with self.assertRaises(sg_ext.SecurityGroupInUse): - self.plugin.delete_security_group(self.context, 1) - - def test_delete_security_group_not_found(self): - with self._stubs() as (db_delete, driver_delete): - with self.assertRaises(sg_ext.SecurityGroupNotFound): - self.plugin.delete_security_group(self.context, 1) - - -class TestQuarkCreateSecurityGroupRule(TestQuarkPlugin): - def setUp(self, *args, **kwargs): - super(TestQuarkCreateSecurityGroupRule, self).setUp(*args, **kwargs) - cfg.CONF.set_override('quota_security_group_rule', 1, 'QUOTAS') - cfg.CONF.set_override('quota_security_rules_per_group', 1, 'QUOTAS') - self.rule = {'id': 1, 'ethertype': 'IPv4', - 'security_group_id': 1, 'group': {'id': 1}, - 'protocol': None, 'port_range_min': None, - 'port_range_max': None} - self.expected = { - 'id': 1, - 'remote_group_id': None, - 'direction': None, - 'port_range_min': None, - 'port_range_max': None, - 'remote_ip_prefix': None, - 'ethertype': 'IPv4', - 'tenant_id': None, - 'protocol': None, - 'security_group_id': 1} - - @contextlib.contextmanager - def _stubs(self, rule, group): - dbrule = models.SecurityGroupRule() - dbrule.update(rule) - dbrule.group_id = rule['security_group_id'] - dbgroup = None - if group: - dbgroup = models.SecurityGroup() - dbgroup.update(group) - - with contextlib.nested( - mock.patch("quark.db.api.security_group_find"), - mock.patch("quark.db.api.security_group_rule_find"), - mock.patch("quark.db.api.security_group_rule_create") - ) as (group_find, rule_find, rule_create): - group_find.return_value = dbgroup - rule_find.return_value.count.return_value = group.get( - 'port_rules', None) if group else 0 - rule_create.return_value = dbrule - yield rule_create - - def _test_create_security_rule(self, **ruleset): - ruleset['tenant_id'] = self.context.tenant_id - rule = dict(self.rule, **ruleset) - group = rule.pop('group') - expected = dict(self.expected, **ruleset) - expected.pop('group', None) - with self._stubs(rule, group) as rule_create: - result = self.plugin.create_security_group_rule( - self.context, {'security_group_rule': rule}) - self.assertTrue(rule_create.called) - for key in expected.keys(): - self.assertEqual(expected[key], result[key]) - - def test_create_security_rule_IPv6(self): - self._test_create_security_rule(ethertype='IPv6') - - def test_create_security_rule_UDP(self): - self._test_create_security_rule(protocol=17) - - def test_create_security_rule_UDP_string(self): - self._test_create_security_rule(protocol="UDP") - - def test_create_security_rule_bad_string_fail(self): - self.assertRaises(sg_ext.SecurityGroupRuleInvalidProtocol, - self._test_create_security_rule, protocol="DERP") - - def test_create_security_rule_TCP(self): - self._test_create_security_rule(protocol=6) - - def test_create_security_rule_remote_ip(self): - self._test_create_security_rule(remote_ip_prefix='192.168.0.1') - - def test_create_security_rule_remote_group(self): - self._test_create_security_rule(remote_group_id=2) - - def test_create_security_rule_port_range_invalid_ranges_fails(self): - with self.assertRaises(exceptions.InvalidInput): - self._test_create_security_rule(protocol=6, port_range_min=0) - - def test_create_security_group_no_proto_with_ranges_fails(self): - with self.assertRaises(sg_ext.SecurityGroupProtocolRequiredWithPorts): - self._test_create_security_rule(protocol=None, port_range_min=0) - with self.assertRaises(Exception): - self._test_create_security_rule( - protocol=6, port_range_min=1, port_range_max=0) - - def test_create_security_rule_remote_conflicts(self): - with self.assertRaises(Exception): - self._test_create_security_rule(remote_ip_prefix='192.168.0.1', - remote_group_id='0') - - def test_create_security_rule_min_greater_than_max_fails(self): - with self.assertRaises(sg_ext.SecurityGroupInvalidPortRange): - self._test_create_security_rule(protocol=6, port_range_min=10, - port_range_max=9) - - def test_create_security_rule_no_group(self): - with self.assertRaises(sg_ext.SecurityGroupNotFound): - self._test_create_security_rule(group=None) - - def test_create_security_rule_group_at_max(self): - with self.assertRaises(exceptions.OverQuota): - self._test_create_security_rule( - group={'id': 1, 'rules': [models.SecurityGroupRule()]}) - - -class TestQuarkDeleteSecurityGroupRule(TestQuarkPlugin): - @contextlib.contextmanager - def _stubs(self, rule={}, group={'id': 1}): - dbrule = None - dbgroup = None - if group: - dbgroup = models.SecurityGroup() - dbgroup.update(group) - if rule: - dbrule = models.SecurityGroupRule() - dbrule.update(dict(rule, group=dbgroup)) - - with contextlib.nested( - mock.patch("quark.db.api.security_group_find"), - mock.patch("quark.db.api.security_group_rule_find"), - mock.patch("quark.db.api.security_group_rule_delete"), - mock.patch( - "quark.drivers.base.BaseDriver.delete_security_group_rule") - ) as (group_find, rule_find, db_group_delete, driver_group_delete): - group_find.return_value = dbgroup - rule_find.return_value = dbrule - yield db_group_delete, driver_group_delete - - def test_delete_security_group_rule(self): - rule = {'id': 1, 'security_group_id': 1, 'ethertype': 'IPv4', - 'protocol': 6, 'port_range_min': 0, 'port_range_max': 10, - 'direction': 'ingress', 'tenant_id': self.context.tenant_id} - expected = { - 'id': 1, 'ethertype': 'IPv4', 'security_group_id': 1, - 'direction': 'ingress', 'port_range_min': 0, 'port_range_max': 10, - 'remote_group_id': None, 'remote_ip_prefix': None, - 'tenant_id': self.context.tenant_id, 'protocol': 6} - - with self._stubs(dict(rule, group_id=1)) as (db_delete, driver_delete): - self.plugin.delete_security_group_rule(self.context, 1) - self.assertTrue(db_delete.called) - driver_delete.assert_called_once_with(self.context, 1, - expected) - - def test_delete_security_group_rule_rule_not_found(self): - with self._stubs() as (db_delete, driver_delete): - with self.assertRaises(sg_ext.SecurityGroupRuleNotFound): - self.plugin.delete_security_group_rule(self.context, 1) - - def test_delete_security_group_rule_group_not_found(self): - rule = {'id': 1, 'security_group_id': 1, 'ethertype': 'IPv4'} - with self._stubs(dict(rule, group_id=1), - None) as (db_delete, driver_delete): - with self.assertRaises(sg_ext.SecurityGroupNotFound): - self.plugin.delete_security_group_rule(self.context, 1) - - class TestQuarkGetIpPolicies(TestQuarkPlugin): @contextlib.contextmanager def _stubs(self, ip_policy):