diff --git a/quark/drivers/nvp_driver.py b/quark/drivers/nvp_driver.py index 455e9bb..045b8c9 100644 --- a/quark/drivers/nvp_driver.py +++ b/quark/drivers/nvp_driver.py @@ -64,7 +64,9 @@ class NVPDriver(base.BaseDriver): def __init__(self): self.nvp_connections = [] self.conn_index = 0 - self.max_ports_per_switch = 0 + self.limits = {'max_ports_per_switch': 0, + 'max_rules_per_group': 0, + 'max_rules_per_port': 0} def load_config(self, path): #NOTE(mdietz): What does default_tz actually mean? @@ -72,7 +74,10 @@ class NVPDriver(base.BaseDriver): default_tz = CONF.NVP.default_tz LOG.info("Loading NVP settings " + str(default_tz)) connections = CONF.NVP.controller_connection - self.max_ports_per_switch = CONF.NVP.max_ports_per_switch + self.limits.update({ + 'max_ports_per_switch': CONF.NVP.max_ports_per_switch, + 'max_rules_per_group': CONF.NVP.max_rules_per_group, + 'max_rules_per_port': CONF.NVP.max_rules_per_port}) LOG.info("Loading NVP settings " + str(connections)) for conn in connections: (ip, port, user, pw, req_timeout, @@ -178,10 +183,11 @@ class NVPDriver(base.BaseDriver): profile.display_name(group_name) ingress_rules = group.get('port_ingress_rules', []) egress_rules = group.get('port_egress_rules', []) + if (len(ingress_rules) + len(egress_rules) > - CONF.NVP.max_rules_per_group): - raise sg_ext.qexception.InvalidInput( - error_message="Max rules for group %s" % group_id) + self.limits['max_rules_per_group']): + raise exceptions.DriverLimitReached(limit="rules per group") + if egress_rules: profile.port_egress_rules(egress_rules) if ingress_rules: @@ -207,10 +213,10 @@ class NVPDriver(base.BaseDriver): query.get('logical_port_ingress_rules')) egress_rules = group.get('port_egress_rules', query.get('logical_port_egress_rules')) + if (len(ingress_rules) + len(egress_rules) > - CONF.NVP.max_rules_per_group): - raise sg_ext.qexception.InvalidInput( - error_message="Max rules for group %s" % group_id) + self.limits['max_rules_per_group']): + raise exceptions.DriverLimitReached(limit="rules per group") if group.get('name', None): profile.display_name(group['name']) @@ -221,15 +227,15 @@ class NVPDriver(base.BaseDriver): return profile.update() def _update_security_group_rules(self, context, group_id, rule, operation, - check, raises): + checks): groupd = self._get_security_group(context, group_id) direction, secrule = self._get_security_group_rule_object(context, rule) rulelist = groupd['logical_port_%s_rules' % direction] - if not check(secrule, rulelist): - raise raises - else: - getattr(rulelist, operation)(secrule) + for check in checks: + if not check(secrule, rulelist): + raise checks[check] + getattr(rulelist, operation)(secrule) LOG.debug("%s rule on security group %s" % (operation, groupd['uuid'])) group = {'port_%s_rules' % direction: rulelist} @@ -238,15 +244,19 @@ class NVPDriver(base.BaseDriver): def create_security_group_rule(self, context, group_id, rule): return self._update_security_group_rules( context, group_id, rule, 'append', - lambda x, y: x not in y, - sg_ext.SecurityGroupRuleExists(id=group_id)) + {(lambda x, y: x not in y): + sg_ext.SecurityGroupRuleExists(id=group_id), + (lambda x, y: + self._check_rule_count_per_port(context, group_id) < + self.limits['max_rules_per_port']): + exceptions.DriverLimitReached(limit="rules per port")}) def delete_security_group_rule(self, context, group_id, rule): return self._update_security_group_rules( context, group_id, rule, 'remove', - lambda x, y: x in y, - sg_ext.SecurityGroupRuleNotFound(id="with group_id %s" % - group_id)) + {(lambda x, y: x in y): + sg_ext.SecurityGroupRuleNotFound(id="with group_id %s" % + group_id)}) def _create_or_choose_lswitch(self, context, network_id): switches = self._lswitch_status_query(context, network_id) @@ -278,8 +288,8 @@ class NVPDriver(base.BaseDriver): if switches is not None: for res in switches["results"]: count = res["_relations"]["LogicalSwitchStatus"]["lport_count"] - if self.max_ports_per_switch == 0 or \ - count < self.max_ports_per_switch: + if self.limits['max_ports_per_switch'] == 0 or \ + count < self.limits['max_ports_per_switch']: return res["uuid"] return None @@ -405,14 +415,27 @@ class NVPDriver(base.BaseDriver): "Direction not specified as 'ingress' or 'egress'.") return (direction, secrule) + def _check_rule_count_per_port(self, context, group_id): + connection = self.get_connection() + ports = connection.lswitch_port("*").query().security_profile_uuid( + self._get_security_group_id( + context, group_id)).results().get('results', []) + groups = (set(port.get('security_profiles', [])) for port in ports) + return max(self._check_rule_count_for_groups( + context, (connection.securityprofile(gp).read() for gp in group)) + for group in groups) + + def _check_rule_count_for_groups(self, context, groups): + return sum(len(group['logical_port_ingress_rules']) + + len(group['logical_port_egress_rules']) + for group in groups) + def _get_security_groups_for_port(self, context, groups): - rulecount = 0 - nvp_group_ids = [] - for group in groups: - nvp_group = self._get_security_group(context, group) - rulecount += (len(nvp_group['logical_port_ingress_rules']) + - len(nvp_group['logical_port_egress_rules'])) - nvp_group_ids.append(nvp_group['uuid']) - if rulecount > CONF.NVP.max_rules_per_port: - raise sg_ext.qexception.OverQuota(overs='security rules per port') - return nvp_group_ids + if (self._check_rule_count_for_groups( + context, + (self._get_security_group(context, g) for g in groups)) + > self.limits['max_rules_per_port']): + raise exceptions.DriverLimitReached(limit="rules per port") + + return [self._get_security_group(context, group)['uuid'] + for group in groups] diff --git a/quark/drivers/optimized_nvp_driver.py b/quark/drivers/optimized_nvp_driver.py index e672242..723d6af 100644 --- a/quark/drivers/optimized_nvp_driver.py +++ b/quark/drivers/optimized_nvp_driver.py @@ -115,7 +115,8 @@ class OptimizedNVPDriver(NVPDriver): def _lswitch_select_free(self, context, network_id): query = context.session.query(LSwitch) - query = query.filter(LSwitch.port_count < self.max_ports_per_switch) + query = query.filter(LSwitch.port_count < + self.limits['max_ports_per_switch']) query = query.filter(LSwitch.network_id == network_id) switch = query.order_by(LSwitch.port_count).first() return switch @@ -129,7 +130,7 @@ class OptimizedNVPDriver(NVPDriver): pass def _lswitch_select_open(self, context, network_id=None, **kwargs): - if self.max_ports_per_switch == 0: + if self.limits['max_ports_per_switch'] == 0: switch = self._lswitch_select_first(context, network_id) else: switch = self._lswitch_select_free(context, network_id) @@ -204,6 +205,15 @@ class OptimizedNVPDriver(NVPDriver): 'logical_port_ingress_rules': rulelist['ingress'], 'logical_port_egress_rules': rulelist['egress']} + def _check_rule_count_per_port(self, context, group_id): + ports = context.session.query(models.SecurityGroup).filter( + models.SecurityGroup.id == group_id).first().get('ports', []) + groups = (set(group.id for group in port.get('security_groups', [])) + for port in ports) + return max(self._check_rule_count_for_groups( + context, (self._get_security_group(context, id) for id in g)) + for g in groups) + class LSwitchPort(models.BASEV2, models.HasId): __tablename__ = "quark_nvp_driver_lswitchport" diff --git a/quark/exceptions.py b/quark/exceptions.py index 32c639d..638f42f 100644 --- a/quark/exceptions.py +++ b/quark/exceptions.py @@ -63,3 +63,7 @@ class IPPolicyNotFound(exceptions.NeutronException): class IPPolicyAlreadyExists(exceptions.NeutronException): message = _("IP Policy %(id)s already exists for %(n_id)s") + + +class DriverLimitReached(exceptions.InvalidInput): + message = _("Driver has reached limit on resource '%(limit)s'") diff --git a/quark/plugin.py b/quark/plugin.py index f7c0092..168905e 100644 --- a/quark/plugin.py +++ b/quark/plugin.py @@ -1146,8 +1146,6 @@ class Plugin(neutron_plugin_base_v2.NeutronPluginBaseV2, scope=db_api.ONE) if not group: raise sg_ext.SecurityGroupNotFound(group_id=group_id) - if group.ports: - raise sg_ext.SecurityGroupInUse(id=group_id) self.net_driver.create_security_group_rule( context, @@ -1183,8 +1181,6 @@ class Plugin(neutron_plugin_base_v2.NeutronPluginBaseV2, scope=db_api.ONE) if not group: raise sg_ext.SecurityGroupNotFound(id=id) - if group.ports: - raise sg_ext.SecurityGroupInUse(id=id) self.net_driver.delete_security_group_rule( context, diff --git a/quark/tests/test_nvp_driver.py b/quark/tests/test_nvp_driver.py index 385116e..57ca4d1 100644 --- a/quark/tests/test_nvp_driver.py +++ b/quark/tests/test_nvp_driver.py @@ -49,6 +49,8 @@ class TestNVPDriver(test_base.TestBase): self.profile_id = "12345678-0000-0000-0000-000000000000" self.d_pkg = "quark.drivers.nvp_driver.NVPDriver" self.max_spanning = 3 + self.driver.limits.update({'max_rules_per_group': 3, + 'max_rules_per_port': 2}) def _create_connection(self, switch_count=1, has_switches=False, maxed_ports=False): @@ -68,13 +70,16 @@ class TestNVPDriver(test_base.TestBase): port.delete = mock.Mock(return_value=None) return port - def _create_lport_query(self, switch_count): + def _create_lport_query(self, switch_count, profiles=[]): query = mock.Mock() port_list = {"_relations": {"LogicalSwitchConfig": - {"uuid": self.lswitch_uuid}}} + {"uuid": self.lswitch_uuid, + "security_profiles": profiles}}} port_query = {"results": [port_list], "result_count": switch_count} query.results = mock.Mock(return_value=port_query) + query.security_profile_uuid().results.return_value = { + "results": [{"security_profiles": profiles}]} return query def _create_lswitch(self, switches_available, maxed_ports): @@ -106,12 +111,13 @@ class TestNVPDriver(test_base.TestBase): def _create_security_profile(self): profile = mock.Mock() query = mock.Mock() - query.results = mock.Mock(return_value={'results': [ - {'name': 'foo', 'uuid': self.profile_id, - 'logical_port_ingress_rules': [], - 'logical_port_egress_rules': []}], - 'result_count': 1}) + group = {'name': 'foo', 'uuid': self.profile_id, + 'logical_port_ingress_rules': [], + 'logical_port_egress_rules': []} + query.results = mock.Mock(return_value={'results': [group], + 'result_count': 1}) profile.query = mock.Mock(return_value=query) + profile.read = mock.Mock(return_value=group) return mock.Mock(return_value=profile) def _create_security_rule(self, rule={}): @@ -361,7 +367,7 @@ class TestNVPDriverCreatePort(TestNVPDriver): def test_create_port_switch_exists_spanning(self): with self._stubs(maxed_ports=True, net_details=dict(foo=3)) as (connection): - self.driver.max_ports_per_switch = self.max_spanning + self.driver.limits['max_ports_per_switch'] = self.max_spanning port = self.driver.create_port(self.context, self.net_id, self.port_id) self.assertTrue("uuid" in port) @@ -417,13 +423,12 @@ class TestNVPDriverCreatePort(TestNVPDriver): def test_create_port_with_security_groups_max_rules(self): with self._stubs() as connection: connection.securityprofile = self._create_security_profile() - connection.securityprofile().query().results()[ - 'results'][0].update( - {'logical_port_ingress_rules': [{'ethertype': 'IPv4'}, - {'ethertype': 'IPv6'}], - 'logical_port_egress_rules': [{'ethertype': 'IPv4'}, - {'ethertype': 'IPv6'}]}) - with self.assertRaises(sg_ext.qexception.OverQuota): + connection.securityprofile().read().update( + {'logical_port_ingress_rules': [{'ethertype': 'IPv4'}, + {'ethertype': 'IPv6'}], + 'logical_port_egress_rules': [{'ethertype': 'IPv4'}, + {'ethertype': 'IPv6'}]}) + with self.assertRaises(sg_ext.qexception.InvalidInput): self.driver.create_port( self.context, self.net_id, self.port_id, security_groups=[1], @@ -456,13 +461,12 @@ class TestNVPDriverUpdatePort(TestNVPDriver): def test_update_port_max_rules(self): with self._stubs() as connection: - connection.securityprofile().query().results()[ - 'results'][0].update( - {'logical_port_ingress_rules': [{'ethertype': 'IPv4'}, - {'ethertype': 'IPv6'}], - 'logical_port_egress_rules': [{'ethertype': 'IPv4'}, - {'ethertype': 'IPv6'}]}) - with self.assertRaises(sg_ext.qexception.OverQuota): + connection.securityprofile().read().update( + {'logical_port_ingress_rules': [{'ethertype': 'IPv4'}, + {'ethertype': 'IPv6'}], + 'logical_port_egress_rules': [{'ethertype': 'IPv4'}, + {'ethertype': 'IPv6'}]}) + with self.assertRaises(sg_ext.qexception.InvalidInput): self.driver.update_port( self.context, self.port_id, security_groups=[1], @@ -596,7 +600,7 @@ class TestNVPDriverCreateSecurityGroup(TestNVPDriver): 'tag': self.context.tenant_id}]), ], any_order=True) - def test_security_group_create_rules_over_quota(self): + def test_security_group_create_rules_at_max(self): ingress_rules = [{'ethertype': 'IPv4', 'protocol': 6}, {'ethertype': 'IPv6', 'remote_ip_prefix': '192.168.0.1'}] @@ -604,7 +608,7 @@ class TestNVPDriverCreateSecurityGroup(TestNVPDriver): 'port_range_min': 0, 'port_range_max': 100}, {'ethertype': 'IPv4', 'remote_group_id': 2}] with self._stubs(): - with self.assertRaises(Exception): + with self.assertRaises(sg_ext.qexception.InvalidInput): self.driver.create_security_group( self.context, 'foo', port_ingress_rules=ingress_rules, @@ -660,7 +664,14 @@ class TestNVPDriverUpdateSecurityGroup(TestNVPDriver): mock.call.update()], any_order=True) - def test_security_group_update_rules(self): + def test_security_group_update_not_found(self): + with self._stubs() as connection: + connection.securityprofile().query().results.return_value = \ + {'result_count': 0, 'results': []} + with self.assertRaises(sg_ext.SecurityGroupNotFound): + self.driver.update_security_group(self.context, 1) + + def test_security_group_update_with_rules(self): ingress_rules = [{'ethertype': 'IPv4', 'protocol': 6}, {'ethertype': 'IPv6', 'remote_ip_prefix': '192.168.0.1'}] @@ -678,14 +689,7 @@ class TestNVPDriverUpdateSecurityGroup(TestNVPDriver): mock.call.update(), ], any_order=True) - def test_security_group_update_not_found(self): - with self._stubs() as connection: - connection.securityprofile().query().results.return_value = \ - {'result_count': 0, 'results': []} - with self.assertRaises(sg_ext.SecurityGroupNotFound): - self.driver.update_security_group(self.context, 1) - - def test_security_group_update_rules_over_quota(self): + def test_security_group_update_rules_at_max(self): ingress_rules = [{'ethertype': 'IPv4', 'protocol': 6}, {'ethertype': 'IPv6', 'remote_ip_prefix': '192.168.0.1'}] @@ -693,7 +697,7 @@ class TestNVPDriverUpdateSecurityGroup(TestNVPDriver): 'port_range_min': 0, 'port_range_max': 100}, {'ethertype': 'IPv4', 'remote_group_id': 2}] with self._stubs(): - with self.assertRaises(Exception): + with self.assertRaises(sg_ext.qexception.InvalidInput): self.driver.update_security_group( self.context, 1, port_ingress_rules=ingress_rules, @@ -709,13 +713,16 @@ class TestNVPDriverCreateSecurityGroupRule(TestNVPDriver): connection = self._create_connection() connection.securityprofile = self._create_security_profile() connection.securityrule = self._create_security_rule() + connection.lswitch_port().query.return_value = \ + self._create_lport_query(1, [self.profile_id]) get_connection.return_value = connection yield connection def test_security_rule_create(self): with self._stubs() as connection: self.driver.create_security_group_rule( - self.context, 1, {'ethertype': 'IPv4', 'direction': 'ingress'}) + self.context, 1, + {'ethertype': 'IPv4', 'direction': 'ingress'}) connection.securityprofile.assert_any_calls(self.profile_id) connection.securityprofile().assert_has_calls([ mock.call.port_ingress_rules([{'ethertype': 'IPv4'}]), @@ -724,8 +731,9 @@ class TestNVPDriverCreateSecurityGroupRule(TestNVPDriver): def test_security_rule_create_duplicate(self): with self._stubs() as connection: - connection.securityprofile().query().results()['results'][0][ - 'logical_port_ingress_rules'] = [{'ethertype': 'IPv4'}] + connection.securityprofile().read().update({ + 'logical_port_ingress_rules': [{'ethertype': 'IPv4'}], + 'logical_port_egress_rules': []}) with self.assertRaises(sg_ext.SecurityGroupRuleExists): self.driver.create_security_group_rule( self.context, 1, @@ -740,6 +748,16 @@ class TestNVPDriverCreateSecurityGroupRule(TestNVPDriver): self.context, 1, {'ethertype': 'IPv4', 'direction': 'egress'}) + def test_security_rule_create_over_port(self): + with self._stubs() as connection: + connection.securityprofile().read().update( + {'logical_port_ingress_rules': [1, 2]}) + with self.assertRaises(sg_ext.qexception.InvalidInput): + self.driver.create_security_group_rule( + self.context, 1, + {'ethertype': 'IPv4', 'direction': 'egress'}) + self.assertTrue(connection.lswitch_port().query.called) + class TestNVPDriverDeleteSecurityGroupRule(TestNVPDriver): @contextlib.contextmanager @@ -755,8 +773,7 @@ class TestNVPDriverDeleteSecurityGroupRule(TestNVPDriver): connection = self._create_connection() connection.securityprofile = self._create_security_profile() connection.securityrule = self._create_security_rule() - connection.securityprofile().query().results()[ - 'results'][0].update(rulelist) + connection.securityprofile().read().update(rulelist) get_connection.return_value = connection yield connection @@ -767,7 +784,6 @@ class TestNVPDriverDeleteSecurityGroupRule(TestNVPDriver): ) as connection: self.driver.delete_security_group_rule( self.context, 1, {'ethertype': 'IPv6', 'direction': 'egress'}) - print connection.securityprofile().mock_calls connection.securityprofile.assert_any_call(self.profile_id) connection.securityprofile().assert_has_calls([ mock.call.port_egress_rules([]), diff --git a/quark/tests/test_optimized_nvp_driver.py b/quark/tests/test_optimized_nvp_driver.py index 57aff15..ec0b16b 100644 --- a/quark/tests/test_optimized_nvp_driver.py +++ b/quark/tests/test_optimized_nvp_driver.py @@ -161,7 +161,7 @@ class TestOptimizedNVPDriverCreatePort(TestOptimizedNVPDriver): '''Testing to ensure a switch is made when maxed.''' with self._stubs(maxed_ports=True) as ( connection, create_opt): - self.driver.max_ports_per_switch = self.max_spanning + self.driver.limits['max_ports_per_switch'] = self.max_spanning port = self.driver.create_port(self.context, self.net_id, self.port_id) self.assertTrue("uuid" in port) diff --git a/quark/tests/test_quark_plugin.py b/quark/tests/test_quark_plugin.py index 1bb301e..5e72f01 100644 --- a/quark/tests/test_quark_plugin.py +++ b/quark/tests/test_quark_plugin.py @@ -1836,15 +1836,6 @@ class TestQuarkCreateSecurityGroup(TestQuarkPlugin): self.context, {'security_group': group}) self.assertTrue(group_create.called) - def test_create_security_group_over_quota(self): - group = {'name': 'foo', 'description': 'bar', - 'tenant_id': self.context.tenant_id} - with self._stubs(group, other=1) as group_create: - with self.assertRaises(exceptions.OverQuota): - self.plugin.create_security_group( - self.context, {'security_group': group}) - self.assertTrue(group_create.called) - class TestQuarkDeleteSecurityGroup(TestQuarkPlugin): @contextlib.contextmanager @@ -1982,22 +1973,6 @@ class TestQuarkCreateSecurityGroupRule(TestQuarkPlugin): with self.assertRaises(sg_ext.SecurityGroupNotFound): self._test_create_security_rule(group=None) - def test_create_security_rule_over_quota(self): - rule = self.rule - rule.pop('group') - with self._stubs( - rule, - {'id': 1, 'port_rules': 1} - ) as rule_create: - with self.assertRaises(exceptions.OverQuota): - self.plugin.create_security_group_rule( - self.context, {'security_group_rule': self.rule}) - self.assertTrue(rule_create.called) - - def test_create_security_rule_group_in_use(self): - with self.assertRaises(sg_ext.SecurityGroupInUse): - self._test_create_security_rule(group={'ports': [models.Port()]}) - class TestQuarkDeleteSecurityGroupRule(TestQuarkPlugin): @contextlib.contextmanager @@ -2051,14 +2026,6 @@ class TestQuarkDeleteSecurityGroupRule(TestQuarkPlugin): with self.assertRaises(sg_ext.SecurityGroupNotFound): self.plugin.delete_security_group_rule(self.context, 1) - def test_delete_security_group_rule_group_in_use(self): - with self._stubs( - rule={'id': 1, 'security_group_id': 1}, - group={'id': 1, 'ports': [models.Port()]} - ) as (db_delete, driver_delete): - with self.assertRaises(sg_ext.SecurityGroupInUse): - self.plugin.delete_security_group_rule(self.context, 1) - class TestQuarkGetIpPolicies(TestQuarkPlugin): @contextlib.contextmanager