Fix import line location

There were two issues with the previous implementation of
_add_import_for_test_uuid():

1) Sometimes the import line was added to the end of the file.
   For example,
     - when the import line belonged to the first line,
     - when there was no tempest import line.

2) Sometimes the import line was not added to the correct import
   group (2nd - related third party imports) defined by the pep8
   style guide.

This patch makes sure that both issues 1) and 2) are solved. The
import line is now by default added between the tempest imports.
If there is no tempest import the import line is added to the
second import group.

Change-Id: Icbac702d295f7f75b3259ad68dd2345cc1e4d90b
This commit is contained in:
lkuchlan 2020-01-07 12:53:55 +02:00 committed by Lukas Piwowarski
parent 50ec7d74c3
commit c8b966ff61
2 changed files with 169 additions and 60 deletions

View File

@ -16,6 +16,7 @@
import argparse
import ast
import contextlib
import importlib
import inspect
import os
@ -28,7 +29,7 @@ import six.moves.urllib.parse as urlparse
DECORATOR_MODULE = 'decorators'
DECORATOR_NAME = 'idempotent_id'
DECORATOR_IMPORT = 'tempest.%s' % DECORATOR_MODULE
DECORATOR_IMPORT = 'tempest.lib.%s' % DECORATOR_MODULE
IMPORT_LINE = 'from tempest.lib import %s' % DECORATOR_MODULE
DECORATOR_TEMPLATE = "@%s.%s('%%s')" % (DECORATOR_MODULE,
DECORATOR_NAME)
@ -180,34 +181,125 @@ class TestChecker(object):
elif isinstance(node, ast.ImportFrom):
return '%s.%s' % (node.module, node.names[0].name)
@contextlib.contextmanager
def ignore_site_packages_paths(self):
"""Removes site-packages directories from the sys.path
Source:
- StackOverflow: https://stackoverflow.com/questions/22195382/
- Author: https://stackoverflow.com/users/485844/
"""
paths = sys.path
# remove all third-party paths
# so that only stdlib imports will succeed
sys.path = list(filter(
None,
filter(lambda i: 'site-packages' not in i, sys.path)
))
yield
sys.path = paths
def is_std_lib(self, module):
"""Checks whether the module is part of the stdlib or not
Source:
- StackOverflow: https://stackoverflow.com/questions/22195382/
- Author: https://stackoverflow.com/users/485844/
"""
if module in sys.builtin_module_names:
return True
with self.ignore_site_packages_paths():
imported_module = sys.modules.pop(module, None)
try:
importlib.import_module(module)
except ImportError:
return False
else:
return True
finally:
if imported_module:
sys.modules[module] = imported_module
def _add_import_for_test_uuid(self, patcher, src_parsed, source_path):
with open(source_path) as f:
src_lines = f.read().split('\n')
line_no = 0
tempest_imports = [node for node in src_parsed.body
import_list = [node for node in src_parsed.body
if isinstance(node, ast.Import) or
isinstance(node, ast.ImportFrom)]
if not import_list:
print("(WARNING) %s: The file is not valid as it does not contain "
"any import line! Therefore the import needed by "
"@decorators.idempotent_id is not added!" % source_path)
return
tempest_imports = [node for node in import_list
if self._import_name(node) and
'tempest.' in self._import_name(node)]
if not tempest_imports:
import_snippet = '\n'.join(('', IMPORT_LINE, ''))
else:
for node in tempest_imports:
if self._import_name(node) < DECORATOR_IMPORT:
continue
else:
line_no = node.lineno
import_snippet = IMPORT_LINE
break
for node in tempest_imports:
if self._import_name(node) < DECORATOR_IMPORT:
continue
else:
line_no = tempest_imports[-1].lineno
while True:
if (not src_lines[line_no - 1] or
getattr(self._next_node(src_parsed.body,
tempest_imports[-1]),
'lineno') == line_no or
line_no == len(src_lines)):
break
line_no += 1
import_snippet = '\n'.join((IMPORT_LINE, ''))
line_no = node.lineno
break
else:
if tempest_imports:
line_no = tempest_imports[-1].lineno + 1
# Insert import line between existing tempest imports
if tempest_imports:
patcher.add_patch(source_path, IMPORT_LINE, line_no)
return
# Group space separated imports together
grouped_imports = {}
first_import_line = import_list[0].lineno
for idx, import_line in enumerate(import_list, first_import_line):
group_no = import_line.lineno - idx
group = grouped_imports.get(group_no, [])
group.append(import_line)
grouped_imports[group_no] = group
if len(grouped_imports) > 3:
print("(WARNING) %s: The file contains more than three import "
"groups! This is not valid according to the PEP8 "
"style guide. " % source_path)
# Divide grouped_imports into groupes based on PEP8 style guide
pep8_groups = {}
package_name = self.package.__name__.split(".")[0]
for key in grouped_imports:
module = self._import_name(grouped_imports[key][0]).split(".")[0]
if module.startswith(package_name):
group = pep8_groups.get('3rd_group', [])
pep8_groups['3rd_group'] = group + grouped_imports[key]
elif self.is_std_lib(module):
group = pep8_groups.get('1st_group', [])
pep8_groups['1st_group'] = group + grouped_imports[key]
else:
group = pep8_groups.get('2nd_group', [])
pep8_groups['2nd_group'] = group + grouped_imports[key]
for node in pep8_groups.get('2nd_group', []):
if self._import_name(node) < DECORATOR_IMPORT:
continue
else:
line_no = node.lineno
import_snippet = IMPORT_LINE
break
else:
if pep8_groups.get('2nd_group', []):
line_no = pep8_groups['2nd_group'][-1].lineno + 1
import_snippet = IMPORT_LINE
elif pep8_groups.get('1st_group', []):
line_no = pep8_groups['1st_group'][-1].lineno + 1
import_snippet = '\n' + IMPORT_LINE
else:
line_no = pep8_groups['3rd_group'][0].lineno
import_snippet = IMPORT_LINE + '\n\n'
patcher.add_patch(source_path, import_snippet, line_no)
def get_tests(self):

View File

@ -10,6 +10,7 @@
# License for the specific language governing permissions and limitations
# under the License.
import ast
import importlib
import tempfile
from unittest import mock
@ -52,6 +53,8 @@ class TestSourcePatcher(base.TestCase):
class TestTestChecker(base.TestCase):
IMPORT_LINE = "from tempest.lib import decorators\n"
def _test_add_uuid_to_test(self, source_file):
class Fake_test_node():
lineno = 1
@ -84,55 +87,69 @@ class TestTestChecker(base.TestCase):
" pass")
self._test_add_uuid_to_test(source_file)
@staticmethod
def get_mocked_ast_object(lineno, col_offset, module, name, object_type):
ast_object = mock.Mock(spec=object_type)
name_obj = mock.Mock()
ast_object.lineno = lineno
ast_object.col_offset = col_offset
name_obj.name = name
ast_object.module = module
ast_object.names = [name_obj]
return ast_object
def test_add_import_for_test_uuid_no_tempest(self):
patcher = check_uuid.SourcePatcher()
checker = check_uuid.TestChecker(importlib.import_module('tempest'))
fake_file = tempfile.NamedTemporaryFile("w+t")
fake_file = tempfile.NamedTemporaryFile("w+t", delete=False)
source_code = "from unittest import mock\n"
fake_file.write(source_code)
fake_file.close()
class Fake_src_parsed():
body = ['test_node']
checker._import_name = mock.Mock(return_value='fake_module')
body = [TestTestChecker.get_mocked_ast_object(
1, 4, 'unittest', 'mock', ast.ImportFrom)]
checker._add_import_for_test_uuid(patcher, Fake_src_parsed(),
checker._add_import_for_test_uuid(patcher, Fake_src_parsed,
fake_file.name)
(patch_id, patch), = patcher.patches.items()
self.assertEqual(patcher._quote('\n' + check_uuid.IMPORT_LINE + '\n'),
patch)
self.assertEqual('{%s:s}' % patch_id,
patcher.source_files[fake_file.name])
patcher.apply_patches()
with open(fake_file.name, "r") as f:
expected_result = source_code + '\n' + TestTestChecker.IMPORT_LINE
self.assertTrue(expected_result == f.read())
def test_add_import_for_test_uuid_tempest(self):
patcher = check_uuid.SourcePatcher()
checker = check_uuid.TestChecker(importlib.import_module('tempest'))
fake_file = tempfile.NamedTemporaryFile("w+t", delete=False)
test1 = (" def test_test():\n"
" pass\n")
test2 = (" def test_another_test():\n"
" pass\n")
source_code = test1 + test2
source_code = "from tempest import a_fake_module\n"
fake_file.write(source_code)
fake_file.close()
def fake_import_name(node):
return node.name
checker._import_name = fake_import_name
class Fake_src_parsed:
body = [TestTestChecker.get_mocked_ast_object(
1, 4, 'tempest', 'a_fake_module', ast.ImportFrom)]
class Fake_node():
def __init__(self, lineno, col_offset, name):
self.lineno = lineno
self.col_offset = col_offset
self.name = name
class Fake_src_parsed():
body = [Fake_node(1, 4, 'tempest.a_fake_module'),
Fake_node(3, 4, 'another_fake_module')]
checker._add_import_for_test_uuid(patcher, Fake_src_parsed(),
checker._add_import_for_test_uuid(patcher, Fake_src_parsed,
fake_file.name)
(patch_id, patch), = patcher.patches.items()
self.assertEqual(patcher._quote(check_uuid.IMPORT_LINE + '\n'),
patch)
expected_source = patcher._quote(test1) + '{' + patch_id + ':s}' +\
patcher._quote(test2)
self.assertEqual(expected_source,
patcher.source_files[fake_file.name])
patcher.apply_patches()
with open(fake_file.name, "r") as f:
expected_result = source_code + TestTestChecker.IMPORT_LINE
self.assertTrue(expected_result == f.read())
def test_add_import_no_import(self):
patcher = check_uuid.SourcePatcher()
patcher.add_patch = mock.Mock()
checker = check_uuid.TestChecker(importlib.import_module('tempest'))
fake_file = tempfile.NamedTemporaryFile("w+t", delete=False)
fake_file.close()
class Fake_src_parsed:
body = []
checker._add_import_for_test_uuid(patcher, Fake_src_parsed,
fake_file.name)
self.assertTrue(not patcher.add_patch.called)