Merge "Filter by owner SGs when retrieving the SG rules"

changes/70/709070/2
Zuul 3 years ago committed by Gerrit Code Review
commit f97ae3d6f8
  1. 5
      neutron/db/securitygroups_db.py
  2. 20
      neutron/objects/securitygroup.py
  3. 57
      neutron/tests/unit/db/test_securitygroups_db.py
  4. 6
      neutron/tests/unit/objects/test_base.py
  5. 39
      neutron/tests/unit/objects/test_securitygroup.py

@ -719,6 +719,11 @@ class SecurityGroupDbMixin(ext_sg.SecurityGroupPluginBase,
pager = base_obj.Pager(
sorts=sorts, marker=marker, limit=limit, page_reverse=page_reverse)
if not filters and context.project_id and not context.is_admin:
rule_ids = sg_obj.SecurityGroupRule.get_security_group_rule_ids(
context.project_id)
filters = {'id': rule_ids}
# NOTE(slaweq): use admin context here to be able to get all rules
# which fits filters' criteria. Later in policy engine rules will be
# filtered and only those which are allowed according to policy will

@ -10,10 +10,12 @@
# License for the specific language governing permissions and limitations
# under the License.
from neutron_lib import context as context_lib
from neutron_lib.objects import common_types
from neutron_lib.utils import net as net_utils
from oslo_utils import versionutils
from oslo_versionedobjects import fields as obj_fields
from sqlalchemy import or_
from neutron.db.models import securitygroup as sg_models
from neutron.db import rbac_db_models
@ -155,3 +157,21 @@ class SecurityGroupRule(base.NeutronDbObject):
fields['remote_ip_prefix'] = (
net_utils.AuthenticIPNetwork(fields['remote_ip_prefix']))
return fields
@classmethod
def get_security_group_rule_ids(cls, project_id):
"""Retrieve all SG rules related to this project_id
This method returns the SG rule IDs that meet these conditions:
- The rule belongs to this project_id
- The rule belongs to a security group that belongs to the project_id
"""
context = context_lib.get_admin_context()
query = context.session.query(cls.db_model.id)
query = query.join(
SecurityGroup.db_model,
cls.db_model.security_group_id == SecurityGroup.db_model.id)
clauses = or_(SecurityGroup.db_model.project_id == project_id,
cls.db_model.project_id == project_id)
rule_ids = query.filter(clauses).all()
return [rule_id[0] for rule_id in rule_ids]

@ -479,3 +479,60 @@ class SecurityGroupDbMixinTestCase(testlib_api.SqlTestCase):
{'port_range_min': None,
'port_range_max': 200,
'protocol': constants.PROTO_NAME_VRRP})
def _create_environment(self):
self.sg = copy.deepcopy(FAKE_SECGROUP)
self.user_ctx = context.Context(user_id='user1', tenant_id='tenant_1',
is_admin=False, overwrite=False)
self.admin_ctx = context.Context(user_id='user2', tenant_id='tenant_2',
is_admin=True, overwrite=False)
self.sg_user = self.mixin.create_security_group(
self.user_ctx, {'security_group': {'name': 'name',
'tenant_id': 'tenant_1',
'description': 'fake'}})
def test_get_security_group_rules(self):
self._create_environment()
rules_before = self.mixin.get_security_group_rules(self.user_ctx)
rule = copy.deepcopy(FAKE_SECGROUP_RULE)
rule['security_group_rule']['security_group_id'] = self.sg_user['id']
rule['security_group_rule']['tenant_id'] = 'tenant_2'
self.mixin.create_security_group_rule(self.admin_ctx, rule)
rules_after = self.mixin.get_security_group_rules(self.user_ctx)
self.assertEqual(len(rules_before) + 1, len(rules_after))
for rule in (rule for rule in rules_after if rule not in rules_before):
self.assertEqual('tenant_2', rule['tenant_id'])
def test_get_security_group_rules_filters_passed(self):
self._create_environment()
filters = {'security_group_id': self.sg_user['id']}
rules_before = self.mixin.get_security_group_rules(self.user_ctx,
filters=filters)
default_sg = self.mixin.get_security_groups(
self.user_ctx, filters={'name': 'default'})[0]
rule = copy.deepcopy(FAKE_SECGROUP_RULE)
rule['security_group_rule']['security_group_id'] = default_sg['id']
rule['security_group_rule']['tenant_id'] = 'tenant_1'
self.mixin.create_security_group_rule(self.user_ctx, rule)
rules_after = self.mixin.get_security_group_rules(self.user_ctx,
filters=filters)
self.assertEqual(rules_before, rules_after)
def test_get_security_group_rules_admin_context(self):
self._create_environment()
rules_before = self.mixin.get_security_group_rules(self.ctx)
rule = copy.deepcopy(FAKE_SECGROUP_RULE)
rule['security_group_rule']['security_group_id'] = self.sg_user['id']
rule['security_group_rule']['tenant_id'] = 'tenant_1'
self.mixin.create_security_group_rule(self.user_ctx, rule)
rules_after = self.mixin.get_security_group_rules(self.ctx)
self.assertEqual(len(rules_before) + 1, len(rules_after))
for rule in (rule for rule in rules_after if rule not in rules_before):
self.assertEqual('tenant_1', rule['tenant_id'])
self.assertEqual(self.sg_user['id'], rule['security_group_id'])

@ -1630,8 +1630,12 @@ class BaseDbObjectTestCase(_BaseObjectTestCase,
self._router.create()
return self._router['id']
def _create_test_security_group_id(self):
def _create_test_security_group_id(self, fields=None):
sg_fields = self.get_random_object_fields(securitygroup.SecurityGroup)
fields = fields or {}
for field, value in ((f, v) for (f, v) in fields.items() if
f in sg_fields):
sg_fields[field] = value
_securitygroup = securitygroup.SecurityGroup(
self.context, **sg_fields)
_securitygroup.create()

@ -10,8 +10,12 @@
# License for the specific language governing permissions and limitations
# under the License.
import collections
import itertools
import random
from oslo_utils import uuidutils
from neutron.objects import securitygroup
from neutron.tests.unit.objects import test_base
from neutron.tests.unit.objects import test_rbac
@ -213,3 +217,38 @@ class SecurityGroupRuleDbObjTestCase(test_base.BaseDbObjectTestCase,
'remote_group_id':
lambda: self._create_test_security_group_id()
})
def test_get_security_group_rule_ids(self):
"""Retrieve the SG rules associated to a project (see method desc.)
SG1 (PROJECT1) SG2 (PROJECT2)
rule1a (PROJECT1) rule2a (PROJECT1)
rule1b (PROJECT2) rule2b (PROJECT2)
query PROJECT1: rule1a, rule1b, rule2a
query PROJECT2: rule1b, rule2a, rule2b
"""
projects = [uuidutils.generate_uuid(), uuidutils.generate_uuid()]
sgs = [
self._create_test_security_group_id({'project_id': projects[0]}),
self._create_test_security_group_id({'project_id': projects[1]})]
rules_per_project = collections.defaultdict(list)
rules_per_sg = collections.defaultdict(list)
for project, sg in itertools.product(projects, sgs):
sgrule_fields = self.get_random_object_fields(
securitygroup.SecurityGroupRule)
sgrule_fields['project_id'] = project
sgrule_fields['security_group_id'] = sg
rule = securitygroup.SecurityGroupRule(self.context,
**sgrule_fields)
rule.create()
rules_per_project[project].append(rule.id)
rules_per_sg[sg].append(rule.id)
for idx in range(2):
rule_ids = securitygroup.SecurityGroupRule.\
get_security_group_rule_ids(projects[idx])
rule_ids_ref = set(rules_per_project[projects[idx]])
rule_ids_ref.update(set(rules_per_sg[sgs[idx]]))
self.assertEqual(rule_ids_ref, set(rule_ids))

Loading…
Cancel
Save