neutron/neutron/tests/unit/objects/test_securitygroup.py
Rodolfo Alonso Hernandez d3905264b7 Filter by owner SGs when retrieving the SG rules
Retrieving the SG rules now is used the admin context. This allows to
get all possible rules, independently of the user calling. The filters
passed and the RBAC policies filter those results, returning only:
- The SG rules belonging to the user.
- The SG rules belonging to a SG owned by the user.

However, if the SG list is too long, the query can take a lot of time.
Instead of this, the filtering is done in the DB query. If no filters
are passed to "get_security_group_rules" and the context is not the
admin context, only the rules specified in the first paragraph will
be retrieved.

Because overwriting the method "get_objects" is too complex, an
intermediate query is done to retrieve the SG rule IDs. Those IDs
will be used as a filter in the "get_objects" call.

Conflicts:
      neutron/db/securitygroups_db.py
      neutron/objects/securitygroup.py

Closes-Bug: #1863201

Change-Id: I25d3da929f8d0b6ee15d7b90ec59b9d58a4ae6a5
(cherry picked from commit d874c46bff)
2020-04-14 18:37:12 +00:00

255 lines
9.1 KiB
Python

# Licensed under the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License. You may obtain
# a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations
# under the License.
import collections
import itertools
import random
from oslo_utils import uuidutils
from neutron.objects import securitygroup
from neutron.tests.unit.objects import test_base
from neutron.tests.unit.objects import test_rbac
from neutron.tests.unit import testlib_api
class _SecurityGroupRBACBase(object):
def get_random_object_fields(self, obj_cls=None):
fields = (super(_SecurityGroupRBACBase, self).
get_random_object_fields(obj_cls))
rnd_actions = self._test_class.db_model.get_valid_actions()
idx = random.randint(0, len(rnd_actions) - 1)
fields['action'] = rnd_actions[idx]
return fields
class SecurityGroupRBACDbObjectTestCase(_SecurityGroupRBACBase,
test_base.BaseDbObjectTestCase,
testlib_api.SqlTestCase):
_test_class = securitygroup.SecurityGroupRBAC
def setUp(self):
super(SecurityGroupRBACDbObjectTestCase, self).setUp()
for obj in self.db_objs:
sg_obj = securitygroup.SecurityGroup(self.context,
id=obj['object_id'],
project_id=obj['project_id'])
sg_obj.create()
def _create_test_security_group_rbac(self):
self.objs[0].create()
return self.objs[0]
def test_object_version_degradation_1_1_to_1_0_no_shared(self):
security_group_rbac_obj = self._create_test_security_group_rbac()
x = security_group_rbac_obj.obj_to_primitive('1.0')
security_group_rbac_dict = x
self.assertNotIn('shared',
security_group_rbac_dict['versioned_object.data'])
class SecurityGroupRBACIfaceObjectTestCase(_SecurityGroupRBACBase,
test_base.BaseObjectIfaceTestCase):
_test_class = securitygroup.SecurityGroupRBAC
class SecurityGroupIfaceObjTestCase(test_rbac.RBACBaseObjectIfaceTestCase):
_test_class = securitygroup.SecurityGroup
class SecurityGroupDbObjTestCase(test_base.BaseDbObjectTestCase,
testlib_api.SqlTestCase):
_test_class = securitygroup.SecurityGroup
def setUp(self):
super(SecurityGroupDbObjTestCase, self).setUp()
# TODO(ihrachys): consider refactoring base test class to set None for
# all nullable fields
for db_obj in self.db_objs:
for rule in db_obj['rules']:
# we either make it null, or create remote groups for each rule
# generated; we picked the former here
rule['remote_group_id'] = None
def test_is_default_True(self):
fields = self.obj_fields[0].copy()
sg_obj = self._make_object(fields)
sg_obj.is_default = True
sg_obj.create()
default_sg_obj = securitygroup.DefaultSecurityGroup.get_object(
self.context,
project_id=sg_obj.project_id,
security_group_id=sg_obj.id)
self.assertIsNotNone(default_sg_obj)
sg_obj = securitygroup.SecurityGroup.get_object(
self.context,
id=sg_obj.id,
project_id=sg_obj.project_id
)
self.assertTrue(sg_obj.is_default)
def test_is_default_False(self):
fields = self.obj_fields[0].copy()
sg_obj = self._make_object(fields)
sg_obj.is_default = False
sg_obj.create()
default_sg_obj = securitygroup.DefaultSecurityGroup.get_object(
self.context,
project_id=sg_obj.project_id,
security_group_id=sg_obj.id)
self.assertIsNone(default_sg_obj)
sg_obj = securitygroup.SecurityGroup.get_object(
self.context,
id=sg_obj.id,
project_id=sg_obj.project_id
)
self.assertFalse(sg_obj.is_default)
def test_get_object_filter_by_is_default(self):
fields = self.obj_fields[0].copy()
sg_obj = self._make_object(fields)
sg_obj.is_default = True
sg_obj.create()
listed_obj = securitygroup.SecurityGroup.get_object(
self.context,
id=sg_obj.id,
project_id=sg_obj.project_id,
is_default=True
)
self.assertIsNotNone(listed_obj)
self.assertEqual(sg_obj, listed_obj)
def test_get_objects_queries_constant(self):
# TODO(electrocucaracha) SecurityGroup is using SecurityGroupRule
# object to reload rules, which costs extra SQL query each time
# is_default field is loaded as part of get_object(s). SecurityGroup
# has defined relationship for SecurityGroupRules, so it should be
# possible to reuse side loaded values fo this. To be reworked in
# follow-up patch.
pass
def test_get_object_no_synth(self):
fields = self.obj_fields[0].copy()
sg_obj = self._make_object(fields)
sg_obj.is_default = True
sg_obj.create()
listed_obj = securitygroup.SecurityGroup.get_object(
self.context,
fields=['id', 'name'],
id=sg_obj.id,
project_id=sg_obj.project_id
)
self.assertIsNotNone(listed_obj)
self.assertEqual(len(sg_obj.rules), 0)
self.assertIsNone(listed_obj.rules)
def test_get_objects_no_synth(self):
fields = self.obj_fields[0].copy()
sg_obj = self._make_object(fields)
sg_obj.is_default = True
sg_obj.create()
listed_objs = securitygroup.SecurityGroup.get_objects(
self.context,
fields=['id', 'name'],
id=sg_obj.id,
project_id=sg_obj.project_id
)
self.assertEqual(len(listed_objs), 1)
self.assertEqual(len(sg_obj.rules), 0)
self.assertIsNone(listed_objs[0].rules)
class DefaultSecurityGroupIfaceObjTestCase(test_base.BaseObjectIfaceTestCase):
_test_class = securitygroup.DefaultSecurityGroup
class DefaultSecurityGroupDbObjTestCase(test_base.BaseDbObjectTestCase,
testlib_api.SqlTestCase):
_test_class = securitygroup.DefaultSecurityGroup
def setUp(self):
super(DefaultSecurityGroupDbObjTestCase, self).setUp()
self.update_obj_fields(
{
'security_group_id':
lambda: self._create_test_security_group_id()
})
class SecurityGroupRuleIfaceObjTestCase(test_base.BaseObjectIfaceTestCase):
_test_class = securitygroup.SecurityGroupRule
class SecurityGroupRuleDbObjTestCase(test_base.BaseDbObjectTestCase,
testlib_api.SqlTestCase):
_test_class = securitygroup.SecurityGroupRule
def setUp(self):
super(SecurityGroupRuleDbObjTestCase, self).setUp()
self.update_obj_fields(
{
'security_group_id':
lambda: self._create_test_security_group_id(),
'remote_group_id':
lambda: self._create_test_security_group_id()
})
def test_get_security_group_rule_ids(self):
"""Retrieve the SG rules associated to a project (see method desc.)
SG1 (PROJECT1) SG2 (PROJECT2)
rule1a (PROJECT1) rule2a (PROJECT1)
rule1b (PROJECT2) rule2b (PROJECT2)
query PROJECT1: rule1a, rule1b, rule2a
query PROJECT2: rule1b, rule2a, rule2b
"""
projects = [uuidutils.generate_uuid(), uuidutils.generate_uuid()]
sgs = [
self._create_test_security_group_id({'project_id': projects[0]}),
self._create_test_security_group_id({'project_id': projects[1]})]
rules_per_project = collections.defaultdict(list)
rules_per_sg = collections.defaultdict(list)
for project, sg in itertools.product(projects, sgs):
sgrule_fields = self.get_random_object_fields(
securitygroup.SecurityGroupRule)
sgrule_fields['project_id'] = project
sgrule_fields['security_group_id'] = sg
rule = securitygroup.SecurityGroupRule(self.context,
**sgrule_fields)
rule.create()
rules_per_project[project].append(rule.id)
rules_per_sg[sg].append(rule.id)
for idx in range(2):
rule_ids = securitygroup.SecurityGroupRule.\
get_security_group_rule_ids(projects[idx])
rule_ids_ref = set(rules_per_project[projects[idx]])
rule_ids_ref.update(set(rules_per_sg[sgs[idx]]))
self.assertEqual(rule_ids_ref, set(rule_ids))