diff --git a/tests/unit/test_mock.py b/tests/unit/test_mock.py index 91873368..1b451a4f 100644 --- a/tests/unit/test_mock.py +++ b/tests/unit/test_mock.py @@ -11,6 +11,7 @@ # under the License. import ast +import itertools import os import re @@ -19,6 +20,30 @@ import six.moves from tests.unit import test +class Variants(object): + def __init__(self, variants, print_prefix="mock_"): + self.variants = variants + self.print_prefix = print_prefix + + def __repr__(self): + variants = self.variants + if len(variants) > 3: + variants = variants[:3] + variants = [repr(self.print_prefix + var) for var in variants] + return "{" + ", ".join(variants) + ( + ", ...}" if len(self.variants) > 3 else "}") + + def __eq__(self, val): + return getattr(val, "variants", val) == self.variants + + def __contains__(self, val): + return val in self.variants + + +def pairwise_isinstance(*args): + return all(itertools.starmap(isinstance, args)) + + class FuncMockArgsDecoratorsChecker(ast.NodeVisitor): """Recursively visit an AST looking for misusage of mocks in tests. @@ -26,11 +51,15 @@ class FuncMockArgsDecoratorsChecker(ast.NodeVisitor): object name against the argument names. The following is the correct usages:: - @mock.patch("module.abc") + @mock.patch("module.abc") # or + # or @mock.patch(MODULE + ".abc") + # or @mock.patch("%s.abc" % MODULE) where MODULE="module" def test_foobar(self, mock_module_abc): # or `mock_abc' ... - @mock.patch("pkg.ClassName.abc") + @mock.patch("pkg.ClassName.abc") # or + # or @mock.patch(CLASSNAME + ".abc") + # or @mock.patch("%s.abc" % CLASSNAME) where CLASSNAME="pkg.ClassName" def test_foobar(self, mock_class_name_abc): ... @@ -80,15 +109,23 @@ class FuncMockArgsDecoratorsChecker(ast.NodeVisitor): if isinstance(node, ast.Str): val = node.s elif isinstance(node, ast.BinOp): - if isinstance(node.op, ast.Mod): + if pairwise_isinstance( + (node.op, ast.Mod), (node.left, ast.Str), + (node.right, ast.Name)): val = node.left.s % self.globals_[node.right.id] - elif isinstance(node.op, ast.Add): + elif pairwise_isinstance( + (node.op, ast.Add), (node.left, ast.Name), + (node.right, ast.Str)): val = self.globals_[node.left.id] + node.right.s elif isinstance(node, ast.Name): val = self.globals_[node.id] if val is None: - raise ValueError("Unable to find value in %s" % ast.dump(node)) + raise ValueError( + "Unable to find value in %s, only the following are parsed: " + "GLOBAL, 'pkg.foobar', '%%s.foobar' %% GLOBAL or 'GLOBAL + " + "'.foobar'" + % ast.dump(node)) return val @@ -103,40 +140,38 @@ class FuncMockArgsDecoratorsChecker(ast.NodeVisitor): name = regexp.sub(cls.CAMELCASE_SPLIT_REPL, name) return name.lower() - def _get_mocked_class_value_regexp(self, class_name, mocked_name): + def _get_mocked_class_value_variants(self, class_name, mocked_name): class_name = self._camelcase_to_python(class_name) mocked_name = self._camelcase_to_python(mocked_name) if class_name == self.classname_python: # Optional, since class name of the mocked package is the same as # class name of the *TestCase - return "(?:" + class_name + "_)?" + mocked_name + return [mocked_name, class_name + "_" + mocked_name] # Full class name is required otherwise - return class_name + "_" + mocked_name + return [class_name + "_" + mocked_name] - def _get_pkg_optional_regexp(self, tokens): - pkg_regexp = "" - for token in map(self._camelcase_to_python, tokens): - pkg_regexp = ("(?:" + pkg_regexp + "_)?" + token - if pkg_regexp else token) - return "(?:" + pkg_regexp + "_)?" + def _add_pkg_optional_prefixes(self, tokens, variants): + prefixed_variants = list(variants) + for token in map(self._camelcase_to_python, reversed(tokens)): + prefixed_variants.append(token + "_" + prefixed_variants[-1]) + return prefixed_variants - def _get_mocked_name_regexp(self, name): + def _get_mocked_name_variants(self, name): tokens = name.split(".") - if len(tokens) > 1: - name = self._camelcase_to_python(tokens.pop()) + variants = [self._camelcase_to_python(tokens.pop())] + if tokens: if tokens[-1][0].isupper(): # Mocked something inside a class, check if we should require # the class name to be present in mock argument - name = self._get_mocked_class_value_regexp( - class_name=tokens[-1], - mocked_name=name) - pkg_regexp = self._get_pkg_optional_regexp(tokens) - name = pkg_regexp + name - return name + variants = self._get_mocked_class_value_variants( + class_name=tokens.pop(), + mocked_name=variants[0]) + variants = self._add_pkg_optional_prefixes(tokens, variants) + return Variants(variants) - def _get_mock_decorators_regexp(self, funccall): + def _get_mock_decorators_variants(self, funccall): """Return all the mock.patch{,.object} decorated for function.""" mock_decorators = [] @@ -153,9 +188,9 @@ class FuncMockArgsDecoratorsChecker(ast.NodeVisitor): else: continue - decname = self._get_mocked_name_regexp(decname) - - mock_decorators.append(decname) + mock_decorators.append( + self._get_mocked_name_variants(decname) + ) return mock_decorators @@ -184,6 +219,8 @@ class FuncMockArgsDecoratorsChecker(ast.NodeVisitor): exec(code, self.globals_) except Exception: pass + self.globals_.pop("__builtins__", None) + self.globals_.pop("builtins", None) def visit_ClassDef(self, node): classname_camel = node.name @@ -194,28 +231,55 @@ class FuncMockArgsDecoratorsChecker(ast.NodeVisitor): self.generic_visit(node) - def check_name(self, arg, dec): - return (arg is not None and dec is not None - and (arg == dec or re.match(dec, arg))) + def check_name(self, arg, dec_vars): + return (dec_vars is not None and arg in dec_vars) def visit_FunctionDef(self, node): self.generic_visit(node) - mock_decs = self._get_mock_decorators_regexp(node) + mock_decs = self._get_mock_decorators_variants(node) if not mock_decs: return mock_args = self._get_mock_args(node) - for arg, dec in six.moves.zip_longest(mock_args, mock_decs): - if not self.check_name(arg, dec): + error_msgs = [] + mismatched = False + for arg, dec_vars in six.moves.zip_longest(mock_args, mock_decs): + if not self.check_name(arg, dec_vars): + if arg and dec_vars: + error_msgs.append( + ("Argument '%(arg)s' misnamed; should be either of " + "%(dec)s that is derived from the mock decorator " + "args.\n") % { + "arg": arg, "dec": dec_vars} + ) + elif not arg: + error_msgs.append( + "Missing or malformed argument for %s decorator." + % dec_vars) + mismatched = True + elif not dec_vars: + error_msgs.append( + "Missing or malformed decorator for '%s' argument." + % arg) + mismatched = True + + if error_msgs: + if mismatched: self.errors.append({ "lineno": node.lineno, "args": mock_args, - "decs": mock_decs + "decs": mock_decs, + "messages": error_msgs + }) + else: + self.errors.append({ + "lineno": node.lineno, + "mismatch_pairs": list(zip(mock_args, mock_decs)), + "messages": error_msgs }) - break class MockUsageCheckerTestCase(test.TestCase): @@ -245,4 +309,24 @@ class MockUsageCheckerTestCase(test.TestCase): dict(filename=filename, **error) for error in visitor.errors) + if errors: + print(FuncMockArgsDecoratorsChecker.__doc__) + print( + "\n\n" + "The following errors were found during the described check:") + for error in errors: + print("\n\n" + "Errors at file %(filename)s line %(lineno)d:\n\n" + "%(message)s" % { + "message": "\n".join(error["messages"]), + "filename": error["filename"], + "lineno": error["lineno"]}) + + # NOTE(pboldin): When the STDOUT is shuted the below is the last + # resort to know what is wrong with the mock names. + for error in errors: + error["messages"] = [ + message.rstrip().replace("\n", " ").replace("\t", "") + for message in error["messages"] + ] self.assertEqual([], errors)