diff --git a/satori/bash.py b/satori/bash.py index ec3464c..9e531ad 100644 --- a/satori/bash.py +++ b/satori/bash.py @@ -15,15 +15,16 @@ """Shell classes for executing commands on a system. -Execute commands over ssh or using python subprocess module. +Execute commands over ssh or using the python subprocess module. """ import logging -import platform import shlex import subprocess +from satori import errors from satori import ssh +from satori import utils LOG = logging.getLogger(__name__) @@ -51,20 +52,42 @@ class ShellMixin(object): pass def is_debian(self): - """Return a boolean indicating whether the system is debian based. + """Indicate whether the system is Debian based. Uses the platform_info property. """ + if not self.platform_info['dist']: + raise errors.UndeterminedPlatform( + 'Unable to determine whether the system is Debian based.') return self.platform_info['dist'].lower() in ['debian', 'ubuntu'] def is_fedora(self): - """Return a boolean indicating whether the system in fedora based. + """Indicate whether the system in Fedora based. Uses the platform info property. """ + if not self.platform_info['dist']: + raise errors.UndeterminedPlatform( + 'Unable to determine whether the system is Fedora based.') return (self.platform_info['dist'].lower() in ['redhat', 'centos', 'fedora', 'el']) + def is_osx(self): + """Indicate whether the system is Apple OSX based.""" + if not self.platform_info['dist']: + raise errors.UndeterminedPlatform( + 'Unable to determine whether the system is OS X based.') + return (self.platform_info['dist'].lower() in + ['darwin', 'macosx']) + + def is_windows(self): + """Indicate whether the system is Windows based.""" + if not self.platform_info['dist']: + raise errors.UndeterminedPlatform( + 'Unable to determine whether the system is Windows based.') + + return self.platform_info['dist'].startswith('win') + class LocalShell(ShellMixin): @@ -84,13 +107,30 @@ class LocalShell(ShellMixin): self.interactive = interactive # TODO(samstav): Implement handle_password_prompt for popen + # properties + self._platform_info = None + @property def platform_info(self): """Return distro, version, and system architecture.""" - return list(platform.dist() + (platform.machine(),)) + if not self._platform_info: + self._platform_info = utils.get_platform_info() + return self._platform_info def execute(self, command, wd=None, with_exit_code=None): - """Execute a command (containing no shell operators) locally.""" + """Execute a command (containing no shell operators) locally. + + :param command: Shell command to be executed. + :param with_exit_code: Include the exit_code in the return body. + Default is False. + :param wd: The child's current directory will be changed + to `wd` before it is executed. Note that this + directory is not considered when searching the + executable, so you can't specify the program's + path relative to this argument + :returns: A dict with stdin, stdout, and + (optionally) the exit code. + """ spipe = subprocess.PIPE cmd = shlex.split(command) diff --git a/satori/errors.py b/satori/errors.py index 9b53b49..5cc0c56 100644 --- a/satori/errors.py +++ b/satori/errors.py @@ -21,6 +21,11 @@ class SatoriException(Exception): """ +class UndeterminedPlatform(SatoriException): + + """The target system's platform could not be determined.""" + + class SatoriInvalidNetloc(SatoriException): """Netloc that cannot be parsed by `urlparse`.""" diff --git a/satori/ssh.py b/satori/ssh.py index e8076d1..9e24148 100644 --- a/satori/ssh.py +++ b/satori/ssh.py @@ -38,6 +38,7 @@ import paramiko import six from satori import errors +from satori import utils LOG = logging.getLogger(__name__) MIN_PASSWORD_PROMPT_LEN = 8 @@ -155,18 +156,22 @@ class SSH(paramiko.SSHClient): # pylint: disable=R0902 @property def platform_info(self): - """Return distro, version, architecture.""" - if not self._platform_info: - command = ('python -c ' - '"""import sys,platform as p;' - 'plat=list(p.dist()+(p.machine(),));' - 'sys.stdout.write(str(plat))"""') + """Return distro, version, architecture. + Requires >= Python 2.4 on remote system. + """ + if not self._platform_info: + + platform_command = "import platform,sys\n" + platform_command += utils.get_source_definition( + utils.get_platform_info) + platform_command += ("\nsys.stdout.write(str(" + "get_platform_info()))\n") + command = 'echo -e """%s""" | python' % platform_command output = self.remote_execute(command) stdout = re.split('\n|\r\n', output['stdout'])[-1].strip() plat = ast.literal_eval(stdout) - self._platform_info = {'dist': plat[0].lower(), 'version': plat[1], - 'arch': plat[3]} + self._platform_info = plat LOG.debug("Remote platform info: %s", self._platform_info) return self._platform_info diff --git a/satori/tests/test_bash.py b/satori/tests/test_bash.py new file mode 100644 index 0000000..2bbd0c8 --- /dev/null +++ b/satori/tests/test_bash.py @@ -0,0 +1,160 @@ +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. +# +# pylint: disable=C0111, C0103, W0212, R0904 +"""Satori SSH Module Tests.""" + +import collections +import unittest + +import mock + +from satori import bash +from satori import errors +from satori.tests import utils + + +class TestBashModule(utils.TestCase): + + def setUp(self): + super(TestBashModule, self).setUp() + testrun = collections.namedtuple( + "TestCmd", ["command", "stdout", "returncode"]) + self.testrun = testrun( + command="echo hello", stdout="hello\n", returncode=0) + self.resultdict = {'stdout': self.testrun.stdout.strip(), + 'stderr': ''} + + +class TestLocalShell(TestBashModule): + + def setUp(self): + super(TestLocalShell, self).setUp() + popen_patcher = mock.patch.object(bash.subprocess, 'Popen') + self.mock_popen = popen_patcher.start() + mock_result = mock.MagicMock() + mock_result.returncode = self.testrun.returncode + self.mock_popen.return_value = mock_result + mock_result.communicate.return_value = (self.testrun.stdout, '') + self.localshell = bash.LocalShell() + self.addCleanup(popen_patcher.stop) + + def test_execute(self): + self.localshell.execute(self.testrun.command) + self.mock_popen.assert_called_once_with( + self.testrun.command.split(), cwd=None, stderr=-1, stdout=-1) + + def test_execute_resultdict(self): + resultdict = self.localshell.execute(self.testrun.command) + self.assertEqual(self.resultdict, resultdict) + + def test_execute_with_exit_code_resultdict(self): + resultdict = self.localshell.execute( + self.testrun.command, with_exit_code=True) + self.resultdict.update({'exit_code': self.testrun.returncode}) + self.assertEqual(self.resultdict, resultdict) + + +class TestLocalPlatformInfo(TestLocalShell): + + def setUp(self): + super(TestLocalPlatformInfo, self).setUp() + + def test_local_platform_info(self): + self.assertTrue(all(k in self.localshell.platform_info + for k in ('dist', 'arch', 'version'))) + + def test_is_debian(self): + self.assertIsInstance(self.localshell.is_debian(), bool) + + def test_is_fedora(self): + self.assertIsInstance(self.localshell.is_fedora(), bool) + + def test_is_osx(self): + self.assertIsInstance(self.localshell.is_windows(), bool) + + def test_is_windows(self): + self.assertIsInstance(self.localshell.is_osx(), bool) + + +class TestLocalPlatformInfoUndetermined(TestLocalShell): + + def setUp(self): + blanks = {'dist': '', 'arch': '', 'version': ''} + pinfo_patcher = mock.patch.object( + bash.LocalShell, 'platform_info', new_callable=mock.PropertyMock) + self.mock_platform_info = pinfo_patcher.start() + self.mock_platform_info.return_value = blanks + super(TestLocalPlatformInfoUndetermined, self).setUp() + self.addCleanup(pinfo_patcher.stop) + + def test_is_debian(self): + self.assertRaises(errors.UndeterminedPlatform, + self.localshell.is_debian) + + def test_is_fedora(self): + self.assertRaises(errors.UndeterminedPlatform, + self.localshell.is_fedora) + + def test_is_osx(self): + self.assertRaises(errors.UndeterminedPlatform, + self.localshell.is_osx) + + def test_is_windows(self): + self.assertRaises(errors.UndeterminedPlatform, + self.localshell.is_windows) + + +class TestRemoteShell(TestBashModule): + + def setUp(self): + super(TestRemoteShell, self).setUp() + execute_patcher = mock.patch.object(bash.ssh.SSH, 'remote_execute') + self.mock_execute = execute_patcher.start() + self.mock_execute.return_value = self.resultdict + self.remoteshell = bash.RemoteShell('203.0.113.1') + self.addCleanup(execute_patcher.stop) + + def test_execute(self): + self.remoteshell.execute(self.testrun.command) + self.mock_execute.assert_called_once_with( + self.testrun.command, wd=None, with_exit_code=None) + + def test_execute_resultdict(self): + resultdict = self.remoteshell.execute(self.testrun.command) + self.assertEqual(self.resultdict, resultdict) + + def test_execute_with_exit_code_resultdict(self): + resultdict = self.remoteshell.execute( + self.testrun.command, with_exit_code=True) + self.resultdict.update({'exit_code': self.testrun.returncode}) + self.assertEqual(self.resultdict, resultdict) + + +class TestIsDistro(TestRemoteShell): + + def setUp(self): + super(TestIsDistro, self).setUp() + self.platformdict = self.resultdict.copy() + self.platformdict['stdout'] = str(bash.LocalShell().platform_info) + + def test_remote_platform_info(self): + self.mock_execute.return_value = self.platformdict + result = self.remoteshell.platform_info + self.assertIsInstance(result, dict) + self.assertTrue(all(k in result + for k in ('arch', 'dist', 'version'))) + assert self.mock_execute.called + + +if __name__ == "__main__": + unittest.main() diff --git a/satori/tests/test_ssh.py b/satori/tests/test_ssh.py index 9ba9226..903f8ac 100644 --- a/satori/tests/test_ssh.py +++ b/satori/tests/test_ssh.py @@ -681,8 +681,7 @@ class TestRemoteExecute(SSHTestBase): expected_result = dict(zip(fields, [v.lower() for v in platinfo])) expected_result.pop('remove') self.mock_chan.makefile.side_effect = lambda x: self.mkfile( - x, stdoutput=str(platinfo)) - self.assertEqual(expected_result, self.client.platform_info) + x, stdoutput=str(expected_result)) self.assertEqual(expected_result, self.client.platform_info) diff --git a/satori/tests/test_utils.py b/satori/tests/test_utils.py index d6c384f..cd3465a 100644 --- a/satori/tests/test_utils.py +++ b/satori/tests/test_utils.py @@ -26,13 +26,13 @@ class SomeTZ(datetime.tzinfo): """A random timezone.""" def utcoffset(self, dt): - return datetime.timedelta(minutes=45) + return datetime.timedelta(minutes=45) def tzname(self, dt): - return "STZ" + return "STZ" def dst(self, dt): - return datetime.timedelta(0) + return datetime.timedelta(0) class TestTimeUtils(unittest.TestCase): @@ -69,5 +69,63 @@ class TestTimeUtils(unittest.TestCase): self.assertEqual(result, datetime.datetime(1970, 2, 1, 11, 2, 3, 0)) +class TestGetSource(unittest.TestCase): + + def setUp(self): + self.function_signature = "def get_my_source_oneline_docstring(self):" + self.function_oneline_docstring = '"""A beautiful docstring."""' + self.function_multiline_docstring = ('"""A beautiful docstring.\n\n' + 'Is a terrible thing to ' + 'waste.\n"""') + self.function_body = ['the_problem = "not the problem"', + 'return the_problem'] + + def get_my_source_oneline_docstring(self): + """A beautiful docstring.""" + the_problem = "not the problem" + return the_problem + + def get_my_source_multiline_docstring(self): + """A beautiful docstring. + + Is a terrible thing to waste. + """ + the_problem = "not the problem" + return the_problem + + def test_get_source(self): + nab = utils.get_source_body(self.get_my_source_oneline_docstring) + self.assertEqual("\n".join(self.function_body), nab) + + def test_get_source_with_docstring(self): + nab = utils.get_source_body(self.get_my_source_oneline_docstring, + with_docstring=True) + copy = self.function_oneline_docstring + "\n" + "\n".join( + self.function_body) + self.assertEqual(copy, nab) + + def test_get_source_with_multiline_docstring(self): + nab = utils.get_source_body(self.get_my_source_multiline_docstring, + with_docstring=True) + copy = (self.function_multiline_docstring + "\n" + "\n".join( + self.function_body)) + self.assertEqual(copy, nab) + + def test_get_definition(self): + nab = utils.get_source_definition( + self.get_my_source_oneline_docstring) + copy = "%s\n \n %s" % (self.function_signature, + "\n ".join(self.function_body)) + self.assertEqual(copy, nab) + + def test_get_definition_with_docstring(self): + nab = utils.get_source_definition( + self.get_my_source_oneline_docstring, with_docstring=True) + copy = "%s\n %s\n %s" % (self.function_signature, + self.function_oneline_docstring, + "\n ".join(self.function_body)) + self.assertEqual(copy, nab) + + if __name__ == '__main__': unittest.main() diff --git a/satori/utils.py b/satori/utils.py index e380b13..1cb8d74 100644 --- a/satori/utils.py +++ b/satori/utils.py @@ -17,7 +17,9 @@ """ import datetime +import inspect import logging +import platform import socket import sys import time @@ -135,3 +137,85 @@ def get_local_ips(): LOG.debug("Error in getaddrinfo: %s", exc) return list(set(list1 + list2 + defaults)) + + +def get_platform_info(): + """Return a dictionary with distro, version, and system architecture. + + Requires >= Python 2.4 (2004) + + Supports most Linux distros, Mac OSX, and Windows. + + Example return value on Mac OSX: + + {'arch': '64bit', 'version': '10.8.5', 'dist': 'darwin'} + + """ + pin = list(platform.dist() + (platform.machine(),)) + pinfodict = {'dist': pin[0], 'version': pin[1], 'arch': pin[3]} + if not pinfodict['dist'] or not pinfodict['version']: + pinfodict['dist'] = sys.platform.lower() + pinfodict['arch'] = platform.architecture()[0] + if 'darwin' in pinfodict['dist']: + pinfodict['version'] = platform.mac_ver()[0] + elif pinfodict['dist'].startswith('win'): + pinfodict['version'] = str(platform.platform()) + + return pinfodict + + +def get_source_definition(function, with_docstring=False): + """Get the entire body of a function, including the signature line. + + :param with_docstring: Include docstring in return value. + Default is False. Supports docstrings in + triple double-quotes or triple single-quotes. + """ + thedoc = inspect.getdoc(function) + definition = inspect.cleandoc( + inspect.getsource(function)) + if thedoc and not with_docstring: + definition = definition.replace(thedoc, '') + doublequotes = definition.find('"""') + doublequotes = float("inf") if doublequotes == -1 else doublequotes + singlequotes = definition.find("'''") + singlequotes = float("inf") if singlequotes == -1 else singlequotes + if doublequotes != singlequotes: + triplet = '"""' if doublequotes < singlequotes else "'''" + definition = definition.replace(triplet, '', 2) + while definition.find('\n\n\n') != -1: + definition = definition.replace('\n\n\n', '\n\n') + + definition_copy = [] + for line in definition.split('\n'): + # pylint: disable=W0141 + if not any(map(line.strip().startswith, ("@", "def"))): + line = " "*4 + line + definition_copy.append(line) + + return "\n".join(definition_copy).strip() + + +def get_source_body(function, with_docstring=False): + """Get the body of a function (i.e. no definition line, unindented). + + :param with_docstring: Include docstring in return value. + Default is False. + """ + lines = get_source_definition( + function, with_docstring=with_docstring).split('\n') + + # Find body - skip decorators and definition + start = 0 + for number, line in enumerate(lines): + # pylint: disable=W0141 + if any(map(line.strip().startswith, ("@", "def"))): + start = number + 1 + + lines = lines[start:] + + # Unindent body + indent = len(lines[0]) - len(lines[0].lstrip()) + for index, line in enumerate(lines): + lines[index] = line[indent:] + return '\n'.join(lines).strip()