Add TestCase.patch() and new testtools.monkey module for help with monkey-patching.

This commit is contained in:
Jonathan Lange
2010-08-04 14:07:42 +01:00
7 changed files with 361 additions and 3 deletions

18
MANUAL
View File

@@ -52,6 +52,24 @@ given the exc_info for the exception, and can use this opportunity to attach
more data (via the addDetails API) and potentially other uses. more data (via the addDetails API) and potentially other uses.
TestCase.patch
~~~~~~~~~~~~~~
``patch`` is a convenient way to monkey-patch a Python object for the duration
of your test. It's especially useful for testing legacy code. e.g.::
def test_foo(self):
my_stream = StringIO()
self.patch(sys, 'stderr', my_stream)
run_some_code_that_prints_to_stderr()
self.assertEqual('', my_stream.getvalue())
The call to ``patch`` above masks sys.stderr with 'my_stream' so that anything
printed to stderr will be captured in a StringIO variable that can be actually
tested. Once the test is done, the real sys.stderr is restored to its rightful
place.
TestCase.skipTest TestCase.skipTest
~~~~~~~~~~~~~~~~~ ~~~~~~~~~~~~~~~~~

3
NEWS
View File

@@ -16,6 +16,9 @@ Improvements
* New 'Is' matcher, which lets you assert that a thing is identical to * New 'Is' matcher, which lets you assert that a thing is identical to
another thing. another thing.
* TestCase now has a 'patch()' method to make it easier to monkey-patching
objects in tests. See the manual for more information. Fixes bug #310770.
0.9.5 0.9.5
~~~~~ ~~~~~

97
testtools/monkey.py Normal file
View File

@@ -0,0 +1,97 @@
# Copyright (c) 2010 Jonathan M. Lange. See LICENSE for details.
"""Helpers for monkey-patching Python code."""
__all__ = [
'MonkeyPatcher',
'patch',
]
class MonkeyPatcher(object):
"""A set of monkey-patches that can be applied and removed all together.
Use this to cover up attributes with new objects. Particularly useful for
testing difficult code.
"""
# Marker used to indicate that the patched attribute did not exist on the
# object before we patched it.
_NO_SUCH_ATTRIBUTE = object()
def __init__(self, *patches):
"""Construct a `MonkeyPatcher`.
:param *patches: The patches to apply, each should be (obj, name,
new_value). Providing patches here is equivalent to calling
`add_patch`.
"""
# List of patches to apply in (obj, name, value).
self._patches_to_apply = []
# List of the original values for things that have been patched.
# (obj, name, value) format.
self._originals = []
for patch in patches:
self.add_patch(*patch)
def add_patch(self, obj, name, value):
"""Add a patch to overwrite 'name' on 'obj' with 'value'.
The attribute C{name} on C{obj} will be assigned to C{value} when
C{patch} is called or during C{run_with_patches}.
You can restore the original values with a call to restore().
"""
self._patches_to_apply.append((obj, name, value))
def patch(self):
"""Apply all of the patches that have been specified with `add_patch`.
Reverse this operation using L{restore}.
"""
for obj, name, value in self._patches_to_apply:
original_value = getattr(obj, name, self._NO_SUCH_ATTRIBUTE)
self._originals.append((obj, name, original_value))
setattr(obj, name, value)
def restore(self):
"""Restore all original values to any patched objects.
If the patched attribute did not exist on an object before it was
patched, `restore` will delete the attribute so as to return the
object to its original state.
"""
while self._originals:
obj, name, value = self._originals.pop()
if value is self._NO_SUCH_ATTRIBUTE:
delattr(obj, name)
else:
setattr(obj, name, value)
def run_with_patches(self, f, *args, **kw):
"""Run 'f' with the given args and kwargs with all patches applied.
Restores all objects to their original state when finished.
"""
self.patch()
try:
return f(*args, **kw)
finally:
self.restore()
def patch(obj, attribute, value):
"""Set 'obj.attribute' to 'value' and return a callable to restore 'obj'.
If 'attribute' is not set on 'obj' already, then the returned callable
will delete the attribute when called.
:param obj: An object to monkey-patch.
:param attribute: The name of the attribute to patch.
:param value: The value to set 'obj.attribute' to.
:return: A nullary callable that, when run, will restore 'obj' to its
original state.
"""
patcher = MonkeyPatcher((obj, attribute, value))
patcher.patch()
return patcher.restore

View File

@@ -1,4 +1,4 @@
# Copyright (c) 2008, 2009 Jonathan M. Lange. See LICENSE for details. # Copyright (c) 2008-2010 Jonathan M. Lange. See LICENSE for details.
"""Test case related stuff.""" """Test case related stuff."""
@@ -23,6 +23,7 @@ import unittest
from testtools import content from testtools import content
from testtools.compat import advance_iterator from testtools.compat import advance_iterator
from testtools.monkey import patch
from testtools.runtest import RunTest from testtools.runtest import RunTest
from testtools.testresult import TestResult from testtools.testresult import TestResult
@@ -121,6 +122,19 @@ class TestCase(unittest.TestCase):
""" """
return self.__details return self.__details
def patch(self, obj, attribute, value):
"""Monkey-patch 'obj.attribute' to 'value' while the test is running.
If 'obj' has no attribute, then the monkey-patch will still go ahead,
and the attribute will be deleted instead of restored to its original
value.
:param obj: The object to patch. Can be anything.
:param attribute: The attribute on 'obj' to patch.
:param value: The value to set 'obj.attribute' to.
"""
self.addCleanup(patch(obj, attribute, value))
def shortDescription(self): def shortDescription(self):
return self.id() return self.id()
@@ -137,7 +151,7 @@ class TestCase(unittest.TestCase):
""" """
raise self.skipException(reason) raise self.skipException(reason)
# skipTest is how python2.7 spells this. Sometime in the future # skipTest is how python2.7 spells this. Sometime in the future
# This should be given a deprecation decorator - RBC 20100611. # This should be given a deprecation decorator - RBC 20100611.
skip = skipTest skip = skipTest
@@ -444,7 +458,7 @@ if types.FunctionType not in copy._copy_dispatch:
def clone_test_with_new_id(test, new_id): def clone_test_with_new_id(test, new_id):
"""Copy a TestCase, and give the copied test a new id. """Copy a TestCase, and give the copied test a new id.
This is only expected to be used on tests that have been constructed but This is only expected to be used on tests that have been constructed but
not executed. not executed.
""" """

View File

@@ -8,6 +8,7 @@ from testtools.tests import (
test_content, test_content,
test_content_type, test_content_type,
test_matchers, test_matchers,
test_monkey,
test_runtest, test_runtest,
test_testtools, test_testtools,
test_testresult, test_testresult,
@@ -22,6 +23,7 @@ def test_suite():
test_content, test_content,
test_content_type, test_content_type,
test_matchers, test_matchers,
test_monkey,
test_runtest, test_runtest,
test_testresult, test_testresult,
test_testsuite, test_testsuite,

View File

@@ -0,0 +1,166 @@
# Copyright (c) 2010 Twisted Matrix Laboratories.
# See LICENSE for details.
"""Tests for testtools.monkey."""
from testtools import TestCase
from testtools.monkey import MonkeyPatcher, patch
class TestObj:
def __init__(self):
self.foo = 'foo value'
self.bar = 'bar value'
self.baz = 'baz value'
class MonkeyPatcherTest(TestCase):
"""
Tests for 'MonkeyPatcher' monkey-patching class.
"""
def setUp(self):
super(MonkeyPatcherTest, self).setUp()
self.test_object = TestObj()
self.original_object = TestObj()
self.monkey_patcher = MonkeyPatcher()
def test_empty(self):
# A monkey patcher without patches doesn't change a thing.
self.monkey_patcher.patch()
# We can't assert that all state is unchanged, but at least we can
# check our test object.
self.assertEquals(self.original_object.foo, self.test_object.foo)
self.assertEquals(self.original_object.bar, self.test_object.bar)
self.assertEquals(self.original_object.baz, self.test_object.baz)
def test_construct_with_patches(self):
# Constructing a 'MonkeyPatcher' with patches adds all of the given
# patches to the patch list.
patcher = MonkeyPatcher((self.test_object, 'foo', 'haha'),
(self.test_object, 'bar', 'hehe'))
patcher.patch()
self.assertEquals('haha', self.test_object.foo)
self.assertEquals('hehe', self.test_object.bar)
self.assertEquals(self.original_object.baz, self.test_object.baz)
def test_patch_existing(self):
# Patching an attribute that exists sets it to the value defined in the
# patch.
self.monkey_patcher.add_patch(self.test_object, 'foo', 'haha')
self.monkey_patcher.patch()
self.assertEquals(self.test_object.foo, 'haha')
def test_patch_non_existing(self):
# Patching a non-existing attribute sets it to the value defined in
# the patch.
self.monkey_patcher.add_patch(self.test_object, 'doesntexist', 'value')
self.monkey_patcher.patch()
self.assertEquals(self.test_object.doesntexist, 'value')
def test_restore_non_existing(self):
# Restoring a value that didn't exist before the patch deletes the
# value.
self.monkey_patcher.add_patch(self.test_object, 'doesntexist', 'value')
self.monkey_patcher.patch()
self.monkey_patcher.restore()
marker = object()
self.assertIs(marker, getattr(self.test_object, 'doesntexist', marker))
def test_patch_already_patched(self):
# Adding a patch for an object and attribute that already have a patch
# overrides the existing patch.
self.monkey_patcher.add_patch(self.test_object, 'foo', 'blah')
self.monkey_patcher.add_patch(self.test_object, 'foo', 'BLAH')
self.monkey_patcher.patch()
self.assertEquals(self.test_object.foo, 'BLAH')
self.monkey_patcher.restore()
self.assertEquals(self.test_object.foo, self.original_object.foo)
def test_restore_twice_is_a_no_op(self):
# Restoring an already-restored monkey patch is a no-op.
self.monkey_patcher.add_patch(self.test_object, 'foo', 'blah')
self.monkey_patcher.patch()
self.monkey_patcher.restore()
self.assertEquals(self.test_object.foo, self.original_object.foo)
self.monkey_patcher.restore()
self.assertEquals(self.test_object.foo, self.original_object.foo)
def test_run_with_patches_decoration(self):
# run_with_patches runs the given callable, passing in all arguments
# and keyword arguments, and returns the return value of the callable.
log = []
def f(a, b, c=None):
log.append((a, b, c))
return 'foo'
result = self.monkey_patcher.run_with_patches(f, 1, 2, c=10)
self.assertEquals('foo', result)
self.assertEquals([(1, 2, 10)], log)
def test_repeated_run_with_patches(self):
# We can call the same function with run_with_patches more than
# once. All patches apply for each call.
def f():
return (self.test_object.foo, self.test_object.bar,
self.test_object.baz)
self.monkey_patcher.add_patch(self.test_object, 'foo', 'haha')
result = self.monkey_patcher.run_with_patches(f)
self.assertEquals(
('haha', self.original_object.bar, self.original_object.baz),
result)
result = self.monkey_patcher.run_with_patches(f)
self.assertEquals(
('haha', self.original_object.bar, self.original_object.baz),
result)
def test_run_with_patches_restores(self):
# run_with_patches restores the original values after the function has
# executed.
self.monkey_patcher.add_patch(self.test_object, 'foo', 'haha')
self.assertEquals(self.original_object.foo, self.test_object.foo)
self.monkey_patcher.run_with_patches(lambda: None)
self.assertEquals(self.original_object.foo, self.test_object.foo)
def test_run_with_patches_restores_on_exception(self):
# run_with_patches restores the original values even when the function
# raises an exception.
def _():
self.assertEquals(self.test_object.foo, 'haha')
self.assertEquals(self.test_object.bar, 'blahblah')
raise RuntimeError, "Something went wrong!"
self.monkey_patcher.add_patch(self.test_object, 'foo', 'haha')
self.monkey_patcher.add_patch(self.test_object, 'bar', 'blahblah')
self.assertRaises(
RuntimeError, self.monkey_patcher.run_with_patches, _)
self.assertEquals(self.test_object.foo, self.original_object.foo)
self.assertEquals(self.test_object.bar, self.original_object.bar)
class TestPatchHelper(TestCase):
def test_patch_patches(self):
# patch(obj, name, value) sets obj.name to value.
test_object = TestObj()
patch(test_object, 'foo', 42)
self.assertEqual(42, test_object.foo)
def test_patch_returns_cleanup(self):
# patch(obj, name, value) returns a nullary callable that restores obj
# to its original state when run.
test_object = TestObj()
original = test_object.foo
cleanup = patch(test_object, 'foo', 42)
cleanup()
self.assertEqual(original, test_object.foo)
def test_suite():
from unittest import TestLoader
return TestLoader().loadTestsFromName(__name__)

View File

@@ -829,6 +829,64 @@ class TestOnException(TestCase):
self.assertThat(events, Equals([])) self.assertThat(events, Equals([]))
class TestPatchSupport(TestCase):
class Case(TestCase):
def test(self):
pass
def test_patch(self):
# TestCase.patch masks obj.attribute with the new value.
self.foo = 'original'
test = self.Case('test')
test.patch(self, 'foo', 'patched')
self.assertEqual('patched', self.foo)
def test_patch_restored_after_run(self):
# TestCase.patch masks obj.attribute with the new value, but restores
# the original value after the test is finished.
self.foo = 'original'
test = self.Case('test')
test.patch(self, 'foo', 'patched')
test.run()
self.assertEqual('original', self.foo)
def test_successive_patches_apply(self):
# TestCase.patch can be called multiple times per test. Each time you
# call it, it overrides the original value.
self.foo = 'original'
test = self.Case('test')
test.patch(self, 'foo', 'patched')
test.patch(self, 'foo', 'second')
self.assertEqual('second', self.foo)
def test_successive_patches_restored_after_run(self):
# TestCase.patch restores the original value, no matter how many times
# it was called.
self.foo = 'original'
test = self.Case('test')
test.patch(self, 'foo', 'patched')
test.patch(self, 'foo', 'second')
test.run()
self.assertEqual('original', self.foo)
def test_patch_nonexistent_attribute(self):
# TestCase.patch can be used to patch a non-existent attribute.
test = self.Case('test')
test.patch(self, 'doesntexist', 'patched')
self.assertEqual('patched', self.doesntexist)
def test_restore_nonexistent_attribute(self):
# TestCase.patch can be used to patch a non-existent attribute, after
# the test run, the attribute is then removed from the object.
test = self.Case('test')
test.patch(self, 'doesntexist', 'patched')
test.run()
marker = object()
value = getattr(self, 'doesntexist', marker)
self.assertIs(marker, value)
def test_suite(): def test_suite():
from unittest import TestLoader from unittest import TestLoader
return TestLoader().loadTestsFromName(__name__) return TestLoader().loadTestsFromName(__name__)