Merge "Filter by owner SGs when retrieving the SG rules" into stable/rocky

This commit is contained in:
Zuul 2020-04-16 05:08:40 +00:00 committed by Gerrit Code Review
commit 79be12b706
5 changed files with 127 additions and 1 deletions

View File

@ -680,6 +680,11 @@ class SecurityGroupDbMixin(ext_sg.SecurityGroupPluginBase):
if project_id:
self._ensure_default_security_group(context, project_id)
if not filters and context.project_id and not context.is_admin:
rule_ids = sg_obj.SecurityGroupRule.get_security_group_rule_ids(
context.project_id)
filters = {'id': rule_ids}
# NOTE(slaweq): use admin context here to be able to get all rules
# which fits filters' criteria. Later in policy engine rules will be
# filtered and only those which are allowed according to policy will

View File

@ -10,7 +10,9 @@
# License for the specific language governing permissions and limitations
# under the License.
from neutron_lib import context as context_lib
from oslo_versionedobjects import fields as obj_fields
from sqlalchemy import or_
from neutron.common import utils
from neutron.db.models import securitygroup as sg_models
@ -127,3 +129,21 @@ class SecurityGroupRule(base.NeutronDbObject):
fields['remote_ip_prefix'] = (
utils.AuthenticIPNetwork(fields['remote_ip_prefix']))
return fields
@classmethod
def get_security_group_rule_ids(cls, project_id):
"""Retrieve all SG rules related to this project_id
This method returns the SG rule IDs that meet these conditions:
- The rule belongs to this project_id
- The rule belongs to a security group that belongs to the project_id
"""
context = context_lib.get_admin_context()
query = context.session.query(cls.db_model.id)
query = query.join(
SecurityGroup.db_model,
cls.db_model.security_group_id == SecurityGroup.db_model.id)
clauses = or_(SecurityGroup.db_model.project_id == project_id,
cls.db_model.project_id == project_id)
rule_ids = query.filter(clauses).all()
return [rule_id[0] for rule_id in rule_ids]

View File

@ -460,3 +460,60 @@ class SecurityGroupDbMixinTestCase(testlib_api.SqlTestCase):
{'port_range_min': 100,
'port_range_max': 200,
'protocol': '111'})
def _create_environment(self):
self.sg = copy.deepcopy(FAKE_SECGROUP)
self.user_ctx = context.Context(user_id='user1', tenant_id='tenant_1',
is_admin=False, overwrite=False)
self.admin_ctx = context.Context(user_id='user2', tenant_id='tenant_2',
is_admin=True, overwrite=False)
self.sg_user = self.mixin.create_security_group(
self.user_ctx, {'security_group': {'name': 'name',
'tenant_id': 'tenant_1',
'description': 'fake'}})
def test_get_security_group_rules(self):
self._create_environment()
rules_before = self.mixin.get_security_group_rules(self.user_ctx)
rule = copy.deepcopy(FAKE_SECGROUP_RULE)
rule['security_group_rule']['security_group_id'] = self.sg_user['id']
rule['security_group_rule']['tenant_id'] = 'tenant_2'
self.mixin.create_security_group_rule(self.admin_ctx, rule)
rules_after = self.mixin.get_security_group_rules(self.user_ctx)
self.assertEqual(len(rules_before) + 1, len(rules_after))
for rule in (rule for rule in rules_after if rule not in rules_before):
self.assertEqual('tenant_2', rule['tenant_id'])
def test_get_security_group_rules_filters_passed(self):
self._create_environment()
filters = {'security_group_id': self.sg_user['id']}
rules_before = self.mixin.get_security_group_rules(self.user_ctx,
filters=filters)
default_sg = self.mixin.get_security_groups(
self.user_ctx, filters={'name': 'default'})[0]
rule = copy.deepcopy(FAKE_SECGROUP_RULE)
rule['security_group_rule']['security_group_id'] = default_sg['id']
rule['security_group_rule']['tenant_id'] = 'tenant_1'
self.mixin.create_security_group_rule(self.user_ctx, rule)
rules_after = self.mixin.get_security_group_rules(self.user_ctx,
filters=filters)
self.assertEqual(rules_before, rules_after)
def test_get_security_group_rules_admin_context(self):
self._create_environment()
rules_before = self.mixin.get_security_group_rules(self.ctx)
rule = copy.deepcopy(FAKE_SECGROUP_RULE)
rule['security_group_rule']['security_group_id'] = self.sg_user['id']
rule['security_group_rule']['tenant_id'] = 'tenant_1'
self.mixin.create_security_group_rule(self.user_ctx, rule)
rules_after = self.mixin.get_security_group_rules(self.ctx)
self.assertEqual(len(rules_before) + 1, len(rules_after))
for rule in (rule for rule in rules_after if rule not in rules_before):
self.assertEqual('tenant_1', rule['tenant_id'])
self.assertEqual(self.sg_user['id'], rule['security_group_id'])

View File

@ -1555,8 +1555,12 @@ class BaseDbObjectTestCase(_BaseObjectTestCase,
self._router.create()
return self._router['id']
def _create_test_security_group_id(self):
def _create_test_security_group_id(self, fields=None):
sg_fields = self.get_random_object_fields(securitygroup.SecurityGroup)
fields = fields or {}
for field, value in ((f, v) for (f, v) in fields.items() if
f in sg_fields):
sg_fields[field] = value
_securitygroup = securitygroup.SecurityGroup(
self.context, **sg_fields)
_securitygroup.create()

View File

@ -10,6 +10,11 @@
# License for the specific language governing permissions and limitations
# under the License.
import collections
import itertools
from oslo_utils import uuidutils
from neutron.objects import securitygroup
from neutron.tests.unit.objects import test_base
from neutron.tests.unit import testlib_api
@ -168,3 +173,38 @@ class SecurityGroupRuleDbObjTestCase(test_base.BaseDbObjectTestCase,
'remote_group_id':
lambda: self._create_test_security_group_id()
})
def test_get_security_group_rule_ids(self):
"""Retrieve the SG rules associated to a project (see method desc.)
SG1 (PROJECT1) SG2 (PROJECT2)
rule1a (PROJECT1) rule2a (PROJECT1)
rule1b (PROJECT2) rule2b (PROJECT2)
query PROJECT1: rule1a, rule1b, rule2a
query PROJECT2: rule1b, rule2a, rule2b
"""
projects = [uuidutils.generate_uuid(), uuidutils.generate_uuid()]
sgs = [
self._create_test_security_group_id({'project_id': projects[0]}),
self._create_test_security_group_id({'project_id': projects[1]})]
rules_per_project = collections.defaultdict(list)
rules_per_sg = collections.defaultdict(list)
for project, sg in itertools.product(projects, sgs):
sgrule_fields = self.get_random_object_fields(
securitygroup.SecurityGroupRule)
sgrule_fields['project_id'] = project
sgrule_fields['security_group_id'] = sg
rule = securitygroup.SecurityGroupRule(self.context,
**sgrule_fields)
rule.create()
rules_per_project[project].append(rule.id)
rules_per_sg[sg].append(rule.id)
for idx in range(2):
rule_ids = securitygroup.SecurityGroupRule.\
get_security_group_rule_ids(projects[idx])
rule_ids_ref = set(rules_per_project[projects[idx]])
rule_ids_ref.update(set(rules_per_sg[sgs[idx]]))
self.assertEqual(rule_ids_ref, set(rule_ids))