diff --git a/rally/common/sshutils.py b/rally/common/sshutils.py index 0c4eccd296..1056059753 100644 --- a/rally/common/sshutils.py +++ b/rally/common/sshutils.py @@ -271,9 +271,8 @@ class SSH(object): client = self._get_client() - sftp = client.open_sftp() - sftp.put(localpath, remotepath) - if mode is None: - mode = 0o777 & os.stat(localpath).st_mode - sftp.chmod(remotepath, mode) - sftp.close() + with client.open_sftp() as sftp: + sftp.put(localpath, remotepath) + if mode is None: + mode = 0o777 & os.stat(localpath).st_mode + sftp.chmod(remotepath, mode) diff --git a/tests/unit/common/test_sshutils.py b/tests/unit/common/test_sshutils.py index 27e230af0b..9fa4b4f43c 100644 --- a/tests/unit/common/test_sshutils.py +++ b/tests/unit/common/test_sshutils.py @@ -274,7 +274,8 @@ class SSHRunTestCase(test.TestCase): @mock.patch("rally.common.sshutils.os.stat") def test_put_file(self, mock_stat): - sftp = self.fake_client.open_sftp.return_value = mock.Mock() + sftp = self.fake_client.open_sftp.return_value = mock.MagicMock() + sftp.__enter__.return_value = sftp mock_stat.return_value = os.stat_result([0o753] + [0] * 9) @@ -283,13 +284,14 @@ class SSHRunTestCase(test.TestCase): sftp.put.assert_called_once_with("localfile", "remotefile") mock_stat.assert_called_once_with("localfile") sftp.chmod.assert_called_once_with("remotefile", 0o753) - sftp.close.assert_called_once_with() + sftp.__exit__.assert_called_once_with(None, None, None) def test_put_file_mode(self): - sftp = self.fake_client.open_sftp.return_value = mock.Mock() + sftp = self.fake_client.open_sftp.return_value = mock.MagicMock() + sftp.__enter__.return_value = sftp self.ssh.put_file("localfile", "remotefile", mode=0o753) sftp.put.assert_called_once_with("localfile", "remotefile") sftp.chmod.assert_called_once_with("remotefile", 0o753) - sftp.close.assert_called_once_with() + sftp.__exit__.assert_called_once_with(None, None, None)