diff --git a/MANUAL b/MANUAL index 4867dae..db21366 100644 --- a/MANUAL +++ b/MANUAL @@ -52,6 +52,24 @@ given the exc_info for the exception, and can use this opportunity to attach more data (via the addDetails API) and potentially other uses. +TestCase.patch +~~~~~~~~~~~~~~ + +``patch`` is a convenient way to monkey-patch a Python object for the duration +of your test. It's especially useful for testing legacy code. e.g.:: + + def test_foo(self): + my_stream = StringIO() + self.patch(sys, 'stderr', my_stream) + run_some_code_that_prints_to_stderr() + self.assertEqual('', my_stream.getvalue()) + +The call to ``patch`` above masks sys.stderr with 'my_stream' so that anything +printed to stderr will be captured in a StringIO variable that can be actually +tested. Once the test is done, the real sys.stderr is restored to its rightful +place. + + TestCase.skipTest ~~~~~~~~~~~~~~~~~ diff --git a/NEWS b/NEWS index 7505cbc..ca26368 100644 --- a/NEWS +++ b/NEWS @@ -16,6 +16,9 @@ Improvements * New 'Is' matcher, which lets you assert that a thing is identical to another thing. + * TestCase now has a 'patch()' method to make it easier to monkey-patching + objects in tests. See the manual for more information. Fixes bug #310770. + 0.9.5 ~~~~~ diff --git a/testtools/monkey.py b/testtools/monkey.py new file mode 100644 index 0000000..bb24764 --- /dev/null +++ b/testtools/monkey.py @@ -0,0 +1,97 @@ +# Copyright (c) 2010 Jonathan M. Lange. See LICENSE for details. + +"""Helpers for monkey-patching Python code.""" + +__all__ = [ + 'MonkeyPatcher', + 'patch', + ] + + +class MonkeyPatcher(object): + """A set of monkey-patches that can be applied and removed all together. + + Use this to cover up attributes with new objects. Particularly useful for + testing difficult code. + """ + + # Marker used to indicate that the patched attribute did not exist on the + # object before we patched it. + _NO_SUCH_ATTRIBUTE = object() + + def __init__(self, *patches): + """Construct a `MonkeyPatcher`. + + :param *patches: The patches to apply, each should be (obj, name, + new_value). Providing patches here is equivalent to calling + `add_patch`. + """ + # List of patches to apply in (obj, name, value). + self._patches_to_apply = [] + # List of the original values for things that have been patched. + # (obj, name, value) format. + self._originals = [] + for patch in patches: + self.add_patch(*patch) + + def add_patch(self, obj, name, value): + """Add a patch to overwrite 'name' on 'obj' with 'value'. + + The attribute C{name} on C{obj} will be assigned to C{value} when + C{patch} is called or during C{run_with_patches}. + + You can restore the original values with a call to restore(). + """ + self._patches_to_apply.append((obj, name, value)) + + def patch(self): + """Apply all of the patches that have been specified with `add_patch`. + + Reverse this operation using L{restore}. + """ + for obj, name, value in self._patches_to_apply: + original_value = getattr(obj, name, self._NO_SUCH_ATTRIBUTE) + self._originals.append((obj, name, original_value)) + setattr(obj, name, value) + + def restore(self): + """Restore all original values to any patched objects. + + If the patched attribute did not exist on an object before it was + patched, `restore` will delete the attribute so as to return the + object to its original state. + """ + while self._originals: + obj, name, value = self._originals.pop() + if value is self._NO_SUCH_ATTRIBUTE: + delattr(obj, name) + else: + setattr(obj, name, value) + + def run_with_patches(self, f, *args, **kw): + """Run 'f' with the given args and kwargs with all patches applied. + + Restores all objects to their original state when finished. + """ + self.patch() + try: + return f(*args, **kw) + finally: + self.restore() + + +def patch(obj, attribute, value): + """Set 'obj.attribute' to 'value' and return a callable to restore 'obj'. + + If 'attribute' is not set on 'obj' already, then the returned callable + will delete the attribute when called. + + :param obj: An object to monkey-patch. + :param attribute: The name of the attribute to patch. + :param value: The value to set 'obj.attribute' to. + :return: A nullary callable that, when run, will restore 'obj' to its + original state. + """ + patcher = MonkeyPatcher((obj, attribute, value)) + patcher.patch() + return patcher.restore diff --git a/testtools/testcase.py b/testtools/testcase.py index a232eec..53bea40 100644 --- a/testtools/testcase.py +++ b/testtools/testcase.py @@ -1,4 +1,4 @@ -# Copyright (c) 2008, 2009 Jonathan M. Lange. See LICENSE for details. +# Copyright (c) 2008-2010 Jonathan M. Lange. See LICENSE for details. """Test case related stuff.""" @@ -23,6 +23,7 @@ import unittest from testtools import content from testtools.compat import advance_iterator +from testtools.monkey import patch from testtools.runtest import RunTest from testtools.testresult import TestResult @@ -121,6 +122,19 @@ class TestCase(unittest.TestCase): """ return self.__details + def patch(self, obj, attribute, value): + """Monkey-patch 'obj.attribute' to 'value' while the test is running. + + If 'obj' has no attribute, then the monkey-patch will still go ahead, + and the attribute will be deleted instead of restored to its original + value. + + :param obj: The object to patch. Can be anything. + :param attribute: The attribute on 'obj' to patch. + :param value: The value to set 'obj.attribute' to. + """ + self.addCleanup(patch(obj, attribute, value)) + def shortDescription(self): return self.id() @@ -137,7 +151,7 @@ class TestCase(unittest.TestCase): """ raise self.skipException(reason) - # skipTest is how python2.7 spells this. Sometime in the future + # skipTest is how python2.7 spells this. Sometime in the future # This should be given a deprecation decorator - RBC 20100611. skip = skipTest @@ -444,7 +458,7 @@ if types.FunctionType not in copy._copy_dispatch: def clone_test_with_new_id(test, new_id): """Copy a TestCase, and give the copied test a new id. - + This is only expected to be used on tests that have been constructed but not executed. """ diff --git a/testtools/tests/__init__.py b/testtools/tests/__init__.py index 1c04ea5..5e22000 100644 --- a/testtools/tests/__init__.py +++ b/testtools/tests/__init__.py @@ -8,6 +8,7 @@ from testtools.tests import ( test_content, test_content_type, test_matchers, + test_monkey, test_runtest, test_testtools, test_testresult, @@ -22,6 +23,7 @@ def test_suite(): test_content, test_content_type, test_matchers, + test_monkey, test_runtest, test_testresult, test_testsuite, diff --git a/testtools/tests/test_monkey.py b/testtools/tests/test_monkey.py new file mode 100644 index 0000000..09388b2 --- /dev/null +++ b/testtools/tests/test_monkey.py @@ -0,0 +1,166 @@ +# Copyright (c) 2010 Twisted Matrix Laboratories. +# See LICENSE for details. + +"""Tests for testtools.monkey.""" + +from testtools import TestCase +from testtools.monkey import MonkeyPatcher, patch + + +class TestObj: + + def __init__(self): + self.foo = 'foo value' + self.bar = 'bar value' + self.baz = 'baz value' + + +class MonkeyPatcherTest(TestCase): + """ + Tests for 'MonkeyPatcher' monkey-patching class. + """ + + def setUp(self): + super(MonkeyPatcherTest, self).setUp() + self.test_object = TestObj() + self.original_object = TestObj() + self.monkey_patcher = MonkeyPatcher() + + def test_empty(self): + # A monkey patcher without patches doesn't change a thing. + self.monkey_patcher.patch() + + # We can't assert that all state is unchanged, but at least we can + # check our test object. + self.assertEquals(self.original_object.foo, self.test_object.foo) + self.assertEquals(self.original_object.bar, self.test_object.bar) + self.assertEquals(self.original_object.baz, self.test_object.baz) + + def test_construct_with_patches(self): + # Constructing a 'MonkeyPatcher' with patches adds all of the given + # patches to the patch list. + patcher = MonkeyPatcher((self.test_object, 'foo', 'haha'), + (self.test_object, 'bar', 'hehe')) + patcher.patch() + self.assertEquals('haha', self.test_object.foo) + self.assertEquals('hehe', self.test_object.bar) + self.assertEquals(self.original_object.baz, self.test_object.baz) + + def test_patch_existing(self): + # Patching an attribute that exists sets it to the value defined in the + # patch. + self.monkey_patcher.add_patch(self.test_object, 'foo', 'haha') + self.monkey_patcher.patch() + self.assertEquals(self.test_object.foo, 'haha') + + def test_patch_non_existing(self): + # Patching a non-existing attribute sets it to the value defined in + # the patch. + self.monkey_patcher.add_patch(self.test_object, 'doesntexist', 'value') + self.monkey_patcher.patch() + self.assertEquals(self.test_object.doesntexist, 'value') + + def test_restore_non_existing(self): + # Restoring a value that didn't exist before the patch deletes the + # value. + self.monkey_patcher.add_patch(self.test_object, 'doesntexist', 'value') + self.monkey_patcher.patch() + self.monkey_patcher.restore() + marker = object() + self.assertIs(marker, getattr(self.test_object, 'doesntexist', marker)) + + def test_patch_already_patched(self): + # Adding a patch for an object and attribute that already have a patch + # overrides the existing patch. + self.monkey_patcher.add_patch(self.test_object, 'foo', 'blah') + self.monkey_patcher.add_patch(self.test_object, 'foo', 'BLAH') + self.monkey_patcher.patch() + self.assertEquals(self.test_object.foo, 'BLAH') + self.monkey_patcher.restore() + self.assertEquals(self.test_object.foo, self.original_object.foo) + + def test_restore_twice_is_a_no_op(self): + # Restoring an already-restored monkey patch is a no-op. + self.monkey_patcher.add_patch(self.test_object, 'foo', 'blah') + self.monkey_patcher.patch() + self.monkey_patcher.restore() + self.assertEquals(self.test_object.foo, self.original_object.foo) + self.monkey_patcher.restore() + self.assertEquals(self.test_object.foo, self.original_object.foo) + + def test_run_with_patches_decoration(self): + # run_with_patches runs the given callable, passing in all arguments + # and keyword arguments, and returns the return value of the callable. + log = [] + + def f(a, b, c=None): + log.append((a, b, c)) + return 'foo' + + result = self.monkey_patcher.run_with_patches(f, 1, 2, c=10) + self.assertEquals('foo', result) + self.assertEquals([(1, 2, 10)], log) + + def test_repeated_run_with_patches(self): + # We can call the same function with run_with_patches more than + # once. All patches apply for each call. + def f(): + return (self.test_object.foo, self.test_object.bar, + self.test_object.baz) + + self.monkey_patcher.add_patch(self.test_object, 'foo', 'haha') + result = self.monkey_patcher.run_with_patches(f) + self.assertEquals( + ('haha', self.original_object.bar, self.original_object.baz), + result) + result = self.monkey_patcher.run_with_patches(f) + self.assertEquals( + ('haha', self.original_object.bar, self.original_object.baz), + result) + + def test_run_with_patches_restores(self): + # run_with_patches restores the original values after the function has + # executed. + self.monkey_patcher.add_patch(self.test_object, 'foo', 'haha') + self.assertEquals(self.original_object.foo, self.test_object.foo) + self.monkey_patcher.run_with_patches(lambda: None) + self.assertEquals(self.original_object.foo, self.test_object.foo) + + def test_run_with_patches_restores_on_exception(self): + # run_with_patches restores the original values even when the function + # raises an exception. + def _(): + self.assertEquals(self.test_object.foo, 'haha') + self.assertEquals(self.test_object.bar, 'blahblah') + raise RuntimeError, "Something went wrong!" + + self.monkey_patcher.add_patch(self.test_object, 'foo', 'haha') + self.monkey_patcher.add_patch(self.test_object, 'bar', 'blahblah') + + self.assertRaises( + RuntimeError, self.monkey_patcher.run_with_patches, _) + self.assertEquals(self.test_object.foo, self.original_object.foo) + self.assertEquals(self.test_object.bar, self.original_object.bar) + + +class TestPatchHelper(TestCase): + + def test_patch_patches(self): + # patch(obj, name, value) sets obj.name to value. + test_object = TestObj() + patch(test_object, 'foo', 42) + self.assertEqual(42, test_object.foo) + + def test_patch_returns_cleanup(self): + # patch(obj, name, value) returns a nullary callable that restores obj + # to its original state when run. + test_object = TestObj() + original = test_object.foo + cleanup = patch(test_object, 'foo', 42) + cleanup() + self.assertEqual(original, test_object.foo) + + +def test_suite(): + from unittest import TestLoader + return TestLoader().loadTestsFromName(__name__) diff --git a/testtools/tests/test_testtools.py b/testtools/tests/test_testtools.py index 565efdd..d2a9b7c 100644 --- a/testtools/tests/test_testtools.py +++ b/testtools/tests/test_testtools.py @@ -829,6 +829,64 @@ class TestOnException(TestCase): self.assertThat(events, Equals([])) +class TestPatchSupport(TestCase): + + class Case(TestCase): + def test(self): + pass + + def test_patch(self): + # TestCase.patch masks obj.attribute with the new value. + self.foo = 'original' + test = self.Case('test') + test.patch(self, 'foo', 'patched') + self.assertEqual('patched', self.foo) + + def test_patch_restored_after_run(self): + # TestCase.patch masks obj.attribute with the new value, but restores + # the original value after the test is finished. + self.foo = 'original' + test = self.Case('test') + test.patch(self, 'foo', 'patched') + test.run() + self.assertEqual('original', self.foo) + + def test_successive_patches_apply(self): + # TestCase.patch can be called multiple times per test. Each time you + # call it, it overrides the original value. + self.foo = 'original' + test = self.Case('test') + test.patch(self, 'foo', 'patched') + test.patch(self, 'foo', 'second') + self.assertEqual('second', self.foo) + + def test_successive_patches_restored_after_run(self): + # TestCase.patch restores the original value, no matter how many times + # it was called. + self.foo = 'original' + test = self.Case('test') + test.patch(self, 'foo', 'patched') + test.patch(self, 'foo', 'second') + test.run() + self.assertEqual('original', self.foo) + + def test_patch_nonexistent_attribute(self): + # TestCase.patch can be used to patch a non-existent attribute. + test = self.Case('test') + test.patch(self, 'doesntexist', 'patched') + self.assertEqual('patched', self.doesntexist) + + def test_restore_nonexistent_attribute(self): + # TestCase.patch can be used to patch a non-existent attribute, after + # the test run, the attribute is then removed from the object. + test = self.Case('test') + test.patch(self, 'doesntexist', 'patched') + test.run() + marker = object() + value = getattr(self, 'doesntexist', marker) + self.assertIs(marker, value) + + def test_suite(): from unittest import TestLoader return TestLoader().loadTestsFromName(__name__)