Merge "Filter by owner SGs when retrieving the SG rules"
This commit is contained in:
commit
f97ae3d6f8
|
@ -719,6 +719,11 @@ class SecurityGroupDbMixin(ext_sg.SecurityGroupPluginBase,
|
||||||
pager = base_obj.Pager(
|
pager = base_obj.Pager(
|
||||||
sorts=sorts, marker=marker, limit=limit, page_reverse=page_reverse)
|
sorts=sorts, marker=marker, limit=limit, page_reverse=page_reverse)
|
||||||
|
|
||||||
|
if not filters and context.project_id and not context.is_admin:
|
||||||
|
rule_ids = sg_obj.SecurityGroupRule.get_security_group_rule_ids(
|
||||||
|
context.project_id)
|
||||||
|
filters = {'id': rule_ids}
|
||||||
|
|
||||||
# NOTE(slaweq): use admin context here to be able to get all rules
|
# NOTE(slaweq): use admin context here to be able to get all rules
|
||||||
# which fits filters' criteria. Later in policy engine rules will be
|
# which fits filters' criteria. Later in policy engine rules will be
|
||||||
# filtered and only those which are allowed according to policy will
|
# filtered and only those which are allowed according to policy will
|
||||||
|
|
|
@ -10,10 +10,12 @@
|
||||||
# License for the specific language governing permissions and limitations
|
# License for the specific language governing permissions and limitations
|
||||||
# under the License.
|
# under the License.
|
||||||
|
|
||||||
|
from neutron_lib import context as context_lib
|
||||||
from neutron_lib.objects import common_types
|
from neutron_lib.objects import common_types
|
||||||
from neutron_lib.utils import net as net_utils
|
from neutron_lib.utils import net as net_utils
|
||||||
from oslo_utils import versionutils
|
from oslo_utils import versionutils
|
||||||
from oslo_versionedobjects import fields as obj_fields
|
from oslo_versionedobjects import fields as obj_fields
|
||||||
|
from sqlalchemy import or_
|
||||||
|
|
||||||
from neutron.db.models import securitygroup as sg_models
|
from neutron.db.models import securitygroup as sg_models
|
||||||
from neutron.db import rbac_db_models
|
from neutron.db import rbac_db_models
|
||||||
|
@ -155,3 +157,21 @@ class SecurityGroupRule(base.NeutronDbObject):
|
||||||
fields['remote_ip_prefix'] = (
|
fields['remote_ip_prefix'] = (
|
||||||
net_utils.AuthenticIPNetwork(fields['remote_ip_prefix']))
|
net_utils.AuthenticIPNetwork(fields['remote_ip_prefix']))
|
||||||
return fields
|
return fields
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_security_group_rule_ids(cls, project_id):
|
||||||
|
"""Retrieve all SG rules related to this project_id
|
||||||
|
|
||||||
|
This method returns the SG rule IDs that meet these conditions:
|
||||||
|
- The rule belongs to this project_id
|
||||||
|
- The rule belongs to a security group that belongs to the project_id
|
||||||
|
"""
|
||||||
|
context = context_lib.get_admin_context()
|
||||||
|
query = context.session.query(cls.db_model.id)
|
||||||
|
query = query.join(
|
||||||
|
SecurityGroup.db_model,
|
||||||
|
cls.db_model.security_group_id == SecurityGroup.db_model.id)
|
||||||
|
clauses = or_(SecurityGroup.db_model.project_id == project_id,
|
||||||
|
cls.db_model.project_id == project_id)
|
||||||
|
rule_ids = query.filter(clauses).all()
|
||||||
|
return [rule_id[0] for rule_id in rule_ids]
|
||||||
|
|
|
@ -479,3 +479,60 @@ class SecurityGroupDbMixinTestCase(testlib_api.SqlTestCase):
|
||||||
{'port_range_min': None,
|
{'port_range_min': None,
|
||||||
'port_range_max': 200,
|
'port_range_max': 200,
|
||||||
'protocol': constants.PROTO_NAME_VRRP})
|
'protocol': constants.PROTO_NAME_VRRP})
|
||||||
|
|
||||||
|
def _create_environment(self):
|
||||||
|
self.sg = copy.deepcopy(FAKE_SECGROUP)
|
||||||
|
self.user_ctx = context.Context(user_id='user1', tenant_id='tenant_1',
|
||||||
|
is_admin=False, overwrite=False)
|
||||||
|
self.admin_ctx = context.Context(user_id='user2', tenant_id='tenant_2',
|
||||||
|
is_admin=True, overwrite=False)
|
||||||
|
self.sg_user = self.mixin.create_security_group(
|
||||||
|
self.user_ctx, {'security_group': {'name': 'name',
|
||||||
|
'tenant_id': 'tenant_1',
|
||||||
|
'description': 'fake'}})
|
||||||
|
|
||||||
|
def test_get_security_group_rules(self):
|
||||||
|
self._create_environment()
|
||||||
|
rules_before = self.mixin.get_security_group_rules(self.user_ctx)
|
||||||
|
|
||||||
|
rule = copy.deepcopy(FAKE_SECGROUP_RULE)
|
||||||
|
rule['security_group_rule']['security_group_id'] = self.sg_user['id']
|
||||||
|
rule['security_group_rule']['tenant_id'] = 'tenant_2'
|
||||||
|
self.mixin.create_security_group_rule(self.admin_ctx, rule)
|
||||||
|
|
||||||
|
rules_after = self.mixin.get_security_group_rules(self.user_ctx)
|
||||||
|
self.assertEqual(len(rules_before) + 1, len(rules_after))
|
||||||
|
for rule in (rule for rule in rules_after if rule not in rules_before):
|
||||||
|
self.assertEqual('tenant_2', rule['tenant_id'])
|
||||||
|
|
||||||
|
def test_get_security_group_rules_filters_passed(self):
|
||||||
|
self._create_environment()
|
||||||
|
filters = {'security_group_id': self.sg_user['id']}
|
||||||
|
rules_before = self.mixin.get_security_group_rules(self.user_ctx,
|
||||||
|
filters=filters)
|
||||||
|
|
||||||
|
default_sg = self.mixin.get_security_groups(
|
||||||
|
self.user_ctx, filters={'name': 'default'})[0]
|
||||||
|
rule = copy.deepcopy(FAKE_SECGROUP_RULE)
|
||||||
|
rule['security_group_rule']['security_group_id'] = default_sg['id']
|
||||||
|
rule['security_group_rule']['tenant_id'] = 'tenant_1'
|
||||||
|
self.mixin.create_security_group_rule(self.user_ctx, rule)
|
||||||
|
|
||||||
|
rules_after = self.mixin.get_security_group_rules(self.user_ctx,
|
||||||
|
filters=filters)
|
||||||
|
self.assertEqual(rules_before, rules_after)
|
||||||
|
|
||||||
|
def test_get_security_group_rules_admin_context(self):
|
||||||
|
self._create_environment()
|
||||||
|
rules_before = self.mixin.get_security_group_rules(self.ctx)
|
||||||
|
|
||||||
|
rule = copy.deepcopy(FAKE_SECGROUP_RULE)
|
||||||
|
rule['security_group_rule']['security_group_id'] = self.sg_user['id']
|
||||||
|
rule['security_group_rule']['tenant_id'] = 'tenant_1'
|
||||||
|
self.mixin.create_security_group_rule(self.user_ctx, rule)
|
||||||
|
|
||||||
|
rules_after = self.mixin.get_security_group_rules(self.ctx)
|
||||||
|
self.assertEqual(len(rules_before) + 1, len(rules_after))
|
||||||
|
for rule in (rule for rule in rules_after if rule not in rules_before):
|
||||||
|
self.assertEqual('tenant_1', rule['tenant_id'])
|
||||||
|
self.assertEqual(self.sg_user['id'], rule['security_group_id'])
|
||||||
|
|
|
@ -1630,8 +1630,12 @@ class BaseDbObjectTestCase(_BaseObjectTestCase,
|
||||||
self._router.create()
|
self._router.create()
|
||||||
return self._router['id']
|
return self._router['id']
|
||||||
|
|
||||||
def _create_test_security_group_id(self):
|
def _create_test_security_group_id(self, fields=None):
|
||||||
sg_fields = self.get_random_object_fields(securitygroup.SecurityGroup)
|
sg_fields = self.get_random_object_fields(securitygroup.SecurityGroup)
|
||||||
|
fields = fields or {}
|
||||||
|
for field, value in ((f, v) for (f, v) in fields.items() if
|
||||||
|
f in sg_fields):
|
||||||
|
sg_fields[field] = value
|
||||||
_securitygroup = securitygroup.SecurityGroup(
|
_securitygroup = securitygroup.SecurityGroup(
|
||||||
self.context, **sg_fields)
|
self.context, **sg_fields)
|
||||||
_securitygroup.create()
|
_securitygroup.create()
|
||||||
|
|
|
@ -10,8 +10,12 @@
|
||||||
# License for the specific language governing permissions and limitations
|
# License for the specific language governing permissions and limitations
|
||||||
# under the License.
|
# under the License.
|
||||||
|
|
||||||
|
import collections
|
||||||
|
import itertools
|
||||||
import random
|
import random
|
||||||
|
|
||||||
|
from oslo_utils import uuidutils
|
||||||
|
|
||||||
from neutron.objects import securitygroup
|
from neutron.objects import securitygroup
|
||||||
from neutron.tests.unit.objects import test_base
|
from neutron.tests.unit.objects import test_base
|
||||||
from neutron.tests.unit.objects import test_rbac
|
from neutron.tests.unit.objects import test_rbac
|
||||||
|
@ -213,3 +217,38 @@ class SecurityGroupRuleDbObjTestCase(test_base.BaseDbObjectTestCase,
|
||||||
'remote_group_id':
|
'remote_group_id':
|
||||||
lambda: self._create_test_security_group_id()
|
lambda: self._create_test_security_group_id()
|
||||||
})
|
})
|
||||||
|
|
||||||
|
def test_get_security_group_rule_ids(self):
|
||||||
|
"""Retrieve the SG rules associated to a project (see method desc.)
|
||||||
|
|
||||||
|
SG1 (PROJECT1) SG2 (PROJECT2)
|
||||||
|
rule1a (PROJECT1) rule2a (PROJECT1)
|
||||||
|
rule1b (PROJECT2) rule2b (PROJECT2)
|
||||||
|
|
||||||
|
query PROJECT1: rule1a, rule1b, rule2a
|
||||||
|
query PROJECT2: rule1b, rule2a, rule2b
|
||||||
|
"""
|
||||||
|
projects = [uuidutils.generate_uuid(), uuidutils.generate_uuid()]
|
||||||
|
sgs = [
|
||||||
|
self._create_test_security_group_id({'project_id': projects[0]}),
|
||||||
|
self._create_test_security_group_id({'project_id': projects[1]})]
|
||||||
|
|
||||||
|
rules_per_project = collections.defaultdict(list)
|
||||||
|
rules_per_sg = collections.defaultdict(list)
|
||||||
|
for project, sg in itertools.product(projects, sgs):
|
||||||
|
sgrule_fields = self.get_random_object_fields(
|
||||||
|
securitygroup.SecurityGroupRule)
|
||||||
|
sgrule_fields['project_id'] = project
|
||||||
|
sgrule_fields['security_group_id'] = sg
|
||||||
|
rule = securitygroup.SecurityGroupRule(self.context,
|
||||||
|
**sgrule_fields)
|
||||||
|
rule.create()
|
||||||
|
rules_per_project[project].append(rule.id)
|
||||||
|
rules_per_sg[sg].append(rule.id)
|
||||||
|
|
||||||
|
for idx in range(2):
|
||||||
|
rule_ids = securitygroup.SecurityGroupRule.\
|
||||||
|
get_security_group_rule_ids(projects[idx])
|
||||||
|
rule_ids_ref = set(rules_per_project[projects[idx]])
|
||||||
|
rule_ids_ref.update(set(rules_per_sg[sgs[idx]]))
|
||||||
|
self.assertEqual(rule_ids_ref, set(rule_ids))
|
||||||
|
|
Loading…
Reference in New Issue