diff --git a/cinder/ssh_utils.py b/cinder/ssh_utils.py index ee6827284d4..ed389ae8622 100644 --- a/cinder/ssh_utils.py +++ b/cinder/ssh_utils.py @@ -64,6 +64,7 @@ class SSHPool(pools.Pool): self.conn_timeout = conn_timeout if conn_timeout else None self.privatekey = privatekey self.hosts_key_file = None + self.current_size = 0 # Validate good config setting here. # Paramiko handles the case where the file is inaccessible. @@ -99,6 +100,23 @@ class SSHPool(pools.Pool): super(SSHPool, self).__init__(*args, **kwargs) + def __del__(self): + # just return if nothing todo + if not self.current_size: + return + # change the size of the pool to reduce the number + # of elements on the pool via puts. + self.resize(1) + # release all but the last connection using + # get and put to allow any get waiters to complete. + while(self.waiting() or self.current_size > 1): + conn = self.get() + self.put(conn) + # Now free everthing that is left + while(self.free_items): + self.free_items.popleft().close() + self.current_size -= 1 + def create(self): try: ssh = paramiko.SSHClient() @@ -168,6 +186,14 @@ class SSHPool(pools.Pool): self.current_size -= 1 return new_conn + def put(self, conn): + # If we are have more connections than we should just close it + if self.current_size > self.max_size: + conn.close() + self.current_size -= 1 + return + super(SSHPool, self).put(conn) + def remove(self, ssh): """Close an ssh client and remove it from free_items.""" ssh.close() diff --git a/cinder/tests/unit/test_ssh_utils.py b/cinder/tests/unit/test_ssh_utils.py index 393139b0152..75578deb292 100644 --- a/cinder/tests/unit/test_ssh_utils.py +++ b/cinder/tests/unit/test_ssh_utils.py @@ -347,3 +347,72 @@ class SSHPoolTestCase(test.TestCase): self.assertRaises(paramiko.SSHException, sshpool.get) self.assertEqual(0, sshpool.current_size) + + @mock.patch('six.moves.builtins.open') + @mock.patch('os.path.isfile', return_value=True) + @mock.patch('paramiko.RSAKey.from_private_key_file') + @mock.patch('paramiko.SSHClient') + def test_ssh_put(self, mock_sshclient, mock_pkey, mock_isfile, + mock_open): + self.override_config( + 'ssh_hosts_key_file', '/var/lib/cinder/ssh_known_hosts') + + fake_close = mock.MagicMock() + fake = FakeSSHClient() + fake.close = fake_close + mock_sshclient.return_value = fake + + sshpool = ssh_utils.SSHPool("127.0.0.1", 22, 10, + "test", + password="test", + min_size=5, + max_size=5) + self.assertEqual(5, sshpool.current_size) + with sshpool.item(): + pass + self.assertEqual(5, sshpool.current_size) + sshpool.resize(4) + with sshpool.item(): + pass + self.assertEqual(4, sshpool.current_size) + fake_close.asssert_called_once_with(mock.call()) + fake_close.reset_mock() + sshpool.resize(3) + with sshpool.item(): + pass + self.assertEqual(3, sshpool.current_size) + fake_close.asssert_called_once_with(mock.call()) + + @mock.patch('six.moves.builtins.open') + @mock.patch('os.path.isfile', return_value=True) + @mock.patch('paramiko.RSAKey.from_private_key_file') + @mock.patch('paramiko.SSHClient') + def test_ssh_destructor(self, mock_sshclient, mock_pkey, mock_isfile, + mock_open): + self.override_config( + 'ssh_hosts_key_file', '/var/lib/cinder/ssh_known_hosts') + + fake_close = mock.MagicMock() + fake = FakeSSHClient() + fake.close = fake_close + mock_sshclient.return_value = fake + + # create with password + sshpool = ssh_utils.SSHPool("127.0.0.1", 22, 10, + "test", + password="test", + min_size=5, + max_size=5) + self.assertEqual(5, sshpool.current_size) + close_expect_calls = [mock.call(), mock.call(), mock.call(), + mock.call(), mock.call()] + + sshpool = ssh_utils.SSHPool("127.0.0.1", 22, 10, + "test", + password="test", + min_size=5, + max_size=5) + self.assertEqual(fake_close.mock_calls, close_expect_calls) + sshpool = None + self.assertEqual(fake_close.mock_calls, close_expect_calls + + close_expect_calls)