Fix Port OVO filtering based on security groups

Filtering of port OVO based on ids of security groups which
are used by ports is now available.

Closes-Bug: 1744447

Change-Id: Ie5a3effe668db119d40728be5357f0851bdcebbe
This commit is contained in:
Sławek Kapłoński 2018-01-22 10:38:59 +01:00
parent a12630ee85
commit e7c0ec17df
2 changed files with 56 additions and 20 deletions

View File

@ -384,6 +384,20 @@ class Port(base.NeutronDbObject):
{'port_id': self.id, 'security_group_id': sg_id}
)
@classmethod
def get_objects(cls, context, _pager=None, validate_filters=True,
security_group_ids=None, **kwargs):
if security_group_ids:
ports_with_sg = cls.get_ports_ids_by_security_groups(
context, security_group_ids)
port_ids = kwargs.get("id", [])
if port_ids:
kwargs['id'] = list(set(port_ids) & set(ports_with_sg))
else:
kwargs['id'] = ports_with_sg
return super(Port, cls).get_objects(context, _pager, validate_filters,
**kwargs)
# TODO(rossella_s): get rid of it once we switch the db model to using
# custom types.
@classmethod
@ -444,3 +458,11 @@ class Port(base.NeutronDbObject):
models_v2.Port.network_id == subnet['network_id']
)
return [cls._load_object(context, db_obj) for db_obj in ports.all()]
@classmethod
def get_ports_ids_by_security_groups(cls, context, security_group_ids):
query = context.session.query(sg_models.SecurityGroupPortBinding)
query = query.filter(
sg_models.SecurityGroupPortBinding.security_group_id.in_(
security_group_ids))
return [port_binding['port_id'] for port_binding in query.all()]

View File

@ -225,34 +225,48 @@ class PortDbObjectTestCase(obj_test_base.BaseDbObjectTestCase,
'fixed_ips': {'subnet_id': subnet_id, 'network_id': network_id}})
def test_security_group_ids(self):
sg1_id = self._create_test_security_group_id()
sg2_id = self._create_test_security_group_id()
groups = {sg1_id, sg2_id}
obj = self._make_object(self.obj_fields[0])
obj.security_group_ids = groups
obj.create()
groups = []
objs = []
for i in range(2):
groups.append(self._create_test_security_group_id())
objs.append(self._make_object(self.obj_fields[i]))
objs[i].security_group_ids = {groups[i]}
objs[i].create()
obj = ports.Port.get_object(self.context, id=obj.id)
self.assertEqual(groups, obj.security_group_ids)
self.assertEqual([obj],
self.assertEqual([objs[0]],
ports.Port.get_objects(
self.context, security_group_ids=(sg1_id, )))
self.assertEqual([obj],
self.context, security_group_ids=(groups[0], )))
self.assertEqual([objs[1]],
ports.Port.get_objects(
self.context, security_group_ids=(sg2_id, )))
self.context, security_group_ids=(groups[1], )))
sg3_id = self._create_test_security_group_id()
obj.security_group_ids = {sg3_id}
obj.update()
objs[0].security_group_ids = {sg3_id}
objs[0].update()
obj = ports.Port.get_object(self.context, id=obj.id)
self.assertEqual({sg3_id}, obj.security_group_ids)
objs[0] = ports.Port.get_object(self.context, id=objs[0].id)
self.assertEqual({sg3_id}, objs[0].security_group_ids)
obj.security_group_ids = set()
obj.update()
objs[0].security_group_ids = set()
objs[0].update()
obj = ports.Port.get_object(self.context, id=obj.id)
self.assertFalse(obj.security_group_ids)
objs[0] = ports.Port.get_object(self.context, id=objs[0].id)
self.assertFalse(objs[0].security_group_ids)
def test_security_group_ids_and_port_id(self):
objs = []
group = self._create_test_security_group_id()
for i in range(2):
objs.append(self._make_object(self.obj_fields[i]))
objs[i].security_group_ids = {group}
objs[i].create()
for i in range(2):
self.assertEqual(
[objs[i]],
ports.Port.get_objects(
self.context, id=(objs[i].id, ),
security_group_ids=(group, )))
def test__attach_security_group(self):
obj = self._make_object(self.obj_fields[0])