diff --git a/satori/bash.py b/satori/bash.py index 5744b7f..0f0dd48 100644 --- a/satori/bash.py +++ b/satori/bash.py @@ -222,6 +222,19 @@ class RemoteShell(ShellMixin): """Return distro, version, architecture.""" return self._client.platform_info + def __del__(self): + """Destructor which should close the connection.""" + self.close() + + def __enter__(self): + """Context manager establish connection.""" + self.connect() + return self + + def __exit__(self, *exc_info): + """Context manager close connection.""" + self.close() + def connect(self): """Connect to the remote host.""" return self._client.connect() diff --git a/satori/ssh.py b/satori/ssh.py index 5a468e4..75d286b 100644 --- a/satori/ssh.py +++ b/satori/ssh.py @@ -148,6 +148,10 @@ class SSH(paramiko.SSHClient): # pylint: disable=R0902 super(SSH, self).__init__() + def __del__(self): + """Destructor to close the connection.""" + self.close() + @classmethod def get_client(cls, *args, **kwargs): """Return an ssh client object from this module.""" @@ -369,7 +373,7 @@ class SSH(paramiko.SSHClient): # pylint: disable=R0902 return False def remote_execute(self, command, with_exit_code=False, - get_pty=False, cwd=None, **kwargs): + get_pty=False, cwd=None, keepalive=True, **kwargs): """Execute an ssh command on a remote host. Tries cert auth first and falls back @@ -425,7 +429,8 @@ class SSH(paramiko.SSHClient): # pylint: disable=R0902 if with_exit_code: results.update({'exit_code': exit_code}) - chan.close() + if not keepalive: + chan.close() if self._handle_tty_required(results, get_pty): return self.remote_execute( @@ -438,7 +443,8 @@ class SSH(paramiko.SSHClient): # pylint: disable=R0902 self.port, exc) raise finally: - self.close() + if not keepalive: + self.close() # Share SSH.__init__'s docstring diff --git a/satori/sysinfo/ohai_solo.py b/satori/sysinfo/ohai_solo.py index 5dc62a4..353cc29 100644 --- a/satori/sysinfo/ohai_solo.py +++ b/satori/sysinfo/ohai_solo.py @@ -39,14 +39,16 @@ def get_systeminfo(ipaddress, config, interactive=False): client = bash.LocalShell() client.host = "localhost" client.port = 0 + perform_install(client) + return system_info(client) else: - client = bash.RemoteShell(ipaddress, username=config['host_username'], - private_key=config['host_key'], - interactive=interactive) - - install_remote(client) - return system_info(client) + with bash.RemoteShell( + ipaddress, username=config['host_username'], + private_key=config['host_key'], + interactive=interactive) as client: + perform_install(client) + return system_info(client) def system_info(client): @@ -88,7 +90,7 @@ def system_info(client): return results -def install_remote(client): +def perform_install(client): """Install ohai-solo on remote system.""" LOG.info("Installing (or updating) ohai-solo on device %s at %s:%d", client.host, client.host, client.port) diff --git a/satori/sysinfo/posh_ohai.py b/satori/sysinfo/posh_ohai.py index 605414e..3a6ee8e 100644 --- a/satori/sysinfo/posh_ohai.py +++ b/satori/sysinfo/posh_ohai.py @@ -39,14 +39,16 @@ def get_systeminfo(ipaddress, config, interactive=False): client = bash.LocalShell() client.host = "localhost" client.port = 0 + perform_install(client) + return system_info(client) else: - client = bash.RemoteShell(ipaddress, username=config['host_username'], - private_key=config['host_key'], - interactive=interactive) - - install_remote(client) - return system_info(client) + with bash.RemoteShell( + ipaddress, username=config['host_username'], + private_key=config['host_key'], + interactive=interactive) as client: + perform_install(client) + return system_info(client) def system_info(client): @@ -81,7 +83,7 @@ def system_info(client): "Target platform was %s", client.platform_info['dist']) -def install_remote(client): +def perform_install(client): """Install PoSh-Ohai on remote system.""" LOG.info("Installing (or updating) PoSh-Ohai on device %s at %s:%d", client.host, client.host, client.port) diff --git a/satori/tests/test_bash.py b/satori/tests/test_bash.py index 5bacb8c..1f41a40 100644 --- a/satori/tests/test_bash.py +++ b/satori/tests/test_bash.py @@ -138,6 +138,25 @@ class TestRemoteShell(TestBashModule): self.assertEqual(self.resultdict, resultdict) +class TestContextManager(utils.TestCase): + + def setUp(self): + super(TestContextManager, self).setUp() + connect_patcher = mock.patch.object(bash.RemoteShell, 'connect') + close_patcher = mock.patch.object(bash.RemoteShell, 'close') + self.mock_connect = connect_patcher.start() + self.mock_close = close_patcher.start() + self.addCleanup(connect_patcher.stop) + self.addCleanup(close_patcher.stop) + + def test_context_manager(self): + with bash.RemoteShell('192.168.2.10') as client: + pass + self.assertTrue(self.mock_connect.call_count == 1) + # >=1 because __del__ (in most python implementations) + # calls close() + self.assertTrue(self.mock_close.call_count >= 1) + class TestIsDistro(TestRemoteShell): def setUp(self): diff --git a/satori/tests/test_sysinfo_ohai_solo.py b/satori/tests/test_sysinfo_ohai_solo.py index fa7884e..50e1e74 100644 --- a/satori/tests/test_sysinfo_ohai_solo.py +++ b/satori/tests/test_sysinfo_ohai_solo.py @@ -25,7 +25,7 @@ class TestOhaiSolo(utils.TestCase): @mock.patch.object(ohai_solo, 'bash') @mock.patch.object(ohai_solo, 'system_info') - @mock.patch.object(ohai_solo, 'install_remote') + @mock.patch.object(ohai_solo, 'perform_install') def test_connect_and_run(self, mock_install, mock_sysinfo, mock_bash): address = "192.0.2.2" config = { @@ -37,13 +37,14 @@ class TestOhaiSolo(utils.TestCase): self.assertTrue(result is mock_sysinfo.return_value) mock_install.assert_called_once_with( - mock_bash.RemoteShell.return_value) + mock_bash.RemoteShell().__enter__.return_value) - mock_bash.RemoteShell.assert_called_with( + mock_bash.RemoteShell.assert_any_call( address, username="bar", private_key="foo", interactive=False) - mock_sysinfo.assert_called_with(mock_bash.RemoteShell.return_value) + mock_sysinfo.assert_called_with( + mock_bash.RemoteShell().__enter__.return_value) class TestOhaiInstall(utils.TestCase): @@ -53,10 +54,10 @@ class TestOhaiInstall(utils.TestCase): self.mock_remotesshclient = mock.MagicMock() self.mock_remotesshclient.is_windows.return_value = False - def test_install_remote_fedora(self): + def test_perform_install_fedora(self): response = {'exit_code': 0, 'foo': 'bar'} self.mock_remotesshclient.execute.return_value = response - result = ohai_solo.install_remote(self.mock_remotesshclient) + result = ohai_solo.perform_install(self.mock_remotesshclient) self.assertEqual(result, response) self.assertEqual(self.mock_remotesshclient.execute.call_count, 3) self.mock_remotesshclient.execute.assert_has_calls([ @@ -68,7 +69,7 @@ class TestOhaiInstall(utils.TestCase): response = {'exit_code': 1, 'stdout': "", "stderr": "FAIL"} self.mock_remotesshclient.execute.return_value = response self.assertRaises(errors.SystemInfoCommandInstallFailed, - ohai_solo.install_remote, self.mock_remotesshclient) + ohai_solo.perform_install, self.mock_remotesshclient) class TestOhaiRemove(utils.TestCase):