diff --git a/zun/container/docker/driver.py b/zun/container/docker/driver.py index e11f353e4..9f2db7d68 100644 --- a/zun/container/docker/driver.py +++ b/zun/container/docker/driver.py @@ -746,16 +746,15 @@ class DockerDriver(driver.ContainerDriver): cpu_used += float(nanocpus) / 1e9 return cpu_used - def add_security_group(self, context, container, security_group, - sandbox_id=None): + def add_security_group(self, context, container, security_group): security_group_ids = utils.get_security_group_ids( context, [security_group]) + with docker_utils.docker_client() as docker: - network_api = zun_network.api(context=context, docker_api=docker) - sandbox = docker.inspect_container(sandbox_id) - for network in sandbox["NetworkSettings"]["Networks"]: - network_api.add_security_groups_to_ports( - container, security_group_ids) + network_api = zun_network.api(context=context, + docker_api=docker) + network_api.add_security_groups_to_ports(container, + security_group_ids) def get_available_nodes(self): return [self._host.get_hostname()] diff --git a/zun/network/kuryr_network.py b/zun/network/kuryr_network.py index 243e36ec1..d05a7aacc 100644 --- a/zun/network/kuryr_network.py +++ b/zun/network/kuryr_network.py @@ -202,7 +202,9 @@ class KuryrNetwork(network.Network): port_id = addr['port'] port_ids.add(port_id) - neutron_ports = self.neutron.list_ports().get('ports', []) + search_opts = {'tenant_id': self.context.project_id} + neutron_ports = self.neutron.list_ports( + **search_opts).get('ports', []) neutron_ports = [p for p in neutron_ports if p['id'] in port_ids] for port in neutron_ports: if 'security_groups' not in port: diff --git a/zun/tests/tempest/api/clients.py b/zun/tests/tempest/api/clients.py index c0d2f47a0..04741385a 100644 --- a/zun/tests/tempest/api/clients.py +++ b/zun/tests/tempest/api/clients.py @@ -183,6 +183,11 @@ class ZunClient(rest_client.RestClient): return self.get( self.container_uri(container_id, action='stats'), None, **kwargs) + def add_security_group(self, container_id, model, **kwargs): + return self.post( + self.container_uri(container_id, action='add_security_group'), + body=model.to_json(), **kwargs) + def list_services(self, **kwargs): resp, body = self.get(self.services_uri(), **kwargs) return self.deserialize(resp, body, diff --git a/zun/tests/tempest/api/common/datagen.py b/zun/tests/tempest/api/common/datagen.py index 718a54b93..e71d08038 100644 --- a/zun/tests/tempest/api/common/datagen.py +++ b/zun/tests/tempest/api/common/datagen.py @@ -92,3 +92,14 @@ def container_rename_data(**kwargs): model = container_model.ContainerPatchEntity.from_dict(data) return model + + +def container_add_sg_data(**kwargs): + data = { + 'name': 'sg_name', + } + + data.update(kwargs) + model = container_model.ContainerPatchEntity.from_dict(data) + + return model diff --git a/zun/tests/tempest/api/test_containers.py b/zun/tests/tempest/api/test_containers.py index c22b0aac2..771a05281 100644 --- a/zun/tests/tempest/api/test_containers.py +++ b/zun/tests/tempest/api/test_containers.py @@ -158,42 +158,18 @@ class TestContainer(base.BaseZunTest): gen_model = datagen.container_data() delattr(gen_model, 'security_groups') _, model = self._run_container(gen_model=gen_model) - - # find the neutron port of this container - port_ids = set() - for addrs_list in model.addresses.values(): - for addr in addrs_list: - port_id = addr['port'] - port_ids.add(port_id) - self.assertEqual(1, len(port_ids)) - # verify default security_group is applied - port_id = port_ids.pop() - port = self.ports_client.show_port(port_id) - sg_ids = port['port']['security_groups'] - self.assertEqual(1, len(sg_ids)) - sg = self.sgs_client.show_security_group(sg_ids[0]) - self.assertEqual('default', sg['security_group']['name']) + sgs = self._get_all_security_groups(model) + self.assertEqual(1, len(sgs)) + self.assertEqual('default', sgs[0]) @decorators.idempotent_id('f181eeda-a9d1-4b2e-9746-d6634ca81e2f') def test_run_container_with_security_groups(self): sg_name = 'test_sg' self.sgs_client.create_security_group(name=sg_name) _, model = self._run_container(security_groups=[sg_name]) - - # find the neutron port of this container - port_ids = set() - for addrs_list in model.addresses.values(): - for addr in addrs_list: - port_id = addr['port'] - port_ids.add(port_id) - self.assertEqual(1, len(port_ids)) - # verify default security_group is applied - port_id = port_ids.pop() - port = self.ports_client.show_port(port_id) - sg_ids = port['port']['security_groups'] - self.assertEqual(1, len(sg_ids)) - sg = self.sgs_client.show_security_group(sg_ids[0]) - self.assertEqual(sg_name, sg['security_group']['name']) + sgs = self._get_all_security_groups(model) + self.assertEqual(1, len(sgs)) + self.assertEqual(sg_name, sgs[0]) @decorators.idempotent_id('c3f02fa0-fdfb-49fc-95e2-6e4dc982f9be') def test_commit_container(self): @@ -366,6 +342,31 @@ class TestContainer(base.BaseZunTest): self.assertTrue('MEM %' in body) self.assertTrue('BLOCK I/O(B)' in body) + @decorators.idempotent_id('b3b9cf17-82ad-4c1b-a4af-8210a778a33e') + def test_add_sg_to_container(self): + _, model = self._run_container() + sgs = self._get_all_security_groups(model) + self.assertEqual(1, len(sgs)) + self.assertEqual('default', sgs[0]) + + sg_name = 'test_add_sg' + self.sgs_client.create_security_group(name=sg_name) + gen_model = datagen.container_add_sg_data(name=sg_name) + resp, body = self.container_client.add_security_group( + model.uuid, gen_model) + self.assertEqual(202, resp.status) + + def assert_security_group_is_added(): + sgs = self._get_all_security_groups(model) + if len(sgs) == 2: + self.assertTrue('default' in sgs) + self.assertTrue(sg_name in sgs) + return True + else: + return False + + utils.wait_for_condition(assert_security_group_is_added) + def _assert_resource_constraints(self, container, cpu=None, memory=None): if cpu is not None: cpu_quota = container.get('HostConfig').get('CpuQuota') @@ -426,3 +427,25 @@ class TestContainer(base.BaseZunTest): return 'Created' else: return 'Stopped' + + def _get_all_security_groups(self, container): + # find all neutron ports of this container + port_ids = set() + for addrs_list in container.addresses.values(): + for addr in addrs_list: + port_id = addr['port'] + port_ids.add(port_id) + + # find all security groups of this container + sg_ids = set() + for port_id in port_ids: + port = self.ports_client.show_port(port_id) + for sg in port['port']['security_groups']: + sg_ids.add(sg) + + sg_names = [] + for sg_id in sg_ids: + sg = self.sgs_client.show_security_group(sg_id) + sg_names.append(sg['security_group']['name']) + + return sg_names