diff --git a/kmip/services/kmip_client.py b/kmip/services/kmip_client.py index 8acad02..e1168a9 100644 --- a/kmip/services/kmip_client.py +++ b/kmip/services/kmip_client.py @@ -246,8 +246,13 @@ class KMIPProxy(KMIP): def close(self): # Shutdown and close the socket. if self.socket: - self.socket.shutdown(socket.SHUT_RDWR) - self.socket.close() + try: + self.socket.shutdown(socket.SHUT_RDWR) + self.socket.close() + except OSError: + # Can be thrown if the socket is not actually connected to + # anything. In this case, ignore the error. + pass self.socket = None def create(self, object_type, template_attribute, credential=None): diff --git a/kmip/tests/unit/services/test_kmip_client.py b/kmip/tests/unit/services/test_kmip_client.py index 2d30b8d..f4cdecb 100644 --- a/kmip/tests/unit/services/test_kmip_client.py +++ b/kmip/tests/unit/services/test_kmip_client.py @@ -97,6 +97,50 @@ class TestKMIPClient(TestCase): def tearDown(self): super(TestKMIPClient, self).tearDown() + def test_close(self): + """ + Test that calling close on the client works as expected. + """ + c = KMIPProxy( + host="IP_ADDR_1, IP_ADDR_2", + port=9090, + ca_certs=None + ) + c.socket = mock.MagicMock() + c_socket = c.socket + + c.socket.shutdown.assert_not_called() + c.socket.close.assert_not_called() + + c.close() + + self.assertEqual(None, c.socket) + c_socket.shutdown.assert_called_once_with(socket.SHUT_RDWR) + c_socket.close.assert_called_once() + + def test_close_with_shutdown_error(self): + """ + Test that calling close on an unconnected client does not trigger an + exception. + """ + c = KMIPProxy( + host="IP_ADDR_1, IP_ADDR_2", + port=9090, + ca_certs=None + ) + c.socket = mock.MagicMock() + c_socket = c.socket + c.socket.shutdown.side_effect = OSError + + c.socket.shutdown.assert_not_called() + c.socket.close.assert_not_called() + + c.close() + + self.assertEqual(None, c.socket) + c_socket.shutdown.assert_called_once_with(socket.SHUT_RDWR) + c_socket.close.assert_not_called() + # TODO (peter-hamilton) Modify for credential type and/or add new test def test_build_credential(self): username = 'username'