diff --git a/os_win/tests/functional/test_mutex.py b/os_win/tests/functional/test_mutex.py new file mode 100644 index 00000000..019281b9 --- /dev/null +++ b/os_win/tests/functional/test_mutex.py @@ -0,0 +1,75 @@ +# Copyright 2019 Cloudbase Solutions Srl +# All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + +import threading +import uuid + +from os_win import exceptions +from os_win.tests.functional import test_base +from os_win.utils import processutils + + +class MutexTestCase(test_base.OsWinBaseFunctionalTestCase): + def setUp(self): + super(MutexTestCase, self).setUp() + + mutex_name = str(uuid.uuid4()) + self._mutex = processutils.Mutex(name=mutex_name) + + self.addCleanup(self._mutex.close) + + def acquire_mutex_in_separate_thread(self, mutex): + # We'll wait for a signal before releasing the mutex. + stop_event = threading.Event() + + def target(): + mutex.acquire() + + stop_event.wait() + + mutex.release() + + thread = threading.Thread(target=target) + thread.daemon = True + thread.start() + + return thread, stop_event + + def test_already_acquired_mutex(self): + thread, stop_event = self.acquire_mutex_in_separate_thread( + self._mutex) + + # We shouldn't be able to acquire a mutex held by a + # different thread. + self.assertFalse(self._mutex.acquire(timeout_ms=0)) + + stop_event.set() + + # We should now be able to acquire the mutex. + # We're using a timeout, giving the other thread some + # time to release it. + self.assertTrue(self._mutex.acquire(timeout_ms=2000)) + + def test_release_unacquired_mutex(self): + self.assertRaises(exceptions.Win32Exception, + self._mutex.release) + + def test_multiple_acquire(self): + # The mutex owner should be able to acquire it multiple times. + self._mutex.acquire(timeout_ms=0) + self._mutex.acquire(timeout_ms=0) + + self._mutex.release() + self._mutex.release() diff --git a/os_win/tests/unit/test_processutils.py b/os_win/tests/unit/test_processutils.py index 75329b2b..2a1de7df 100644 --- a/os_win/tests/unit/test_processutils.py +++ b/os_win/tests/unit/test_processutils.py @@ -196,3 +196,24 @@ class ProcessUtilsTestCase(test_base.OsWinBaseTestCase): mock_wait.assert_called_once_with(phandles, mock.sentinel.wait_all, mock.sentinel.milliseconds) + + def test_create_mutex(self): + handle = self._procutils.create_mutex( + mock.sentinel.name, mock.sentinel.owner, + mock.sentinel.sec_attr) + + self.assertEqual(self._mock_run.return_value, handle) + self._mock_run.assert_called_once_with( + self._mock_kernel32.CreateMutexW, + self._ctypes.byref(mock.sentinel.sec_attr), + mock.sentinel.owner, + mock.sentinel.name, + kernel32_lib_func=True) + + def test_release_mutex(self): + self._procutils.release_mutex(mock.sentinel.handle) + + self._mock_run.assert_called_once_with( + self._mock_kernel32.ReleaseMutex, + mock.sentinel.handle, + kernel32_lib_func=True) diff --git a/os_win/tests/unit/utils/test_win32utils.py b/os_win/tests/unit/utils/test_win32utils.py index 8964887b..57d60e95 100644 --- a/os_win/tests/unit/utils/test_win32utils.py +++ b/os_win/tests/unit/utils/test_win32utils.py @@ -241,3 +241,26 @@ class Win32UtilsTestCase(test_base.BaseTestCase): self._win32_utils.wait_for_multiple_objects, fake_handles, mock.sentinel.wait_all, mock.sentinel.milliseconds) + + @mock.patch.object(win32utils.Win32Utils, 'run_and_check_output') + def test_wait_for_single_object(self, mock_helper): + ret_val = self._win32_utils.wait_for_single_object( + mock.sentinel.handle, mock.sentinel.milliseconds) + + mock_helper.assert_called_once_with( + win32utils.kernel32.WaitForSingleObject, + mock.sentinel.handle, + mock.sentinel.milliseconds, + kernel32_lib_func=True, + error_ret_vals=[w_const.WAIT_FAILED]) + self.assertEqual(mock_helper.return_value, ret_val) + + @mock.patch.object(win32utils.Win32Utils, 'run_and_check_output') + def test_wait_for_single_object_timeout(self, mock_helper): + mock_helper.return_value = w_const.ERROR_WAIT_TIMEOUT + + self.assertRaises( + exceptions.Timeout, + self._win32_utils.wait_for_single_object, + mock.sentinel.timeout, + mock.sentinel.milliseconds) diff --git a/os_win/utils/processutils.py b/os_win/utils/processutils.py index fa3d5420..1e02f5ea 100644 --- a/os_win/utils/processutils.py +++ b/os_win/utils/processutils.py @@ -18,6 +18,7 @@ import ctypes from oslo_log import log as logging +from os_win import exceptions from os_win.utils import win32utils from os_win.utils.winapi import constants as w_const from os_win.utils.winapi import libs as w_lib @@ -128,3 +129,58 @@ class ProcessUtils(object): finally: for handle in handles: self._win32_utils.close_handle(handle) + + def create_mutex(self, name=None, initial_owner=False, + security_attributes=None): + sec_attr_ref = (ctypes.byref(security_attributes) + if security_attributes else None) + return self._run_and_check_output( + kernel32.CreateMutexW, + sec_attr_ref, + initial_owner, + name) + + def release_mutex(self, handle): + return self._run_and_check_output( + kernel32.ReleaseMutex, + handle) + + +class Mutex(object): + def __init__(self, name=None): + self.name = name + + self._processutils = ProcessUtils() + self._win32_utils = win32utils.Win32Utils() + + # This is supposed to be a simple interface. + # We're not exposing the "initial_owner" flag, + # nor are we informing the caller if the mutex + # already exists. + self._handle = self._processutils.create_mutex( + self.name) + + def acquire(self, timeout_ms=w_const.INFINITE): + try: + self._win32_utils.wait_for_single_object( + self._handle, timeout_ms) + return True + except exceptions.Timeout: + return False + + def release(self): + self._processutils.release_mutex(self._handle) + + def close(self): + if self._handle: + self._win32_utils.close_handle(self._handle) + self._handle = None + + __del__ = close + + def __enter__(self): + self.acquire() + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + self.release() diff --git a/os_win/utils/win32utils.py b/os_win/utils/win32utils.py index 873f0aec..8389a322 100644 --- a/os_win/utils/win32utils.py +++ b/os_win/utils/win32utils.py @@ -145,3 +145,16 @@ class Win32Utils(object): raise exceptions.Timeout() return ret_val + + def wait_for_single_object(self, handle, + milliseconds=w_const.INFINITE): + ret_val = self.run_and_check_output( + kernel32.WaitForSingleObject, + handle, + milliseconds, + kernel32_lib_func=True, + error_ret_vals=[w_const.WAIT_FAILED]) + if ret_val == w_const.ERROR_WAIT_TIMEOUT: + raise exceptions.Timeout() + + return ret_val diff --git a/os_win/utils/winapi/libs/kernel32.py b/os_win/utils/winapi/libs/kernel32.py index 5ec22b24..16783afe 100644 --- a/os_win/utils/winapi/libs/kernel32.py +++ b/os_win/utils/winapi/libs/kernel32.py @@ -95,6 +95,12 @@ def register(): ] lib_handle.CreateFileW.restype = wintypes.HANDLE + lib_handle.CreateMutexW.argtypes = [ + wintypes.LPCVOID, + wintypes.BOOL, + wintypes.LPCWSTR] + lib_handle.CreateMutexW.restype = wintypes.HANDLE + lib_handle.CreatePipe.argtypes = [ wintypes.PHANDLE, wintypes.PHANDLE, @@ -162,6 +168,9 @@ def register(): ] lib_handle.ReadFileEx.restype = wintypes.BOOL + lib_handle.ReleaseMutex.argtypes = [wintypes.HANDLE] + lib_handle.ReleaseMutex.restype = wintypes.BOOL + lib_handle.ResetEvent.argtypes = [wintypes.HANDLE] lib_handle.ResetEvent.restype = wintypes.BOOL @@ -171,6 +180,12 @@ def register(): lib_handle.SetLastError.argtypes = [wintypes.DWORD] lib_handle.SetLastError.restype = None + lib_handle.WaitForSingleObject.argtypes = [ + wintypes.HANDLE, + wintypes.DWORD + ] + lib_handle.WaitForSingleObject.restype = wintypes.DWORD + lib_handle.WaitForSingleObjectEx.argtypes = [ wintypes.HANDLE, wintypes.DWORD, diff --git a/os_win/utilsfactory.py b/os_win/utilsfactory.py index bd4ec04b..bb2daac3 100644 --- a/os_win/utilsfactory.py +++ b/os_win/utilsfactory.py @@ -19,6 +19,7 @@ from os_win._i18n import _ # noqa from os_win import exceptions from os_win.utils import hostutils from os_win.utils.io import namedpipe +from os_win.utils import processutils utils = hostutils.HostUtils() @@ -201,3 +202,7 @@ def get_processutils(): def get_ioutils(): return _get_class(class_type='ioutils') + + +def get_mutex(*args, **kwargs): + return processutils.Mutex(*args, **kwargs)