diff --git a/testtools/helpers.py b/testtools/helpers.py index 626f38f..b2ce391 100644 --- a/testtools/helpers.py +++ b/testtools/helpers.py @@ -10,6 +10,10 @@ def try_import(module_name, alternative=None): When supporting multiple versions of Python or optional dependencies, it is useful to be able to try to import a module. + + :param module_name: The name of the module to import, e.g. 'os.path'. + :param alternative: The value to return if no module can be imported. + Defaults to None. """ try: module = __import__(module_name) @@ -19,3 +23,22 @@ def try_import(module_name, alternative=None): for segment in segments: module = getattr(module, segment) return module + + +def try_imports(module_names, alternative=None): + """Attempt to import modules. + + Tries to import the first module in `module_names`. If it can be + imported, we return it. If not, we go on to the second module and try + that. The process continues until we run out of modules to try. If none + of the modules can be imported, return the provided `alternative` value. + + :param module_names: A sequence of module names to try to import. + :param alternative: The value to return if no module can be imported. + Defaults to 'None'. + """ + for module_name in module_names: + module = try_import(module_name) + if module: + return module + return alternative diff --git a/testtools/tests/test_helpers.py b/testtools/tests/test_helpers.py index 9f62ac9..129025d 100644 --- a/testtools/tests/test_helpers.py +++ b/testtools/tests/test_helpers.py @@ -1,7 +1,10 @@ # Copyright (c) 2010 Jonathan M. Lange. See LICENSE for details. from testtools import TestCase -from testtools.helpers import try_import +from testtools.helpers import ( + try_import, + try_imports, + ) from testtools.matchers import Is @@ -32,7 +35,7 @@ class TestTryImport(TestCase): import os self.assertThat(result, Is(os.path)) - def try_nonexistent_submodule(self): + def test_nonexistent_submodule(self): # try_import('thing.another', foo) imports 'thing' and returns foo if # 'another' doesn't exist. marker = object() @@ -40,6 +43,51 @@ class TestTryImport(TestCase): self.assertThat(result, Is(marker)) +class TestTryImports(TestCase): + + def test_doesnt_exist(self): + # try_imports('thing', foo) returns foo if 'thing' doesn't exist. + marker = object() + result = try_imports(['doesntexist'], marker) + self.assertThat(result, Is(marker)) + + def test_fallback(self): + result = try_imports(['doesntexist', 'os']) + import os + self.assertThat(result, Is(os)) + + def test_None_is_default_alternative(self): + # try_imports('thing') returns None if 'thing' doesn't exist. + result = try_imports(['doesntexist']) + self.assertThat(result, Is(None)) + + def test_existing_module(self): + # try_imports('thing', foo) imports 'thing' and returns it if it's a + # module that exists. + result = try_imports(['os'], object()) + import os + self.assertThat(result, Is(os)) + + def test_existing_submodule(self): + # try_imports('thing.another', foo) imports 'thing' and returns it if + # it's a module that exists. + result = try_imports(['os.path'], object()) + import os + self.assertThat(result, Is(os.path)) + + def test_nonexistent_submodule(self): + # try_imports('thing.another', foo) imports 'thing' and returns foo if + # 'another' doesn't exist. + marker = object() + result = try_imports(['os.doesntexist'], marker) + self.assertThat(result, Is(marker)) + + def test_fallback_submodule(self): + result = try_imports(['os.doesntexist', 'os.path']) + import os + self.assertThat(result, Is(os.path)) + + def test_suite(): from unittest import TestLoader return TestLoader().loadTestsFromName(__name__)