diff --git a/test/unit/common/test_utils.py b/test/unit/common/test_utils.py index cf41f057d4..06d2959d0d 100644 --- a/test/unit/common/test_utils.py +++ b/test/unit/common/test_utils.py @@ -30,6 +30,7 @@ from textwrap import dedent import threading import time import unittest +import fcntl from Queue import Queue, Empty from getpass import getuser from shutil import rmtree @@ -40,7 +41,7 @@ from tempfile import TemporaryFile, NamedTemporaryFile from mock import MagicMock, patch from swift.common.exceptions import (Timeout, MessageTimeout, - ConnectionTimeout) + ConnectionTimeout, LockTimeout) from swift.common import utils from swift.common.swob import Response @@ -1324,6 +1325,47 @@ log_name = %(yarr)s''' utils.tpool_reraise, MagicMock(side_effect=BaseException('test3'))) + def test_lock_file(self): + flags = os.O_CREAT | os.O_RDWR + with NamedTemporaryFile(delete=False) as nt: + nt.write("test string") + nt.flush() + nt.close() + with utils.lock_file(nt.name, unlink=False) as f: + self.assertEqual(f.read(), "test string") + # we have a lock, now let's try to get a newer one + fd = os.open(nt.name, flags) + self.assertRaises(IOError, fcntl.flock, fd, + fcntl.LOCK_EX | fcntl.LOCK_NB) + + with utils.lock_file(nt.name, unlink=False, append=True) as f: + self.assertEqual(f.read(), "test string") + f.seek(0) + f.write("\nanother string") + f.flush() + f.seek(0) + self.assertEqual(f.read(), "test string\nanother string") + + # we have a lock, now let's try to get a newer one + fd = os.open(nt.name, flags) + self.assertRaises(IOError, fcntl.flock, fd, + fcntl.LOCK_EX | fcntl.LOCK_NB) + + with utils.lock_file(nt.name, timeout=3, unlink=False) as f: + try: + with utils.lock_file(nt.name, timeout=1, unlink=False) as f: + self.assertTrue(False, "Expected LockTimeout exception") + except LockTimeout: + pass + + with utils.lock_file(nt.name, unlink=True) as f: + self.assertEqual(f.read(), "test string\nanother string") + # we have a lock, now let's try to get a newer one + fd = os.open(nt.name, flags) + self.assertRaises(IOError, fcntl.flock, fd, fcntl.LOCK_EX | fcntl.LOCK_NB) + + self.assertRaises(OSError, os.remove, nt.name) + class TestStatsdLogging(unittest.TestCase): def test_get_logger_statsd_client_not_specified(self):