Browse Source

Filter by owner SGs when retrieving the SG rules

Retrieving the SG rules now is used the admin context. This allows to
get all possible rules, independently of the user calling. The filters
passed and the RBAC policies filter those results, returning only:
- The SG rules belonging to the user.
- The SG rules belonging to a SG owned by the user.

However, if the SG list is too long, the query can take a lot of time.
Instead of this, the filtering is done in the DB query. If no filters
are passed to "get_security_group_rules" and the context is not the
admin context, only the rules specified in the first paragraph will
be retrieved.

Because overwriting the method "get_objects" is too complex, an
intermediate query is done to retrieve the SG rule IDs. Those IDs
will be used as a filter in the "get_objects" call.

Conflicts:
      neutron/objects/securitygroup.py
      neutron/tests/unit/db/test_securitygroups_db.py
      neutron/tests/unit/objects/test_securitygroup.py

Closes-Bug: #1863201

Change-Id: I25d3da929f8d0b6ee15d7b90ec59b9d58a4ae6a5
(cherry picked from commit d874c46bff)
(cherry picked from commit d3905264b7)
(cherry picked from commit 61dc621c1b)
changes/86/720686/1
Rodolfo Alonso Hernandez 4 months ago
parent
commit
093b861bb4
5 changed files with 127 additions and 1 deletions
  1. +5
    -0
      neutron/db/securitygroups_db.py
  2. +20
    -0
      neutron/objects/securitygroup.py
  3. +57
    -0
      neutron/tests/unit/db/test_securitygroups_db.py
  4. +5
    -1
      neutron/tests/unit/objects/test_base.py
  5. +40
    -0
      neutron/tests/unit/objects/test_securitygroup.py

+ 5
- 0
neutron/db/securitygroups_db.py View File

@@ -687,6 +687,11 @@ class SecurityGroupDbMixin(ext_sg.SecurityGroupPluginBase):
if project_id:
self._ensure_default_security_group(context, project_id)

if not filters and context.project_id and not context.is_admin:
rule_ids = sg_obj.SecurityGroupRule.get_security_group_rule_ids(
context.project_id)
filters = {'id': rule_ids}

# NOTE(slaweq): use admin context here to be able to get all rules
# which fits filters' criteria. Later in policy engine rules will be
# filtered and only those which are allowed according to policy will


+ 20
- 0
neutron/objects/securitygroup.py View File

@@ -10,7 +10,9 @@
# License for the specific language governing permissions and limitations
# under the License.

from neutron_lib import context as context_lib
from oslo_versionedobjects import fields as obj_fields
from sqlalchemy import or_

from neutron.common import utils
from neutron.db.models import securitygroup as sg_models
@@ -127,3 +129,21 @@ class SecurityGroupRule(base.NeutronDbObject):
fields['remote_ip_prefix'] = (
utils.AuthenticIPNetwork(fields['remote_ip_prefix']))
return fields

@classmethod
def get_security_group_rule_ids(cls, project_id):
"""Retrieve all SG rules related to this project_id

This method returns the SG rule IDs that meet these conditions:
- The rule belongs to this project_id
- The rule belongs to a security group that belongs to the project_id
"""
context = context_lib.get_admin_context()
query = context.session.query(cls.db_model.id)
query = query.join(
SecurityGroup.db_model,
cls.db_model.security_group_id == SecurityGroup.db_model.id)
clauses = or_(SecurityGroup.db_model.project_id == project_id,
cls.db_model.project_id == project_id)
rule_ids = query.filter(clauses).all()
return [rule_id[0] for rule_id in rule_ids]

+ 57
- 0
neutron/tests/unit/db/test_securitygroups_db.py View File

@@ -460,3 +460,60 @@ class SecurityGroupDbMixinTestCase(testlib_api.SqlTestCase):
{'port_range_min': 100,
'port_range_max': 200,
'protocol': '111'})

def _create_environment(self):
self.sg = copy.deepcopy(FAKE_SECGROUP)
self.user_ctx = context.Context(user_id='user1', tenant_id='tenant_1',
is_admin=False, overwrite=False)
self.admin_ctx = context.Context(user_id='user2', tenant_id='tenant_2',
is_admin=True, overwrite=False)
self.sg_user = self.mixin.create_security_group(
self.user_ctx, {'security_group': {'name': 'name',
'tenant_id': 'tenant_1',
'description': 'fake'}})

def test_get_security_group_rules(self):
self._create_environment()
rules_before = self.mixin.get_security_group_rules(self.user_ctx)

rule = copy.deepcopy(FAKE_SECGROUP_RULE)
rule['security_group_rule']['security_group_id'] = self.sg_user['id']
rule['security_group_rule']['tenant_id'] = 'tenant_2'
self.mixin.create_security_group_rule(self.admin_ctx, rule)

rules_after = self.mixin.get_security_group_rules(self.user_ctx)
self.assertEqual(len(rules_before) + 1, len(rules_after))
for rule in (rule for rule in rules_after if rule not in rules_before):
self.assertEqual('tenant_2', rule['tenant_id'])

def test_get_security_group_rules_filters_passed(self):
self._create_environment()
filters = {'security_group_id': self.sg_user['id']}
rules_before = self.mixin.get_security_group_rules(self.user_ctx,
filters=filters)

default_sg = self.mixin.get_security_groups(
self.user_ctx, filters={'name': 'default'})[0]
rule = copy.deepcopy(FAKE_SECGROUP_RULE)
rule['security_group_rule']['security_group_id'] = default_sg['id']
rule['security_group_rule']['tenant_id'] = 'tenant_1'
self.mixin.create_security_group_rule(self.user_ctx, rule)

rules_after = self.mixin.get_security_group_rules(self.user_ctx,
filters=filters)
self.assertEqual(rules_before, rules_after)

def test_get_security_group_rules_admin_context(self):
self._create_environment()
rules_before = self.mixin.get_security_group_rules(self.ctx)

rule = copy.deepcopy(FAKE_SECGROUP_RULE)
rule['security_group_rule']['security_group_id'] = self.sg_user['id']
rule['security_group_rule']['tenant_id'] = 'tenant_1'
self.mixin.create_security_group_rule(self.user_ctx, rule)

rules_after = self.mixin.get_security_group_rules(self.ctx)
self.assertEqual(len(rules_before) + 1, len(rules_after))
for rule in (rule for rule in rules_after if rule not in rules_before):
self.assertEqual('tenant_1', rule['tenant_id'])
self.assertEqual(self.sg_user['id'], rule['security_group_id'])

+ 5
- 1
neutron/tests/unit/objects/test_base.py View File

@@ -1533,8 +1533,12 @@ class BaseDbObjectTestCase(_BaseObjectTestCase,
self._router.create()
return self._router['id']

def _create_test_security_group_id(self):
def _create_test_security_group_id(self, fields=None):
sg_fields = self.get_random_object_fields(securitygroup.SecurityGroup)
fields = fields or {}
for field, value in ((f, v) for (f, v) in fields.items() if
f in sg_fields):
sg_fields[field] = value
_securitygroup = securitygroup.SecurityGroup(
self.context, **sg_fields)
_securitygroup.create()


+ 40
- 0
neutron/tests/unit/objects/test_securitygroup.py View File

@@ -10,6 +10,11 @@
# License for the specific language governing permissions and limitations
# under the License.

import collections
import itertools

from oslo_utils import uuidutils

from neutron.objects import securitygroup
from neutron.tests.unit.objects import test_base
from neutron.tests.unit import testlib_api
@@ -168,3 +173,38 @@ class SecurityGroupRuleDbObjTestCase(test_base.BaseDbObjectTestCase,
'remote_group_id':
lambda: self._create_test_security_group_id()
})

def test_get_security_group_rule_ids(self):
"""Retrieve the SG rules associated to a project (see method desc.)

SG1 (PROJECT1) SG2 (PROJECT2)
rule1a (PROJECT1) rule2a (PROJECT1)
rule1b (PROJECT2) rule2b (PROJECT2)

query PROJECT1: rule1a, rule1b, rule2a
query PROJECT2: rule1b, rule2a, rule2b
"""
projects = [uuidutils.generate_uuid(), uuidutils.generate_uuid()]
sgs = [
self._create_test_security_group_id({'project_id': projects[0]}),
self._create_test_security_group_id({'project_id': projects[1]})]

rules_per_project = collections.defaultdict(list)
rules_per_sg = collections.defaultdict(list)
for project, sg in itertools.product(projects, sgs):
sgrule_fields = self.get_random_object_fields(
securitygroup.SecurityGroupRule)
sgrule_fields['project_id'] = project
sgrule_fields['security_group_id'] = sg
rule = securitygroup.SecurityGroupRule(self.context,
**sgrule_fields)
rule.create()
rules_per_project[project].append(rule.id)
rules_per_sg[sg].append(rule.id)

for idx in range(2):
rule_ids = securitygroup.SecurityGroupRule.\
get_security_group_rule_ids(projects[idx])
rule_ids_ref = set(rules_per_project[projects[idx]])
rule_ids_ref.update(set(rules_per_sg[sgs[idx]]))
self.assertEqual(rule_ids_ref, set(rule_ids))

Loading…
Cancel
Save