Browse Source

Merge "Filter by owner SGs when retrieving the SG rules" into stable/rocky

changes/91/720691/1
Zuul 2 months ago
committed by Gerrit Code Review
parent
commit
79be12b706
5 changed files with 127 additions and 1 deletions
  1. +5
    -0
      neutron/db/securitygroups_db.py
  2. +20
    -0
      neutron/objects/securitygroup.py
  3. +57
    -0
      neutron/tests/unit/db/test_securitygroups_db.py
  4. +5
    -1
      neutron/tests/unit/objects/test_base.py
  5. +40
    -0
      neutron/tests/unit/objects/test_securitygroup.py

+ 5
- 0
neutron/db/securitygroups_db.py View File

@@ -680,6 +680,11 @@ class SecurityGroupDbMixin(ext_sg.SecurityGroupPluginBase):
if project_id:
self._ensure_default_security_group(context, project_id)

if not filters and context.project_id and not context.is_admin:
rule_ids = sg_obj.SecurityGroupRule.get_security_group_rule_ids(
context.project_id)
filters = {'id': rule_ids}

# NOTE(slaweq): use admin context here to be able to get all rules
# which fits filters' criteria. Later in policy engine rules will be
# filtered and only those which are allowed according to policy will


+ 20
- 0
neutron/objects/securitygroup.py View File

@@ -10,7 +10,9 @@
# License for the specific language governing permissions and limitations
# under the License.

from neutron_lib import context as context_lib
from oslo_versionedobjects import fields as obj_fields
from sqlalchemy import or_

from neutron.common import utils
from neutron.db.models import securitygroup as sg_models
@@ -127,3 +129,21 @@ class SecurityGroupRule(base.NeutronDbObject):
fields['remote_ip_prefix'] = (
utils.AuthenticIPNetwork(fields['remote_ip_prefix']))
return fields

@classmethod
def get_security_group_rule_ids(cls, project_id):
"""Retrieve all SG rules related to this project_id

This method returns the SG rule IDs that meet these conditions:
- The rule belongs to this project_id
- The rule belongs to a security group that belongs to the project_id
"""
context = context_lib.get_admin_context()
query = context.session.query(cls.db_model.id)
query = query.join(
SecurityGroup.db_model,
cls.db_model.security_group_id == SecurityGroup.db_model.id)
clauses = or_(SecurityGroup.db_model.project_id == project_id,
cls.db_model.project_id == project_id)
rule_ids = query.filter(clauses).all()
return [rule_id[0] for rule_id in rule_ids]

+ 57
- 0
neutron/tests/unit/db/test_securitygroups_db.py View File

@@ -460,3 +460,60 @@ class SecurityGroupDbMixinTestCase(testlib_api.SqlTestCase):
{'port_range_min': 100,
'port_range_max': 200,
'protocol': '111'})

def _create_environment(self):
self.sg = copy.deepcopy(FAKE_SECGROUP)
self.user_ctx = context.Context(user_id='user1', tenant_id='tenant_1',
is_admin=False, overwrite=False)
self.admin_ctx = context.Context(user_id='user2', tenant_id='tenant_2',
is_admin=True, overwrite=False)
self.sg_user = self.mixin.create_security_group(
self.user_ctx, {'security_group': {'name': 'name',
'tenant_id': 'tenant_1',
'description': 'fake'}})

def test_get_security_group_rules(self):
self._create_environment()
rules_before = self.mixin.get_security_group_rules(self.user_ctx)

rule = copy.deepcopy(FAKE_SECGROUP_RULE)
rule['security_group_rule']['security_group_id'] = self.sg_user['id']
rule['security_group_rule']['tenant_id'] = 'tenant_2'
self.mixin.create_security_group_rule(self.admin_ctx, rule)

rules_after = self.mixin.get_security_group_rules(self.user_ctx)
self.assertEqual(len(rules_before) + 1, len(rules_after))
for rule in (rule for rule in rules_after if rule not in rules_before):
self.assertEqual('tenant_2', rule['tenant_id'])

def test_get_security_group_rules_filters_passed(self):
self._create_environment()
filters = {'security_group_id': self.sg_user['id']}
rules_before = self.mixin.get_security_group_rules(self.user_ctx,
filters=filters)

default_sg = self.mixin.get_security_groups(
self.user_ctx, filters={'name': 'default'})[0]
rule = copy.deepcopy(FAKE_SECGROUP_RULE)
rule['security_group_rule']['security_group_id'] = default_sg['id']
rule['security_group_rule']['tenant_id'] = 'tenant_1'
self.mixin.create_security_group_rule(self.user_ctx, rule)

rules_after = self.mixin.get_security_group_rules(self.user_ctx,
filters=filters)
self.assertEqual(rules_before, rules_after)

def test_get_security_group_rules_admin_context(self):
self._create_environment()
rules_before = self.mixin.get_security_group_rules(self.ctx)

rule = copy.deepcopy(FAKE_SECGROUP_RULE)
rule['security_group_rule']['security_group_id'] = self.sg_user['id']
rule['security_group_rule']['tenant_id'] = 'tenant_1'
self.mixin.create_security_group_rule(self.user_ctx, rule)

rules_after = self.mixin.get_security_group_rules(self.ctx)
self.assertEqual(len(rules_before) + 1, len(rules_after))
for rule in (rule for rule in rules_after if rule not in rules_before):
self.assertEqual('tenant_1', rule['tenant_id'])
self.assertEqual(self.sg_user['id'], rule['security_group_id'])

+ 5
- 1
neutron/tests/unit/objects/test_base.py View File

@@ -1555,8 +1555,12 @@ class BaseDbObjectTestCase(_BaseObjectTestCase,
self._router.create()
return self._router['id']

def _create_test_security_group_id(self):
def _create_test_security_group_id(self, fields=None):
sg_fields = self.get_random_object_fields(securitygroup.SecurityGroup)
fields = fields or {}
for field, value in ((f, v) for (f, v) in fields.items() if
f in sg_fields):
sg_fields[field] = value
_securitygroup = securitygroup.SecurityGroup(
self.context, **sg_fields)
_securitygroup.create()


+ 40
- 0
neutron/tests/unit/objects/test_securitygroup.py View File

@@ -10,6 +10,11 @@
# License for the specific language governing permissions and limitations
# under the License.

import collections
import itertools

from oslo_utils import uuidutils

from neutron.objects import securitygroup
from neutron.tests.unit.objects import test_base
from neutron.tests.unit import testlib_api
@@ -168,3 +173,38 @@ class SecurityGroupRuleDbObjTestCase(test_base.BaseDbObjectTestCase,
'remote_group_id':
lambda: self._create_test_security_group_id()
})

def test_get_security_group_rule_ids(self):
"""Retrieve the SG rules associated to a project (see method desc.)

SG1 (PROJECT1) SG2 (PROJECT2)
rule1a (PROJECT1) rule2a (PROJECT1)
rule1b (PROJECT2) rule2b (PROJECT2)

query PROJECT1: rule1a, rule1b, rule2a
query PROJECT2: rule1b, rule2a, rule2b
"""
projects = [uuidutils.generate_uuid(), uuidutils.generate_uuid()]
sgs = [
self._create_test_security_group_id({'project_id': projects[0]}),
self._create_test_security_group_id({'project_id': projects[1]})]

rules_per_project = collections.defaultdict(list)
rules_per_sg = collections.defaultdict(list)
for project, sg in itertools.product(projects, sgs):
sgrule_fields = self.get_random_object_fields(
securitygroup.SecurityGroupRule)
sgrule_fields['project_id'] = project
sgrule_fields['security_group_id'] = sg
rule = securitygroup.SecurityGroupRule(self.context,
**sgrule_fields)
rule.create()
rules_per_project[project].append(rule.id)
rules_per_sg[sg].append(rule.id)

for idx in range(2):
rule_ids = securitygroup.SecurityGroupRule.\
get_security_group_rule_ids(projects[idx])
rule_ids_ref = set(rules_per_project[projects[idx]])
rule_ids_ref.update(set(rules_per_sg[sgs[idx]]))
self.assertEqual(rule_ids_ref, set(rule_ids))

Loading…
Cancel
Save