diff --git a/neutron/db/securitygroups_db.py b/neutron/db/securitygroups_db.py index 664d300b45c..16d04aceaa7 100644 --- a/neutron/db/securitygroups_db.py +++ b/neutron/db/securitygroups_db.py @@ -687,6 +687,11 @@ class SecurityGroupDbMixin(ext_sg.SecurityGroupPluginBase): if project_id: self._ensure_default_security_group(context, project_id) + if not filters and context.project_id and not context.is_admin: + rule_ids = sg_obj.SecurityGroupRule.get_security_group_rule_ids( + context.project_id) + filters = {'id': rule_ids} + # NOTE(slaweq): use admin context here to be able to get all rules # which fits filters' criteria. Later in policy engine rules will be # filtered and only those which are allowed according to policy will diff --git a/neutron/objects/securitygroup.py b/neutron/objects/securitygroup.py index f76093f7539..836e96c688c 100644 --- a/neutron/objects/securitygroup.py +++ b/neutron/objects/securitygroup.py @@ -10,7 +10,9 @@ # License for the specific language governing permissions and limitations # under the License. +from neutron_lib import context as context_lib from oslo_versionedobjects import fields as obj_fields +from sqlalchemy import or_ from neutron.common import utils from neutron.db.models import securitygroup as sg_models @@ -127,3 +129,21 @@ class SecurityGroupRule(base.NeutronDbObject): fields['remote_ip_prefix'] = ( utils.AuthenticIPNetwork(fields['remote_ip_prefix'])) return fields + + @classmethod + def get_security_group_rule_ids(cls, project_id): + """Retrieve all SG rules related to this project_id + + This method returns the SG rule IDs that meet these conditions: + - The rule belongs to this project_id + - The rule belongs to a security group that belongs to the project_id + """ + context = context_lib.get_admin_context() + query = context.session.query(cls.db_model.id) + query = query.join( + SecurityGroup.db_model, + cls.db_model.security_group_id == SecurityGroup.db_model.id) + clauses = or_(SecurityGroup.db_model.project_id == project_id, + cls.db_model.project_id == project_id) + rule_ids = query.filter(clauses).all() + return [rule_id[0] for rule_id in rule_ids] diff --git a/neutron/tests/unit/db/test_securitygroups_db.py b/neutron/tests/unit/db/test_securitygroups_db.py index a5254acfe54..92b2f0259e7 100644 --- a/neutron/tests/unit/db/test_securitygroups_db.py +++ b/neutron/tests/unit/db/test_securitygroups_db.py @@ -460,3 +460,60 @@ class SecurityGroupDbMixinTestCase(testlib_api.SqlTestCase): {'port_range_min': 100, 'port_range_max': 200, 'protocol': '111'}) + + def _create_environment(self): + self.sg = copy.deepcopy(FAKE_SECGROUP) + self.user_ctx = context.Context(user_id='user1', tenant_id='tenant_1', + is_admin=False, overwrite=False) + self.admin_ctx = context.Context(user_id='user2', tenant_id='tenant_2', + is_admin=True, overwrite=False) + self.sg_user = self.mixin.create_security_group( + self.user_ctx, {'security_group': {'name': 'name', + 'tenant_id': 'tenant_1', + 'description': 'fake'}}) + + def test_get_security_group_rules(self): + self._create_environment() + rules_before = self.mixin.get_security_group_rules(self.user_ctx) + + rule = copy.deepcopy(FAKE_SECGROUP_RULE) + rule['security_group_rule']['security_group_id'] = self.sg_user['id'] + rule['security_group_rule']['tenant_id'] = 'tenant_2' + self.mixin.create_security_group_rule(self.admin_ctx, rule) + + rules_after = self.mixin.get_security_group_rules(self.user_ctx) + self.assertEqual(len(rules_before) + 1, len(rules_after)) + for rule in (rule for rule in rules_after if rule not in rules_before): + self.assertEqual('tenant_2', rule['tenant_id']) + + def test_get_security_group_rules_filters_passed(self): + self._create_environment() + filters = {'security_group_id': self.sg_user['id']} + rules_before = self.mixin.get_security_group_rules(self.user_ctx, + filters=filters) + + default_sg = self.mixin.get_security_groups( + self.user_ctx, filters={'name': 'default'})[0] + rule = copy.deepcopy(FAKE_SECGROUP_RULE) + rule['security_group_rule']['security_group_id'] = default_sg['id'] + rule['security_group_rule']['tenant_id'] = 'tenant_1' + self.mixin.create_security_group_rule(self.user_ctx, rule) + + rules_after = self.mixin.get_security_group_rules(self.user_ctx, + filters=filters) + self.assertEqual(rules_before, rules_after) + + def test_get_security_group_rules_admin_context(self): + self._create_environment() + rules_before = self.mixin.get_security_group_rules(self.ctx) + + rule = copy.deepcopy(FAKE_SECGROUP_RULE) + rule['security_group_rule']['security_group_id'] = self.sg_user['id'] + rule['security_group_rule']['tenant_id'] = 'tenant_1' + self.mixin.create_security_group_rule(self.user_ctx, rule) + + rules_after = self.mixin.get_security_group_rules(self.ctx) + self.assertEqual(len(rules_before) + 1, len(rules_after)) + for rule in (rule for rule in rules_after if rule not in rules_before): + self.assertEqual('tenant_1', rule['tenant_id']) + self.assertEqual(self.sg_user['id'], rule['security_group_id']) diff --git a/neutron/tests/unit/objects/test_base.py b/neutron/tests/unit/objects/test_base.py index 716c7d8d3d5..50c696d8016 100644 --- a/neutron/tests/unit/objects/test_base.py +++ b/neutron/tests/unit/objects/test_base.py @@ -1533,8 +1533,12 @@ class BaseDbObjectTestCase(_BaseObjectTestCase, self._router.create() return self._router['id'] - def _create_test_security_group_id(self): + def _create_test_security_group_id(self, fields=None): sg_fields = self.get_random_object_fields(securitygroup.SecurityGroup) + fields = fields or {} + for field, value in ((f, v) for (f, v) in fields.items() if + f in sg_fields): + sg_fields[field] = value _securitygroup = securitygroup.SecurityGroup( self.context, **sg_fields) _securitygroup.create() diff --git a/neutron/tests/unit/objects/test_securitygroup.py b/neutron/tests/unit/objects/test_securitygroup.py index 54b8a04d43b..63c57cd850c 100644 --- a/neutron/tests/unit/objects/test_securitygroup.py +++ b/neutron/tests/unit/objects/test_securitygroup.py @@ -10,6 +10,11 @@ # License for the specific language governing permissions and limitations # under the License. +import collections +import itertools + +from oslo_utils import uuidutils + from neutron.objects import securitygroup from neutron.tests.unit.objects import test_base from neutron.tests.unit import testlib_api @@ -168,3 +173,38 @@ class SecurityGroupRuleDbObjTestCase(test_base.BaseDbObjectTestCase, 'remote_group_id': lambda: self._create_test_security_group_id() }) + + def test_get_security_group_rule_ids(self): + """Retrieve the SG rules associated to a project (see method desc.) + + SG1 (PROJECT1) SG2 (PROJECT2) + rule1a (PROJECT1) rule2a (PROJECT1) + rule1b (PROJECT2) rule2b (PROJECT2) + + query PROJECT1: rule1a, rule1b, rule2a + query PROJECT2: rule1b, rule2a, rule2b + """ + projects = [uuidutils.generate_uuid(), uuidutils.generate_uuid()] + sgs = [ + self._create_test_security_group_id({'project_id': projects[0]}), + self._create_test_security_group_id({'project_id': projects[1]})] + + rules_per_project = collections.defaultdict(list) + rules_per_sg = collections.defaultdict(list) + for project, sg in itertools.product(projects, sgs): + sgrule_fields = self.get_random_object_fields( + securitygroup.SecurityGroupRule) + sgrule_fields['project_id'] = project + sgrule_fields['security_group_id'] = sg + rule = securitygroup.SecurityGroupRule(self.context, + **sgrule_fields) + rule.create() + rules_per_project[project].append(rule.id) + rules_per_sg[sg].append(rule.id) + + for idx in range(2): + rule_ids = securitygroup.SecurityGroupRule.\ + get_security_group_rule_ids(projects[idx]) + rule_ids_ref = set(rules_per_project[projects[idx]]) + rule_ids_ref.update(set(rules_per_sg[sgs[idx]])) + self.assertEqual(rule_ids_ref, set(rule_ids))