diff --git a/NEWS b/NEWS index 617b197..8b699c8 100644 --- a/NEWS +++ b/NEWS @@ -6,6 +6,12 @@ Changes and improvements to testtools_, grouped by release. NEXT ~~~~ +Improvements +------------ + +* ``ExpectedException`` now accepts a msg parameter for describing an error, + much the same as assertEquals etc. (Robert Collins) + 0.9.30 ~~~~~~ diff --git a/doc/for-test-authors.rst b/doc/for-test-authors.rst index faaf36b..96cf5b0 100644 --- a/doc/for-test-authors.rst +++ b/doc/for-test-authors.rst @@ -163,7 +163,8 @@ The first argument to ``ExpectedException`` is the type of exception you expect to see raised. The second argument is optional, and can be either a regular expression or a matcher. If it is a regular expression, the ``str()`` of the raised exception must match the regular expression. If it is a matcher, -then the raised exception object must match it. +then the raised exception object must match it. The optional third argument +``msg`` will cause the raised error to be annotated with that message. assertIn, assertNotIn diff --git a/testtools/testcase.py b/testtools/testcase.py index 628ddf2..53324e5 100644 --- a/testtools/testcase.py +++ b/testtools/testcase.py @@ -818,26 +818,33 @@ class ExpectedException: exception is raised, an AssertionError will be raised. """ - def __init__(self, exc_type, value_re=None): + def __init__(self, exc_type, value_re=None, msg=None): """Construct an `ExpectedException`. :param exc_type: The type of exception to expect. :param value_re: A regular expression to match against the 'str()' of the raised exception. + :param msg: An optional message explaining the failure. """ self.exc_type = exc_type self.value_re = value_re + self.msg = msg def __enter__(self): pass def __exit__(self, exc_type, exc_value, traceback): if exc_type is None: - raise AssertionError('%s not raised.' % self.exc_type.__name__) + error_msg = '%s not raised.' % self.exc_type.__name__ + if self.msg: + error_msg = error_msg + ' : ' + self.msg + raise AssertionError(error_msg) if exc_type != self.exc_type: return False if self.value_re: matcher = MatchesException(self.exc_type, self.value_re) + if self.msg: + matcher = Annotate(self.msg, matcher) mismatch = matcher.match((exc_type, exc_value, traceback)) if mismatch: raise AssertionError(mismatch.describe()) diff --git a/testtools/tests/test_with_with.py b/testtools/tests/test_with_with.py index e06adeb..4305c62 100644 --- a/testtools/tests/test_with_with.py +++ b/testtools/tests/test_with_with.py @@ -11,6 +11,7 @@ from testtools import ( from testtools.matchers import ( AfterPreprocessing, Equals, + EndsWith, ) @@ -71,3 +72,17 @@ class TestExpectedException(TestCase): def test_pass_on_raise_any_message(self): with ExpectedException(ValueError): raise ValueError('whatever') + + def test_annotate(self): + def die(): + with ExpectedException(ValueError, msg="foo"): + pass + exc = self.assertRaises(AssertionError, die) + self.assertThat(exc.args[0], EndsWith(': foo')) + + def test_annotated_matcher(self): + def die(): + with ExpectedException(ValueError, 'bar', msg="foo"): + pass + exc = self.assertRaises(AssertionError, die) + self.assertThat(exc.args[0], EndsWith(': foo'))