From 093b861bb4edcd2af0cdd31351158ba4e2fa2435 Mon Sep 17 00:00:00 2001 From: Rodolfo Alonso Hernandez Date: Tue, 18 Feb 2020 17:08:22 +0000 Subject: [PATCH] Filter by owner SGs when retrieving the SG rules Retrieving the SG rules now is used the admin context. This allows to get all possible rules, independently of the user calling. The filters passed and the RBAC policies filter those results, returning only: - The SG rules belonging to the user. - The SG rules belonging to a SG owned by the user. However, if the SG list is too long, the query can take a lot of time. Instead of this, the filtering is done in the DB query. If no filters are passed to "get_security_group_rules" and the context is not the admin context, only the rules specified in the first paragraph will be retrieved. Because overwriting the method "get_objects" is too complex, an intermediate query is done to retrieve the SG rule IDs. Those IDs will be used as a filter in the "get_objects" call. Conflicts: neutron/objects/securitygroup.py neutron/tests/unit/db/test_securitygroups_db.py neutron/tests/unit/objects/test_securitygroup.py Closes-Bug: #1863201 Change-Id: I25d3da929f8d0b6ee15d7b90ec59b9d58a4ae6a5 (cherry picked from commit d874c46bff7045ba25f5dd6e790f7ddb209cb224) (cherry picked from commit d3905264b7659b1d10a68e3629861d5f0ba13568) (cherry picked from commit 61dc621c1ba40efcedabdfb9f3a1854cea227d2c) --- neutron/db/securitygroups_db.py | 5 ++ neutron/objects/securitygroup.py | 20 +++++++ .../tests/unit/db/test_securitygroups_db.py | 57 +++++++++++++++++++ neutron/tests/unit/objects/test_base.py | 6 +- .../tests/unit/objects/test_securitygroup.py | 40 +++++++++++++ 5 files changed, 127 insertions(+), 1 deletion(-) diff --git a/neutron/db/securitygroups_db.py b/neutron/db/securitygroups_db.py index 664d300b45c..16d04aceaa7 100644 --- a/neutron/db/securitygroups_db.py +++ b/neutron/db/securitygroups_db.py @@ -687,6 +687,11 @@ class SecurityGroupDbMixin(ext_sg.SecurityGroupPluginBase): if project_id: self._ensure_default_security_group(context, project_id) + if not filters and context.project_id and not context.is_admin: + rule_ids = sg_obj.SecurityGroupRule.get_security_group_rule_ids( + context.project_id) + filters = {'id': rule_ids} + # NOTE(slaweq): use admin context here to be able to get all rules # which fits filters' criteria. Later in policy engine rules will be # filtered and only those which are allowed according to policy will diff --git a/neutron/objects/securitygroup.py b/neutron/objects/securitygroup.py index f76093f7539..836e96c688c 100644 --- a/neutron/objects/securitygroup.py +++ b/neutron/objects/securitygroup.py @@ -10,7 +10,9 @@ # License for the specific language governing permissions and limitations # under the License. +from neutron_lib import context as context_lib from oslo_versionedobjects import fields as obj_fields +from sqlalchemy import or_ from neutron.common import utils from neutron.db.models import securitygroup as sg_models @@ -127,3 +129,21 @@ class SecurityGroupRule(base.NeutronDbObject): fields['remote_ip_prefix'] = ( utils.AuthenticIPNetwork(fields['remote_ip_prefix'])) return fields + + @classmethod + def get_security_group_rule_ids(cls, project_id): + """Retrieve all SG rules related to this project_id + + This method returns the SG rule IDs that meet these conditions: + - The rule belongs to this project_id + - The rule belongs to a security group that belongs to the project_id + """ + context = context_lib.get_admin_context() + query = context.session.query(cls.db_model.id) + query = query.join( + SecurityGroup.db_model, + cls.db_model.security_group_id == SecurityGroup.db_model.id) + clauses = or_(SecurityGroup.db_model.project_id == project_id, + cls.db_model.project_id == project_id) + rule_ids = query.filter(clauses).all() + return [rule_id[0] for rule_id in rule_ids] diff --git a/neutron/tests/unit/db/test_securitygroups_db.py b/neutron/tests/unit/db/test_securitygroups_db.py index a5254acfe54..92b2f0259e7 100644 --- a/neutron/tests/unit/db/test_securitygroups_db.py +++ b/neutron/tests/unit/db/test_securitygroups_db.py @@ -460,3 +460,60 @@ class SecurityGroupDbMixinTestCase(testlib_api.SqlTestCase): {'port_range_min': 100, 'port_range_max': 200, 'protocol': '111'}) + + def _create_environment(self): + self.sg = copy.deepcopy(FAKE_SECGROUP) + self.user_ctx = context.Context(user_id='user1', tenant_id='tenant_1', + is_admin=False, overwrite=False) + self.admin_ctx = context.Context(user_id='user2', tenant_id='tenant_2', + is_admin=True, overwrite=False) + self.sg_user = self.mixin.create_security_group( + self.user_ctx, {'security_group': {'name': 'name', + 'tenant_id': 'tenant_1', + 'description': 'fake'}}) + + def test_get_security_group_rules(self): + self._create_environment() + rules_before = self.mixin.get_security_group_rules(self.user_ctx) + + rule = copy.deepcopy(FAKE_SECGROUP_RULE) + rule['security_group_rule']['security_group_id'] = self.sg_user['id'] + rule['security_group_rule']['tenant_id'] = 'tenant_2' + self.mixin.create_security_group_rule(self.admin_ctx, rule) + + rules_after = self.mixin.get_security_group_rules(self.user_ctx) + self.assertEqual(len(rules_before) + 1, len(rules_after)) + for rule in (rule for rule in rules_after if rule not in rules_before): + self.assertEqual('tenant_2', rule['tenant_id']) + + def test_get_security_group_rules_filters_passed(self): + self._create_environment() + filters = {'security_group_id': self.sg_user['id']} + rules_before = self.mixin.get_security_group_rules(self.user_ctx, + filters=filters) + + default_sg = self.mixin.get_security_groups( + self.user_ctx, filters={'name': 'default'})[0] + rule = copy.deepcopy(FAKE_SECGROUP_RULE) + rule['security_group_rule']['security_group_id'] = default_sg['id'] + rule['security_group_rule']['tenant_id'] = 'tenant_1' + self.mixin.create_security_group_rule(self.user_ctx, rule) + + rules_after = self.mixin.get_security_group_rules(self.user_ctx, + filters=filters) + self.assertEqual(rules_before, rules_after) + + def test_get_security_group_rules_admin_context(self): + self._create_environment() + rules_before = self.mixin.get_security_group_rules(self.ctx) + + rule = copy.deepcopy(FAKE_SECGROUP_RULE) + rule['security_group_rule']['security_group_id'] = self.sg_user['id'] + rule['security_group_rule']['tenant_id'] = 'tenant_1' + self.mixin.create_security_group_rule(self.user_ctx, rule) + + rules_after = self.mixin.get_security_group_rules(self.ctx) + self.assertEqual(len(rules_before) + 1, len(rules_after)) + for rule in (rule for rule in rules_after if rule not in rules_before): + self.assertEqual('tenant_1', rule['tenant_id']) + self.assertEqual(self.sg_user['id'], rule['security_group_id']) diff --git a/neutron/tests/unit/objects/test_base.py b/neutron/tests/unit/objects/test_base.py index 716c7d8d3d5..50c696d8016 100644 --- a/neutron/tests/unit/objects/test_base.py +++ b/neutron/tests/unit/objects/test_base.py @@ -1533,8 +1533,12 @@ class BaseDbObjectTestCase(_BaseObjectTestCase, self._router.create() return self._router['id'] - def _create_test_security_group_id(self): + def _create_test_security_group_id(self, fields=None): sg_fields = self.get_random_object_fields(securitygroup.SecurityGroup) + fields = fields or {} + for field, value in ((f, v) for (f, v) in fields.items() if + f in sg_fields): + sg_fields[field] = value _securitygroup = securitygroup.SecurityGroup( self.context, **sg_fields) _securitygroup.create() diff --git a/neutron/tests/unit/objects/test_securitygroup.py b/neutron/tests/unit/objects/test_securitygroup.py index 54b8a04d43b..63c57cd850c 100644 --- a/neutron/tests/unit/objects/test_securitygroup.py +++ b/neutron/tests/unit/objects/test_securitygroup.py @@ -10,6 +10,11 @@ # License for the specific language governing permissions and limitations # under the License. +import collections +import itertools + +from oslo_utils import uuidutils + from neutron.objects import securitygroup from neutron.tests.unit.objects import test_base from neutron.tests.unit import testlib_api @@ -168,3 +173,38 @@ class SecurityGroupRuleDbObjTestCase(test_base.BaseDbObjectTestCase, 'remote_group_id': lambda: self._create_test_security_group_id() }) + + def test_get_security_group_rule_ids(self): + """Retrieve the SG rules associated to a project (see method desc.) + + SG1 (PROJECT1) SG2 (PROJECT2) + rule1a (PROJECT1) rule2a (PROJECT1) + rule1b (PROJECT2) rule2b (PROJECT2) + + query PROJECT1: rule1a, rule1b, rule2a + query PROJECT2: rule1b, rule2a, rule2b + """ + projects = [uuidutils.generate_uuid(), uuidutils.generate_uuid()] + sgs = [ + self._create_test_security_group_id({'project_id': projects[0]}), + self._create_test_security_group_id({'project_id': projects[1]})] + + rules_per_project = collections.defaultdict(list) + rules_per_sg = collections.defaultdict(list) + for project, sg in itertools.product(projects, sgs): + sgrule_fields = self.get_random_object_fields( + securitygroup.SecurityGroupRule) + sgrule_fields['project_id'] = project + sgrule_fields['security_group_id'] = sg + rule = securitygroup.SecurityGroupRule(self.context, + **sgrule_fields) + rule.create() + rules_per_project[project].append(rule.id) + rules_per_sg[sg].append(rule.id) + + for idx in range(2): + rule_ids = securitygroup.SecurityGroupRule.\ + get_security_group_rule_ids(projects[idx]) + rule_ids_ref = set(rules_per_project[projects[idx]]) + rule_ids_ref.update(set(rules_per_sg[sgs[idx]])) + self.assertEqual(rule_ids_ref, set(rule_ids))