diff --git a/ryu/tests/unit/lib/test_import_module.py b/ryu/tests/unit/lib/test_import_module.py index 25264c36..b8561d20 100644 --- a/ryu/tests/unit/lib/test_import_module.py +++ b/ryu/tests/unit/lib/test_import_module.py @@ -44,9 +44,8 @@ class Test_import_module(unittest.TestCase): eq_("this is ccc", ccc.name) ddd = import_module('./lib/test_mod/ddd/mod.py') # Note: When importing a module by filename, if module file name - # is duplicated, import_module returns a module instance which is - # imported before. - eq_("this is ccc", ddd.name) + # is duplicated, import_module reload (override) a module instance. + eq_("this is ddd", ddd.name) def test_import_same_module1(self): from ryu.tests.unit.lib.test_mod import eee as eee1 diff --git a/ryu/utils.py b/ryu/utils.py index 44d6bf3d..d8bbc53b 100644 --- a/ryu/utils.py +++ b/ryu/utils.py @@ -80,29 +80,34 @@ def _find_loaded_module(modpath): return None -def import_module(modname): +def _import_module_file(path): + abspath = os.path.abspath(path) + # Backup original sys.path before appending path to file + original_path = list(sys.path) + sys.path.append(os.path.dirname(abspath)) + modname = chop_py_suffix(os.path.basename(abspath)) try: - # Import module with python module path - # e.g.) modname = 'module.path.module_name' - return importlib.import_module(modname) - except (ImportError, TypeError): - # In this block, we retry to import module when modname is filename - # e.g.) modname = 'module/path/module_name.py' - abspath = os.path.abspath(modname) - # Check if specified modname is already imported - mod = _find_loaded_module(abspath) - if mod: - return mod - # Backup original sys.path before appending path to file - original_path = list(sys.path) - sys.path.append(os.path.dirname(abspath)) - # Remove python suffix - name = chop_py_suffix(os.path.basename(modname)) - # Retry to import - mod = importlib.import_module(name) - # Restore sys.path + return load_source(modname, abspath) + finally: + # Restore original sys.path sys.path = original_path - return mod + + +def import_module(modname): + if os.path.exists(modname): + try: + # Try to import module since 'modname' is a valid path to a file + # e.g.) modname = './path/to/module/name.py' + return _import_module_file(modname) + except SyntaxError: + # The file didn't parse as valid Python code, try + # importing module assuming 'modname' is a Python module name + # e.g.) modname = 'path.to.module.name' + return importlib.import_module(modname) + else: + # Import module assuming 'modname' is a Python module name + # e.g.) modname = 'path.to.module.name' + return importlib.import_module(modname) def round_up(x, y):