diff --git a/cinder/tests/test_utils.py b/cinder/tests/test_utils.py index 64d9140365d..5ba73cfaec6 100644 --- a/cinder/tests/test_utils.py +++ b/cinder/tests/test_utils.py @@ -767,7 +767,13 @@ class FakeSSHClient(object): self.transport = FakeTransport() def set_missing_host_key_policy(self, policy): - pass + self.policy = policy + + def load_system_host_keys(self): + self.system_host_keys = 'system_host_keys' + + def load_host_keys(self, hosts_key_file): + self.hosts_key_file = hosts_key_file def connect(self, ip, port=22, username=None, password=None, pkey=None, timeout=10): @@ -776,6 +782,9 @@ class FakeSSHClient(object): def get_transport(self): return self.transport + def get_policy(self): + return self.policy + def close(self): pass @@ -803,6 +812,33 @@ class FakeTransport(object): class SSHPoolTestCase(test.TestCase): """Unit test for SSH Connection Pool.""" + @mock.patch('paramiko.SSHClient') + def test_ssh_key_policy(self, mock_sshclient): + mock_sshclient.return_value = FakeSSHClient() + + # create with customized setting + sshpool = utils.SSHPool("127.0.0.1", 22, 10, + "test", + password="test", + min_size=1, + max_size=1, + missing_key_policy=paramiko.RejectPolicy(), + hosts_key_file='dummy_host_keyfile') + with sshpool.item() as ssh: + self.assertTrue(isinstance(ssh.get_policy(), + paramiko.RejectPolicy)) + self.assertEqual(ssh.hosts_key_file, 'dummy_host_keyfile') + + # create with default setting + sshpool = utils.SSHPool("127.0.0.1", 22, 10, + "test", + password="test", + min_size=1, + max_size=1) + with sshpool.item() as ssh: + self.assertTrue(isinstance(ssh.get_policy(), + paramiko.AutoAddPolicy)) + self.assertEqual(ssh.system_host_keys, 'system_host_keys') @mock.patch('paramiko.RSAKey.from_private_key_file') @mock.patch('paramiko.SSHClient') diff --git a/cinder/utils.py b/cinder/utils.py index a0a3e086f7e..3736163306e 100644 --- a/cinder/utils.py +++ b/cinder/utils.py @@ -189,12 +189,24 @@ class SSHPool(pools.Pool): self.password = password self.conn_timeout = conn_timeout if conn_timeout else None self.privatekey = privatekey + if 'missing_key_policy' in kwargs.keys(): + self.missing_key_policy = kwargs.pop('missing_key_policy') + else: + self.missing_key_policy = paramiko.AutoAddPolicy() + if 'hosts_key_file' in kwargs.keys(): + self.hosts_key_file = kwargs.pop('hosts_key_file') + else: + self.hosts_key_file = None super(SSHPool, self).__init__(*args, **kwargs) def create(self): try: ssh = paramiko.SSHClient() - ssh.set_missing_host_key_policy(paramiko.AutoAddPolicy()) + ssh.set_missing_host_key_policy(self.missing_key_policy) + if not self.hosts_key_file: + ssh.load_system_host_keys() + else: + ssh.load_host_keys(self.hosts_key_file) if self.password: ssh.connect(self.ip, port=self.port,