diff --git a/manila/exception.py b/manila/exception.py index 936541b493..eeeff6aca4 100644 --- a/manila/exception.py +++ b/manila/exception.py @@ -477,3 +477,7 @@ class GaneshaCommandFailure(ProcessExecutionError): class InvalidSqliteDB(Invalid): message = _("Invalid Sqlite database.") + + +class SSHException(ManilaException): + message = _("Exception in SSH protocol negotiation or logic.") diff --git a/manila/tests/test_utils.py b/manila/tests/test_utils.py index b53c654442..452a16f27c 100644 --- a/manila/tests/test_utils.py +++ b/manila/tests/test_utils.py @@ -415,7 +415,7 @@ class FakeSSHClient(object): pass def connect(self, ip, port=22, username=None, password=None, - pkey=None, timeout=10): + key_filename=None, look_for_keys=None, timeout=10): pass def get_transport(self): @@ -473,28 +473,44 @@ class SSHPoolTestCase(test.TestCase): fake_ssh_client.connect.assert_called_once_with( "127.0.0.1", port=22, username="test", - password="test", pkey=None, timeout=10) + password="test", key_filename=None, look_for_keys=False, + timeout=10) def test_create_ssh_with_key(self): - key = os.path.expanduser("fake_key") + path_to_private_key = "/fakepath/to/privatekey" fake_ssh_client = mock.Mock() ssh_pool = utils.SSHPool("127.0.0.1", 22, 10, "test", - privatekey="fake_key") + privatekey="/fakepath/to/privatekey") with mock.patch.object(paramiko, "SSHClient", return_value=fake_ssh_client): - with mock.patch.object(paramiko.RSAKey, "from_private_key_file", - return_value=key) as from_private_key_mock: - - ssh_pool.create() - from_private_key_mock.assert_called_once_with(key) - fake_ssh_client.connect.assert_called_once_with( - "127.0.0.1", port=22, username="test", - password=None, pkey=key, timeout=10) + ssh_pool.create() + fake_ssh_client.connect.assert_called_once_with( + "127.0.0.1", port=22, username="test", password=None, + key_filename=path_to_private_key, look_for_keys=False, + timeout=10) def test_create_ssh_with_nothing(self): + fake_ssh_client = mock.Mock() ssh_pool = utils.SSHPool("127.0.0.1", 22, 10, "test") - with mock.patch.object(paramiko, "SSHClient"): - self.assertRaises(paramiko.SSHException, ssh_pool.create) + with mock.patch.object(paramiko, "SSHClient", + return_value=fake_ssh_client): + ssh_pool.create() + fake_ssh_client.connect.assert_called_once_with( + "127.0.0.1", port=22, username="test", password=None, + key_filename=None, look_for_keys=True, + timeout=10) + + def test_create_ssh_error_connecting(self): + attrs = {'connect.side_effect': paramiko.SSHException, } + fake_ssh_client = mock.Mock(**attrs) + ssh_pool = utils.SSHPool("127.0.0.1", 22, 10, "test") + with mock.patch.object(paramiko, "SSHClient", + return_value=fake_ssh_client): + self.assertRaises(exception.SSHException, ssh_pool.create) + fake_ssh_client.connect.assert_called_once_with( + "127.0.0.1", port=22, username="test", password=None, + key_filename=None, look_for_keys=True, + timeout=10) def test_closed_reopend_ssh_connections(self): with mock.patch.object(paramiko, "SSHClient", diff --git a/manila/utils.py b/manila/utils.py index 3b8070f9ef..5c2f2fb698 100644 --- a/manila/utils.py +++ b/manila/utils.py @@ -80,26 +80,27 @@ class SSHPool(pools.Pool): self.login = login self.password = password self.conn_timeout = conn_timeout if conn_timeout else None - self.privatekey = privatekey + self.path_to_private_key = privatekey super(SSHPool, self).__init__(*args, **kwargs) def create(self): + ssh = paramiko.SSHClient() + ssh.set_missing_host_key_policy(paramiko.AutoAddPolicy()) + look_for_keys = True + if self.path_to_private_key: + self.path_to_private_key = os.path.expanduser( + self.path_to_private_key) + look_for_keys = False + elif self.password: + look_for_keys = False try: - ssh = paramiko.SSHClient() - ssh.set_missing_host_key_policy(paramiko.AutoAddPolicy()) - if self.privatekey: - pkfile = os.path.expanduser(self.privatekey) - self.privatekey = paramiko.RSAKey.from_private_key_file(pkfile) - elif not self.password: - msg = _("Specify a password or private_key") - raise exception.ManilaException(msg) ssh.connect(self.ip, port=self.port, username=self.login, password=self.password, - pkey=self.privatekey, + key_filename=self.path_to_private_key, + look_for_keys=look_for_keys, timeout=self.conn_timeout) - # Paramiko by default sets the socket timeout to 0.1 seconds, # ignoring what we set thru the sshclient. This doesn't help for # keeping long lived connections. Hence we have to bypass it, by @@ -113,9 +114,10 @@ class SSHPool(pools.Pool): transport.set_keepalive(self.conn_timeout) return ssh except Exception as e: - msg = _("Error connecting via ssh: %s") % e + msg = _("Check whether private key or password are correctly " + "set. Error connecting via ssh: %s") % e LOG.error(msg) - raise paramiko.SSHException(msg) + raise exception.SSHException(msg) def get(self): """Return an item from the pool, when one is available.