diff --git a/os_xenapi/tests/utils/test_common_function.py b/os_xenapi/tests/utils/test_common_function.py index 11726dd..f1f5378 100644 --- a/os_xenapi/tests/utils/test_common_function.py +++ b/os_xenapi/tests/utils/test_common_function.py @@ -18,12 +18,15 @@ from os_xenapi.tests import base from os_xenapi.utils import common_function +class ScpCmdFailure(Exception): + msg_fmt = ("scp failure") + + class CommonUtilFuncTestCase(base.TestCase): def test_get_remote_hostname(self): mock_client = mock.Mock() out = ' \nFake_host_name\n ' - err = '' - mock_client.ssh.return_value = (out, err) + mock_client.ssh.return_value = (0, out, '') hostname = common_function.get_remote_hostname(mock_client) @@ -34,8 +37,7 @@ class CommonUtilFuncTestCase(base.TestCase): mock_client = mock.Mock() out = u'xenbr0 10.71.64.118/20\n' out += 'xenapi 169.254.0.1/16\n' - err = '' - mock_client.ssh.return_value = (out, err) + mock_client.ssh.return_value = (0, out, '') ipv4s = common_function.get_host_ipv4s(mock_client) @@ -153,3 +155,48 @@ class CommonUtilFuncTestCase(base.TestCase): format='%(asctime)s %(name)-12s %(levelname)-8s %(message)s') mock_mkdir.assert_called_once_with('fake_folder/') mock_exists.assert_called_once_with('fake_folder/') + + @mock.patch.object(os.path, 'dirname') + def test_scp_and_execute(self, mock_dirname): + fake_util_dir = 'fake_util_dir' + fake_tmp_sh_dir = 'fake_sh_dir' + fake_script_name = 'fake_script_name' + mock_dirname.return_value = fake_util_dir + mock_client = mock.Mock() + mock_client.ssh.return_value = (0, fake_tmp_sh_dir, '') + fake_tmp_sh_path = fake_tmp_sh_dir + '/' + fake_script_name + fake_sh_shell_dir = fake_util_dir + '/sh_tools/' + + expect_ssh_calls = [mock.call("mktemp -d /tmp/domu_sh.XXXXXX"), + mock.call("mkdir -p " + fake_tmp_sh_dir), + mock.call("chmod +x " + fake_tmp_sh_path), + mock.call(fake_tmp_sh_path), + mock.call("rm -rf " + fake_tmp_sh_dir)] + + common_function.scp_and_execute(mock_client, fake_script_name) + mock_client.ssh.assert_has_calls(expect_ssh_calls) + mock_client.scp.assert_called_once_with( + fake_sh_shell_dir + fake_script_name, fake_tmp_sh_path) + + @mock.patch.object(os.path, 'dirname') + def test_scp_and_execute_exception(self, mock_dirname): + fake_util_dir = 'fake_util_dir' + fake_tmp_sh_dir = 'fake_sh_dir' + fake_script_name = 'fake_script_name' + mock_dirname.return_value = fake_util_dir + mock_client = mock.Mock() + mock_client.ssh.return_value = (0, fake_tmp_sh_dir, '') + fake_tmp_sh_path = fake_tmp_sh_dir + '/' + fake_script_name + fake_sh_shell_dir = fake_util_dir + '/sh_tools/' + mock_client.scp.side_effect = [ScpCmdFailure()] + + expect_ssh_calls = [mock.call("mktemp -d /tmp/domu_sh.XXXXXX"), + mock.call("mkdir -p " + fake_tmp_sh_dir), + mock.call("rm -rf " + fake_tmp_sh_dir)] + + self.assertRaises(ScpCmdFailure, + common_function.scp_and_execute, + mock_client, fake_script_name) + mock_client.ssh.assert_has_calls(expect_ssh_calls) + mock_client.scp.assert_called_once_with( + fake_sh_shell_dir + fake_script_name, fake_tmp_sh_path) diff --git a/os_xenapi/tests/utils/test_conntrack_service.py b/os_xenapi/tests/utils/test_conntrack_service.py index 1f2c532..c705cb5 100644 --- a/os_xenapi/tests/utils/test_conntrack_service.py +++ b/os_xenapi/tests/utils/test_conntrack_service.py @@ -20,7 +20,7 @@ class XenapiConntrackServiceTestCase(base.TestCase): @mock.patch.object(os.path, 'dirname') def test_ensure_conntrack_packages(self, mock_dirname): client = mock.Mock() - client.ssh.return_value = '/tmp/domu_sh.fake' + client.ssh.return_value = (0, '/tmp/domu_sh.fake', '') mock_dirname.return_value = '/fake_dir' ssh_expect_call = [mock.call("mkdir -p /tmp/domu_sh.fake"), mock.call("chmod +x /tmp/domu_sh.fake/" @@ -39,7 +39,7 @@ class XenapiConntrackServiceTestCase(base.TestCase): def test_enable_conntrack_service(self, mock_ensure_conntrack, mock_dir_name): client = mock.Mock() - client.ssh.return_value = '/tmp/domu_sh.fake' + client.ssh.return_value = (0, '/tmp/domu_sh.fake', '') mock_dir_name.return_value = '/fake_dir' ssh_expect_call = [mock.call("mkdir -p /tmp/domu_sh.fake"), mock.call("chmod +x /tmp/domu_sh.fake/" diff --git a/os_xenapi/tests/utils/test_sshclient.py b/os_xenapi/tests/utils/test_sshclient.py index 522a4a4..0ec2f4c 100644 --- a/os_xenapi/tests/utils/test_sshclient.py +++ b/os_xenapi/tests/utils/test_sshclient.py @@ -52,7 +52,7 @@ class SshClientTestCase(base.TestCase): client = sshclient.SSHClient('ip', 'username', password='password', log=mock_log) - return_code, out, err = client.ssh('fake_command', output=True) + return_code, out, err = client.ssh('fake_command') mock_log.debug.assert_called() mock_exec.assert_called() @@ -75,8 +75,9 @@ class SshClientTestCase(base.TestCase): client = sshclient.SSHClient('ip', 'username', password='password', log=mock_log) - self.assertRaises(sshclient.SshExecCmdFailure, client.ssh, - 'fake_command', output=True) + self.assertRaises(sshclient.SshExecCmdFailure, + client.ssh, + 'fake_command') @mock.patch.object(paramiko.SSHClient, 'set_missing_host_key_policy') @mock.patch.object(paramiko.SSHClient, 'connect') @@ -91,7 +92,7 @@ class SshClientTestCase(base.TestCase): client = sshclient.SSHClient('ip', 'username', password='password', log=mock_log) - return_code, out, err = client.ssh('fake_command', output=True, + return_code, out, err = client.ssh('fake_command', allowed_return_codes=[0, 1]) mock_exec.assert_called_once_with('fake_command', get_pty=True) mock_channel.recv_exit_status.assert_called_once() diff --git a/os_xenapi/utils/common_function.py b/os_xenapi/utils/common_function.py index 4dc0ce3..9d4129c 100644 --- a/os_xenapi/utils/common_function.py +++ b/os_xenapi/utils/common_function.py @@ -97,7 +97,7 @@ def get_eth_mac(eth): def get_remote_hostname(host_client): # Get remote host's hostname via the host_client connected to the host. - out, _ = host_client.ssh('hostname') + _, out, _ = host_client.ssh('hostname') hostname = out.strip() return hostname @@ -106,7 +106,7 @@ def get_host_ipv4s(host_client): # Get host's IPs (v4 only) via the host_client connected to the host. ipv4s = [] command = "ip -4 -o addr show scope global | awk '{print $2, $4}'" - out, _ = host_client.ssh(command) + _, out, _ = host_client.ssh(command) for line in out.split('\n'): line = line.strip() if line: @@ -181,7 +181,7 @@ def get_domu_vifs_by_eth(xenserver_client): def scp_and_execute(dom0_client, script_name): # copy script to remote host and execute it - TMP_SH_DIR = dom0_client.ssh("mktemp -d /tmp/domu_sh.XXXXXX", output=True) + _, TMP_SH_DIR, _ = dom0_client.ssh("mktemp -d /tmp/domu_sh.XXXXXX") TMP_SH_PATH = TMP_SH_DIR + '/' + script_name Util_DIR = os.path.dirname( os.path.abspath(inspect.getfile(inspect.currentframe()))) diff --git a/os_xenapi/utils/sshclient.py b/os_xenapi/utils/sshclient.py index 41ea83c..81dae12 100644 --- a/os_xenapi/utils/sshclient.py +++ b/os_xenapi/utils/sshclient.py @@ -45,8 +45,7 @@ class SSHClient(object): def __del__(self): self.client.close() - def ssh(self, command, get_pty=True, output=False, - allowed_return_codes=[0]): + def ssh(self, command, get_pty=True, allowed_return_codes=[0]): if self.log: self.log.debug("Executing command: [%s]" % command) stdin, stdout, stderr = self.client.exec_command(