diff --git a/octane/tests/test_util_subprocess.py b/octane/tests/test_util_subprocess.py new file mode 100644 index 00000000..f313f2be --- /dev/null +++ b/octane/tests/test_util_subprocess.py @@ -0,0 +1,68 @@ +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + +import mock +import pytest + +from octane.util import subprocess + + +class _TestException(Exception): + pass + + +@pytest.mark.parametrize(("exception", "reraise", "calls"), [ + (None, False, [ + mock.call.stat("/fake/filename"), + mock.call.chmod("/temp/filename", 0o640), + mock.call.chown("/temp/filename", 2, 3), + mock.call.rename("/fake/filename", "/fake/filename.bak"), + mock.call.rename("/temp/filename", "/fake/filename"), + mock.call.unlink("/fake/filename.bak"), + ]), + (subprocess.DontUpdateException, False, [ + mock.call.unlink("/temp/filename"), + ]), + (_TestException, True, [ + mock.call.unlink("/temp/filename"), + ]), +]) +def test_update_file(mocker, mock_open, exception, reraise, calls): + mock_tempfile = mocker.patch("octane.util.tempfile.get_tempname") + mock_tempfile.return_value = "/temp/filename" + + mock_old = mock.MagicMock() + mock_new = mock.MagicMock() + + mock_open.side_effect = [mock_old, mock_new] + + mock_os = mock.Mock() + os_methods = ["unlink", "stat", "chmod", "chown", "rename"] + for method in os_methods: + mocker.patch("os." + method, new=getattr(mock_os, method)) + + mock_os.stat.return_value.configure_mock( + st_mode=0o640, + st_uid=2, + st_gid=3, + ) + + if reraise: + with pytest.raises(exception): + with subprocess.update_file("/fake/filename"): + raise exception + else: + with subprocess.update_file("/fake/filename"): + if exception is not None: + raise exception + + assert mock_os.mock_calls == calls diff --git a/octane/util/maintenance.py b/octane/util/maintenance.py index 07a7d6f5..5e306464 100644 --- a/octane/util/maintenance.py +++ b/octane/util/maintenance.py @@ -44,7 +44,7 @@ def disable_apis(env): with ssh.update_file(sftp, f) as (old, new): contents = old.read() if not mode_tcp_re.search(contents): - raise ssh.DontUpdateException + raise subprocess.DontUpdateException new.write(contents) if not contents.endswith('\n'): new.write('\n') diff --git a/octane/util/ssh.py b/octane/util/ssh.py index 9ff5fc38..0851195c 100644 --- a/octane/util/ssh.py +++ b/octane/util/ssh.py @@ -191,10 +191,6 @@ def sftp(node): return _get_sftp(node) -class DontUpdateException(Exception): - pass - - @contextlib.contextmanager def update_file(sftp, filename): old = sftp.open(filename, 'r') @@ -209,7 +205,7 @@ def update_file(sftp, filename): with contextlib.nested(old, new): try: yield old, new - except DontUpdateException: + except subprocess.DontUpdateException: sftp.unlink(temp_filename) return except Exception: diff --git a/octane/util/subprocess.py b/octane/util/subprocess.py index 425f9e61..3e5efa54 100644 --- a/octane/util/subprocess.py +++ b/octane/util/subprocess.py @@ -22,6 +22,8 @@ import re import subprocess import threading +from octane.util import tempfile + LOG = logging.getLogger(__name__) PIPE = subprocess.PIPE CalledProcessError = subprocess.CalledProcessError @@ -202,3 +204,32 @@ def call(cmd, **kwargs): def call_output(cmd, **kwargs): return call(cmd, stdout=PIPE, **kwargs)[0] + + +class DontUpdateException(Exception): + pass + + +@contextlib.contextmanager +def update_file(filename): + old = open(filename, 'r') + dirname = os.path.dirname(filename) + prefix = ".{0}.".format(os.path.basename(filename)) + temp_filename = tempfile.get_tempname(dir=dirname, prefix=prefix) + new = open(temp_filename, 'w') + with contextlib.nested(old, new): + try: + yield old, new + except DontUpdateException: + os.unlink(temp_filename) + return + except Exception: + os.unlink(temp_filename) + raise + stat = os.stat(filename) + os.chmod(temp_filename, stat.st_mode) + os.chown(temp_filename, stat.st_uid, stat.st_gid) + bak_filename = filename + '.bak' + os.rename(filename, bak_filename) + os.rename(temp_filename, filename) + os.unlink(bak_filename)