Merge "Fix import line location"

This commit is contained in:
Zuul 2020-09-30 02:21:35 +00:00 committed by Gerrit Code Review
commit a3d91b227c
2 changed files with 169 additions and 60 deletions

View File

@ -16,6 +16,7 @@
import argparse
import ast
import contextlib
import importlib
import inspect
import os
@ -28,7 +29,7 @@ import six.moves.urllib.parse as urlparse
DECORATOR_MODULE = 'decorators'
DECORATOR_NAME = 'idempotent_id'
DECORATOR_IMPORT = 'tempest.%s' % DECORATOR_MODULE
DECORATOR_IMPORT = 'tempest.lib.%s' % DECORATOR_MODULE
IMPORT_LINE = 'from tempest.lib import %s' % DECORATOR_MODULE
DECORATOR_TEMPLATE = "@%s.%s('%%s')" % (DECORATOR_MODULE,
DECORATOR_NAME)
@ -180,34 +181,125 @@ class TestChecker(object):
elif isinstance(node, ast.ImportFrom):
return '%s.%s' % (node.module, node.names[0].name)
@contextlib.contextmanager
def ignore_site_packages_paths(self):
"""Removes site-packages directories from the sys.path
Source:
- StackOverflow: https://stackoverflow.com/questions/22195382/
- Author: https://stackoverflow.com/users/485844/
"""
paths = sys.path
# remove all third-party paths
# so that only stdlib imports will succeed
sys.path = list(filter(
None,
filter(lambda i: 'site-packages' not in i, sys.path)
))
yield
sys.path = paths
def is_std_lib(self, module):
"""Checks whether the module is part of the stdlib or not
Source:
- StackOverflow: https://stackoverflow.com/questions/22195382/
- Author: https://stackoverflow.com/users/485844/
"""
if module in sys.builtin_module_names:
return True
with self.ignore_site_packages_paths():
imported_module = sys.modules.pop(module, None)
try:
importlib.import_module(module)
except ImportError:
return False
else:
return True
finally:
if imported_module:
sys.modules[module] = imported_module
def _add_import_for_test_uuid(self, patcher, src_parsed, source_path):
with open(source_path) as f:
src_lines = f.read().split('\n')
line_no = 0
tempest_imports = [node for node in src_parsed.body
import_list = [node for node in src_parsed.body
if isinstance(node, ast.Import) or
isinstance(node, ast.ImportFrom)]
if not import_list:
print("(WARNING) %s: The file is not valid as it does not contain "
"any import line! Therefore the import needed by "
"@decorators.idempotent_id is not added!" % source_path)
return
tempest_imports = [node for node in import_list
if self._import_name(node) and
'tempest.' in self._import_name(node)]
if not tempest_imports:
import_snippet = '\n'.join(('', IMPORT_LINE, ''))
else:
for node in tempest_imports:
if self._import_name(node) < DECORATOR_IMPORT:
continue
else:
line_no = node.lineno
import_snippet = IMPORT_LINE
break
for node in tempest_imports:
if self._import_name(node) < DECORATOR_IMPORT:
continue
else:
line_no = tempest_imports[-1].lineno
while True:
if (not src_lines[line_no - 1] or
getattr(self._next_node(src_parsed.body,
tempest_imports[-1]),
'lineno') == line_no or
line_no == len(src_lines)):
break
line_no += 1
import_snippet = '\n'.join((IMPORT_LINE, ''))
line_no = node.lineno
break
else:
if tempest_imports:
line_no = tempest_imports[-1].lineno + 1
# Insert import line between existing tempest imports
if tempest_imports:
patcher.add_patch(source_path, IMPORT_LINE, line_no)
return
# Group space separated imports together
grouped_imports = {}
first_import_line = import_list[0].lineno
for idx, import_line in enumerate(import_list, first_import_line):
group_no = import_line.lineno - idx
group = grouped_imports.get(group_no, [])
group.append(import_line)
grouped_imports[group_no] = group
if len(grouped_imports) > 3:
print("(WARNING) %s: The file contains more than three import "
"groups! This is not valid according to the PEP8 "
"style guide. " % source_path)
# Divide grouped_imports into groupes based on PEP8 style guide
pep8_groups = {}
package_name = self.package.__name__.split(".")[0]
for key in grouped_imports:
module = self._import_name(grouped_imports[key][0]).split(".")[0]
if module.startswith(package_name):
group = pep8_groups.get('3rd_group', [])
pep8_groups['3rd_group'] = group + grouped_imports[key]
elif self.is_std_lib(module):
group = pep8_groups.get('1st_group', [])
pep8_groups['1st_group'] = group + grouped_imports[key]
else:
group = pep8_groups.get('2nd_group', [])
pep8_groups['2nd_group'] = group + grouped_imports[key]
for node in pep8_groups.get('2nd_group', []):
if self._import_name(node) < DECORATOR_IMPORT:
continue
else:
line_no = node.lineno
import_snippet = IMPORT_LINE
break
else:
if pep8_groups.get('2nd_group', []):
line_no = pep8_groups['2nd_group'][-1].lineno + 1
import_snippet = IMPORT_LINE
elif pep8_groups.get('1st_group', []):
line_no = pep8_groups['1st_group'][-1].lineno + 1
import_snippet = '\n' + IMPORT_LINE
else:
line_no = pep8_groups['3rd_group'][0].lineno
import_snippet = IMPORT_LINE + '\n\n'
patcher.add_patch(source_path, import_snippet, line_no)
def get_tests(self):

View File

@ -10,6 +10,7 @@
# License for the specific language governing permissions and limitations
# under the License.
import ast
import importlib
import os
import sys
@ -95,6 +96,8 @@ class TestSourcePatcher(base.TestCase):
class TestTestChecker(base.TestCase):
IMPORT_LINE = "from tempest.lib import decorators\n"
def _test_add_uuid_to_test(self, source_file):
class Fake_test_node():
lineno = 1
@ -127,55 +130,69 @@ class TestTestChecker(base.TestCase):
" pass")
self._test_add_uuid_to_test(source_file)
@staticmethod
def get_mocked_ast_object(lineno, col_offset, module, name, object_type):
ast_object = mock.Mock(spec=object_type)
name_obj = mock.Mock()
ast_object.lineno = lineno
ast_object.col_offset = col_offset
name_obj.name = name
ast_object.module = module
ast_object.names = [name_obj]
return ast_object
def test_add_import_for_test_uuid_no_tempest(self):
patcher = check_uuid.SourcePatcher()
checker = check_uuid.TestChecker(importlib.import_module('tempest'))
fake_file = tempfile.NamedTemporaryFile("w+t")
fake_file = tempfile.NamedTemporaryFile("w+t", delete=False)
source_code = "from unittest import mock\n"
fake_file.write(source_code)
fake_file.close()
class Fake_src_parsed():
body = ['test_node']
checker._import_name = mock.Mock(return_value='fake_module')
body = [TestTestChecker.get_mocked_ast_object(
1, 4, 'unittest', 'mock', ast.ImportFrom)]
checker._add_import_for_test_uuid(patcher, Fake_src_parsed(),
checker._add_import_for_test_uuid(patcher, Fake_src_parsed,
fake_file.name)
(patch_id, patch), = patcher.patches.items()
self.assertEqual(patcher._quote('\n' + check_uuid.IMPORT_LINE + '\n'),
patch)
self.assertEqual('{%s:s}' % patch_id,
patcher.source_files[fake_file.name])
patcher.apply_patches()
with open(fake_file.name, "r") as f:
expected_result = source_code + '\n' + TestTestChecker.IMPORT_LINE
self.assertTrue(expected_result == f.read())
def test_add_import_for_test_uuid_tempest(self):
patcher = check_uuid.SourcePatcher()
checker = check_uuid.TestChecker(importlib.import_module('tempest'))
fake_file = tempfile.NamedTemporaryFile("w+t", delete=False)
test1 = (" def test_test():\n"
" pass\n")
test2 = (" def test_another_test():\n"
" pass\n")
source_code = test1 + test2
source_code = "from tempest import a_fake_module\n"
fake_file.write(source_code)
fake_file.close()
def fake_import_name(node):
return node.name
checker._import_name = fake_import_name
class Fake_src_parsed:
body = [TestTestChecker.get_mocked_ast_object(
1, 4, 'tempest', 'a_fake_module', ast.ImportFrom)]
class Fake_node():
def __init__(self, lineno, col_offset, name):
self.lineno = lineno
self.col_offset = col_offset
self.name = name
class Fake_src_parsed():
body = [Fake_node(1, 4, 'tempest.a_fake_module'),
Fake_node(3, 4, 'another_fake_module')]
checker._add_import_for_test_uuid(patcher, Fake_src_parsed(),
checker._add_import_for_test_uuid(patcher, Fake_src_parsed,
fake_file.name)
(patch_id, patch), = patcher.patches.items()
self.assertEqual(patcher._quote(check_uuid.IMPORT_LINE + '\n'),
patch)
expected_source = patcher._quote(test1) + '{' + patch_id + ':s}' +\
patcher._quote(test2)
self.assertEqual(expected_source,
patcher.source_files[fake_file.name])
patcher.apply_patches()
with open(fake_file.name, "r") as f:
expected_result = source_code + TestTestChecker.IMPORT_LINE
self.assertTrue(expected_result == f.read())
def test_add_import_no_import(self):
patcher = check_uuid.SourcePatcher()
patcher.add_patch = mock.Mock()
checker = check_uuid.TestChecker(importlib.import_module('tempest'))
fake_file = tempfile.NamedTemporaryFile("w+t", delete=False)
fake_file.close()
class Fake_src_parsed:
body = []
checker._add_import_for_test_uuid(patcher, Fake_src_parsed,
fake_file.name)
self.assertTrue(not patcher.add_patch.called)