diff --git a/keystone/tests/unit/core.py b/keystone/tests/unit/core.py index 74b0e2e384..732da590a9 100644 --- a/keystone/tests/unit/core.py +++ b/keystone/tests/unit/core.py @@ -562,6 +562,19 @@ class BaseTestCase(testtools.TestCase): if not os.environ.get(env_var): self.skipTest('Env variable %s is not set.' % env_var) + def skip_test_overrides(self, *args, **kwargs): + if self._check_for_method_in_parents(self._testMethodName): + return super(BaseTestCase, self).skipTest(*args, **kwargs) + raise Exception('%r is not a previously defined test method' + % self._testMethodName) + + def _check_for_method_in_parents(self, name): + # skip first to get to parents + for cls in self.__class__.__mro__[1:]: + if hasattr(cls, name): + return True + return False + class TestCase(BaseTestCase): diff --git a/keystone/tests/unit/tests/test_core.py b/keystone/tests/unit/tests/test_core.py index 56e42bcc71..0460e8da84 100644 --- a/keystone/tests/unit/tests/test_core.py +++ b/keystone/tests/unit/tests/test_core.py @@ -33,6 +33,41 @@ class BaseTestTestCase(unit.BaseTestCase): matchers.raises(unit.UnexpectedExit)) +class TestOverrideSkipping(unit.BaseTestCase): + + class TestParent(unit.BaseTestCase): + def test_in_parent(self): + pass + + class TestChild(TestParent): + def test_in_parent(self): + self.skip_test_overrides('some message') + + def test_not_in_parent(self): + self.skip_test_overrides('some message') + + def test_skip_test_override_success(self): + # NOTE(dstanek): let's run the test and see what happens + test = self.TestChild('test_in_parent') + result = test.run() + + # NOTE(dstanek): reach into testtools to ensure the test succeeded + self.assertEqual([], result.decorated.errors) + + def test_skip_test_override_fails_for_missing_parent_test_case(self): + # NOTE(dstanek): let's run the test and see what happens + test = self.TestChild('test_not_in_parent') + result = test.run() + + # NOTE(dstanek): reach into testtools to ensure the test failed + # the way we expected + observed_error = result.decorated.errors[0] + observed_error_msg = observed_error[1] + expected_error_msg = ("'test_not_in_parent' is not a previously " + "defined test method") + self.assertIn(expected_error_msg, observed_error_msg) + + class TestTestCase(unit.TestCase): def test_bad_log(self):