neutron-lib/tools/pyir.py

1534 lines
45 KiB
Python
Executable File

#!/usr/bin/env python
# Copyright 2016 VMware, Inc.
# All Rights Reserved
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License. You may obtain
# a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations
# under the License.
#
import abc
import contextlib
import imp
import inspect
import os
from os import path
import re
import shutil
import sys
import tempfile
import argparse
import six
from oslo_serialization import jsonutils
__version__ = '0.0.2'
# NOTE(boden): This is a prototype and needs additional love
_MOCK_SRC = '''
class _PyIREmptyMock_(object):
def __init__(self, *args, **kwargs):
self._args_ = args
self._kwargs_ = kwargs
def __or__(self, o):
return o
def __ror__(self, o):
return o
def __xor__(self, o):
return o
def __rxor__(self, o):
return o
def __and__(self, o):
return o
def __rand__(self, o):
return o
def __rshift__(self, o):
return o
def __rrshift__(self, o):
return o
def __lshift__(self, o):
return o
def __rlshift__(self, o):
return o
def __pow__(self, o):
return o
def __rpow__(self, o):
return o
def __divmod__(self, o):
return o
def __rdivmod__(self, o):
return o
def __mod__(self, o):
return o
def __rmod__(self, o):
return o
def __floordiv__(self, o):
return o
def __rfloordiv__(self, o):
return o
def __truediv__(self, o):
return o
def __rtrudiv__(self, o):
return o
def __add__(self, o):
return o
def __radd__(self, o):
return o
def __sub__(self, o):
return o
def __rsub__(self, o):
return o
def __mul__(self, o):
return o
def __rmul__(self, o):
return o
def __matmul__(self, o):
return o
def __rmatmul__(self, o):
return o
def __getattribute__(self, name):
return _PyIREmptyMock_()
def __call__(self, *args, **kwargs):
return _PyIREmptyMock_()
def __iter__(self):
return [].__iter__()
def __getitem__(self, item):
return _PyIREmptyMock_()
def __setitem__(self, key, value):
pass
def __delitem__(self, key):
pass
def __enter__(self):
return self
def __exit__(self, type, value, traceback):
pass
class _PyIREmptyImport_(_PyIREmptyMock_):
pass
'''
_MOCK_CLASS_NAME = '_PyIREmptyMock_'
_MOCK_IMPORT_CLASS_NAME = '_PyIREmptyImport_'
UNKNOWN_VAL = 'PYIR UNKNOWN VALUE'
_BLACKLIST = [re.compile(".*\.%s" % _MOCK_CLASS_NAME),
re.compile(".*\.%s" % _MOCK_IMPORT_CLASS_NAME)]
def blacklist_filter(value):
for pattern in _BLACKLIST:
if pattern.match(value):
return False
return True
def add_blacklist_from_csv_str(csv_str):
global _BLACKLIST
_BLACKLIST.extend([re.compile(p)
for p in split_on_token(csv_str, ',')])
def for_tokens(the_str, tokens, callback):
in_str = []
tokens = list(tokens)
index = 0
def _compare_tokens(idx):
hits = []
for token in tokens:
if the_str[idx:].startswith(token):
hits.append(token)
return hits
for c in the_str:
if c == '\'' or c == '\"':
if in_str and in_str[len(in_str) - 1] == c:
in_str.pop()
else:
in_str.append(c)
elif not in_str:
matching_tokens = _compare_tokens(index)
if matching_tokens:
callback(matching_tokens, index, the_str[index:])
index += 1
def token_indexes(the_str, tokens):
indexes = []
def _count(toks, idx, substring):
indexes.append(idx)
for_tokens(the_str, tokens, _count)
return indexes
def split_on_token(the_str, token):
indexes = token_indexes(the_str, [token])
if not indexes:
return [the_str]
strs = []
indexes.insert(0, None)
for start, end in zip(indexes, indexes[1:] + [None]):
start = 0 if start is None else start + 1
if end is None:
end = len(the_str)
strs.append(the_str[start:end])
return strs
def count_tokens(the_str, tokens):
return len(token_indexes(the_str, tokens))
def remove_tokens(the_str, tokens):
in_str = []
tokens = list(tokens)
index = 0
new_str = ''
def _token(idx):
for token in tokens:
if the_str[idx:].startswith(token):
return token
return None
while index < len(the_str):
c = the_str[index]
if c == '\'' or c == '\"':
if in_str and in_str[len(in_str) - 1] == c:
in_str.pop()
else:
in_str.append(c)
elif not in_str:
tok = _token(index)
if tok:
index += len(tok)
continue
new_str += c
index += 1
return new_str
def remove_brackets(the_str):
return remove_tokens(the_str, ['(', ')'])
def parent_path(file_path):
if not file_path or file_path == '/':
return None
return path.abspath(path.join(file_path, '..'))
def is_py_file(file_path):
file_path = file_path if filter(blacklist_filter, [file_path]) else None
return (file_path and
path.isfile(file_path) and
file_path.endswith('.py'))
def is_py_dir(dir_path):
if not filter(blacklist_filter, [dir_path]):
return False
if path.isdir(dir_path):
for f in os.listdir(dir_path):
f = path.join(dir_path, f)
if is_py_file(f):
return True
return False
def is_py_package_dir(dir_path):
if not filter(blacklist_filter, [dir_path]):
return False
if path.isdir(dir_path):
return '__init__.py' in os.listdir(dir_path)
return False
def parent_package_names(file_or_dir_path):
pkg_names = []
file_or_dir_path = parent_path(file_or_dir_path)
while file_or_dir_path:
if is_py_package_dir(file_or_dir_path):
pkg_names.append(os.path.basename(file_or_dir_path))
else:
break
file_or_dir_path = parent_path(file_or_dir_path)
return None if not pkg_names else reversed(pkg_names)
def whitespace(line):
if line.isspace():
return line, ''
char_index = 0
for char_index in range(len(line)):
if not line[char_index].isspace():
break
return line[:char_index], line[char_index:].strip()
def ordered(obj):
if isinstance(obj, dict):
return sorted((k, ordered(v)) for k, v in obj.items())
if isinstance(obj, list):
return sorted(ordered(x) for x in obj)
else:
return obj
def json_primitive(val):
if isinstance(val, (six.string_types, six.text_type,
six.integer_types, bool)):
return str(val)
elif str(val).startswith('<') or type(val) in [dict, list, set, tuple]:
return str(type(val))
elif (str(val).count(_MOCK_CLASS_NAME) or
str(val).count(_MOCK_IMPORT_CLASS_NAME)):
return UNKNOWN_VAL
return val
def is_mock_import(obj):
return _MOCK_IMPORT_CLASS_NAME in str(obj)
def _member_filter(obj):
return not inspect.isbuiltin(obj) and not inspect.ismodule(obj)
class PyFiles(object):
def __init__(self, files):
self._files = set(PyFiles.check_py_paths(files))
@staticmethod
def check_py_paths(py_paths):
checked = []
for f in py_paths:
f = path.abspath(f)
assert path.exists(f)
if path.isfile(f):
if not is_py_file(f):
raise IOError("'%s' is not a .py file." % f)
else:
if not is_py_package_dir(f):
raise IOError("'%s' doesn't contain __init__.py." % f)
checked.append(f)
return checked
@property
def files(self):
return set(self._files)
@property
def has_files(self):
return len(self._files) > 0
def _path_to_tmp_tree(self, tree_dir, src_path):
tree_dest = path.join(tree_dir, path.basename(src_path))
parent_dirs = list(parent_package_names(src_path) or [])
if parent_dirs:
tree_dest = path.join(tree_dir, *tuple(parent_dirs))
os.makedirs(tree_dest)
subpath = tree_dir
for subdir in parent_dirs:
subpath = path.join(subpath, subdir)
open(path.join(subpath, '__init__.py'), 'a').close()
tree_dest = path.join(tree_dest, path.basename(src_path))
copy_fn = shutil.copytree if path.isdir(src_path) else shutil.copyfile
copy_fn(src_path, tree_dest)
return tree_dest
def to_tmp_tree(self, tree_dir=None):
tree_dir = tree_dir or tempfile.mkdtemp()
assert path.isdir(tree_dir)
subtrees = []
for f in self._files:
subtrees.append(self._path_to_tmp_tree(tree_dir, f))
return tree_dir, subtrees
@contextlib.contextmanager
def tmp_tree(self, delete_on_exit=True):
tree = None
try:
tree, subtress = self.to_tmp_tree()
yield tree
finally:
if tree and delete_on_exit:
shutil.rmtree(tree)
@staticmethod
def filter_all_py_files(root_dir, filters):
for child in os.listdir(root_dir):
child_path = path.join(root_dir, child)
if is_py_file(child_path):
PyFile.rewrite(child_path, filters)
elif is_py_dir(child_path):
PyFiles.filter_all_py_files(child_path, filters)
class PyLine(object):
def __init__(self, ws, logical_line, py_file):
self.ws = '' if ws is None else ws
if ws == "\n":
self.ws = ''
self.logical = '' if logical_line is None else logical_line
self._py_file = py_file
@property
def is_str_line(self):
return ((self.logical.startswith('\'') and
self.logical.endswith('\'')) or
(self.logical.startswith('\"') and
self.logical.endswith('\"')))
@property
def is_empty_line(self):
return len(self.logical.strip()) == 0
@property
def indent(self):
return self.ws.count(' ') + (self.ws.count("\t") * 4)
@property
def bracket_tics(self):
return (count_tokens(self.logical, PyLineTokens.OPEN_B) -
count_tokens(self.logical, PyLineTokens.CLOSED_B))
@property
def physical_line(self):
return str(self)
@property
def is_comment(self):
return self.logical.startswith(PyLineTokens.COMMENT)
def comment_out(self):
if not self.is_comment:
self.logical = PyLineTokens.COMMENT + self.logical
@property
def has_unmatched_brackets(self):
return self.bracket_tics != 0
@property
def is_continuation(self):
return (self.logical.endswith(PyLineTokens.BACKSLASH) or
self.has_unmatched_brackets)
@property
def is_space(self):
if self.logical == '':
return self.ws.isspace()
return self.logical.isspace()
@property
def file_path(self):
return self._py_file.name
@staticmethod
def from_string_lines(lines, py_file=None):
py_lines = []
for l in lines:
ws, logical = whitespace(l)
py_lines.append(PyLine(ws, logical, py_file))
return py_lines
def __str__(self):
return self.ws + self.logical
class FilterMarker(object):
def __init__(self, filt, markers=None):
self._filter = filt
self.markers = markers or []
def mark(self, line):
if self._filter.mark(line):
self.markers.append(line)
def filter(self, py_file):
for marker in self.markers:
self._filter.filter(marker, py_file)
def reset(self):
self.markers = []
class PyFile(object):
def __init__(self, py_filters):
self._markers = []
self._add_filters(py_filters)
self._lines = []
def prepend_lines(self, lines):
lines = list(lines)
lines.extend(self._lines)
self._lines = lines
def reset(self):
self._lines = []
for m in self._markers:
m.reset()
def first_line(self):
return self._lines[0] if self._lines else None
def next_line(self, py_line):
if py_line not in self._lines:
return None
index = self._lines.index(py_line) + 1
if index >= len(self._lines):
return None
return self._lines[index]
def prev_line(self, py_line):
if py_line not in self._lines:
return None
index = self._lines.index(py_line) - 1
if index <= 0:
return None
return self._lines[index]
def del_line(self, py_line):
self._lines.remove(py_line)
def get_line(self, py_line):
return (None if not self.contains_line(py_line)
else self._lines[self._lines.index(py_line)])
def contains_line(self, py_line):
return py_line in self._lines
def _add_filters(self, filters):
self._markers.extend([FilterMarker(f) for f in filters])
def _mark_line_filter(self, line):
for marker in self._markers:
marker.mark(line)
def load_path(self, py_path):
with open(py_path, 'r') as py_file:
for line in py_file:
ws, logical = whitespace(line)
line = PyLine(ws, logical, py_file)
self._lines.append(line)
self._mark_line_filter(line)
def filter(self):
if not self._lines or not self._markers:
return None
for marker in self._markers:
marker.filter(self)
def insert_after(self, py_line, py_line_to_add):
if py_line not in self._lines:
return False
self._lines.insert(self._lines.index(py_line) + 1, py_line_to_add)
return True
def to_file_str(self):
buff = ''
for line in self._lines:
buff += str(line) + "\n"
return buff
def save(self, py_path):
with open(py_path, 'w') as py_file:
py_file.write(self.to_file_str())
@staticmethod
def filter_to_file_str(py_path, filters):
py_file = PyFile(filters)
py_file.load_path(py_path)
py_file.filter()
return py_file.to_file_str()
@staticmethod
def rewrite(py_path, filters):
py_file = PyFile(filters)
py_file.load_path(py_path)
py_file.filter()
py_file.save(py_path)
class ImportParser(object):
def __init__(self):
self.names = []
self.modules = []
def _segs(self, the_str, token=' '):
return [s.strip() for s in the_str.split(token)
if s and not s.isspace()]
def _lstrip(self, the_str, to_strip):
return the_str[len(to_strip):].strip()
def _next_token(self, the_str, delim=' ', strip=True):
try:
idx = the_str.index(delim)
content = the_str[:idx]
remainder = the_str[idx:]
if strip:
content.strip()
remainder.strip()
return content, remainder
except ValueError:
return None, None
def _parse_from(self, import_str):
import_str = self._lstrip(import_str, 'from ')
module_name, import_str = self._next_token(import_str)
import_str = self._lstrip(import_str, 'import ')
for name_def in self._segs(import_str, token=','):
if name_def.count(' as '):
segs = self._segs(name_def, token=' as ')
self.names.append(segs[1])
self.modules.append(module_name + '.' + segs[0])
else:
self.names.extend(self._segs(name_def, '.'))
self.modules.append(module_name)
def _parse_import(self, import_str):
import_str = self._lstrip(import_str, 'import ')
for name_def in self._segs(import_str, token=','):
if name_def.count(' as '):
segs = self._segs(name_def, token=' as ')
self.names.append(segs[1])
self.modules.append(segs[0])
else:
self.names.extend(self._segs(name_def, '.'))
self.modules.append(name_def)
def reset(self):
self.names = []
self.modules = []
def is_statement(self, import_str):
return import_str.startswith(('import ', 'from ', ))
def parse(self, import_str):
self.reset()
import_str = import_str.replace('(', '').replace(')', '')
if import_str.startswith('import '):
self._parse_import(import_str)
elif import_str.startswith('from '):
self._parse_from(import_str)
else:
raise IOError("Invalid import string: %s" % import_str)
return self
class PyLineTokens(object):
COMMENT = '#'
BACKSLASH = '\\'
DECORATOR = '@'
OPEN_B = '('
CLOSED_B = ')'
@six.add_metaclass(abc.ABCMeta)
class AbstractFilter(object):
@abc.abstractmethod
def mark(self, py_line):
pass
@abc.abstractmethod
def filter(self, py_line, py_file):
pass
@six.add_metaclass(abc.ABCMeta)
class AbstractPerFileFilter(AbstractFilter):
def __init__(self):
self._marked = []
def mark(self, py_line):
if py_line.file_path not in self._marked:
self._marked.append(py_line.file_path)
return True
return False
@abc.abstractmethod
def _filter(self, py_line, py_file):
pass
def filter(self, py_line, py_file):
if py_line.file_path not in self._marked:
return
self._marked.remove(py_line.file_path)
return self._filter(py_line, py_file)
class CommentOutDecorators(AbstractFilter):
def mark(self, py_line):
if py_line.is_str_line:
return False
if py_line.logical.startswith(PyLineTokens.DECORATOR):
return True
return False
def filter(self, py_line, py_file):
if not py_file.get_line(py_line):
return
py_line.comment_out()
class StripTrailingComments(AbstractFilter):
_RE = re.compile('^([^#]*)#(.*)$')
def mark(self, py_line):
if (py_line.is_str_line or
not count_tokens(py_line.logical, PyLineTokens.COMMENT)):
return False
m = StripTrailingComments._RE.match(py_line.logical)
return True if m else False
def filter(self, py_line, py_file):
if not py_file.get_line(py_line):
return
m = StripTrailingComments._RE.match(py_line.logical)
py_line.logical = m.group(1).strip()
class AddMockDefinitions(AbstractPerFileFilter):
_LINES = PyLine.from_string_lines(_MOCK_SRC.split("\n"))
def _filter(self, py_line, py_file):
py_file.prepend_lines(AddMockDefinitions._LINES)
class PassEmptyDef(AbstractPerFileFilter):
def _has_body(self, def_py_line, py_file):
indent = def_py_line.indent
line = py_file.next_line(def_py_line)
while line:
if line.is_empty_line:
line = py_file.next_line(line)
continue
elif line.indent > indent:
return True
elif line.indent <= indent:
return False
else:
line = py_file.next_line(line)
return False
def _filter(self, py_line, py_file):
line = py_file.first_line()
while line:
if (line.logical.startswith(('class ', 'def ',)) and
not self._has_body(line, py_file)):
pass_line = PyLine(line.ws + " ", 'pass', py_file)
py_file.insert_after(line, pass_line)
line = py_file.next_line(pass_line)
else:
line = py_file.next_line(line)
class RemoveDocStrings(AbstractPerFileFilter):
_COMMENT = '"""'
def _comment_count(self, py_line):
return count_tokens(py_line.logical, RemoveDocStrings._COMMENT)
def _safe_delete_line(self, py_line, py_file):
if py_line.logical.endswith((',', ')',)):
return
py_file.del_line(py_line)
def _filter(self, py_line, py_file):
in_comment = False
last_line = line = py_file.first_line()
while line:
comment_count = self._comment_count(line)
if comment_count:
if in_comment:
in_comment = False
elif comment_count == 1:
in_comment = True
py_file.del_line(line)
elif in_comment:
py_file.del_line(line)
if not py_file.contains_line(line):
if not py_file.contains_line(last_line):
last_line = line = py_file.first_line()
else:
line = py_file.next_line(last_line)
else:
next_line = py_file.next_line(line)
last_line = line
line = next_line
class RemoveCommentLines(AbstractFilter):
def mark(self, py_line):
return py_line.logical.startswith(PyLineTokens.COMMENT)
def filter(self, py_line, py_file):
py_line = py_file.get_line(py_line)
if py_line and self.mark(py_line):
py_file.del_line(py_line)
@six.add_metaclass(abc.ABCMeta)
class AbstractMultiLineCollector(AbstractFilter):
def __init__(self):
self._comment_stripper = StripTrailingComments()
def _strip_backslash(self, py_line):
if py_line.logical.endswith(PyLineTokens.BACKSLASH):
py_line.logical = py_line.logical[:-1].strip()
return True
return False
def _collect(self, py_line, py_file, continue_fn):
self._strip_backslash(py_line)
next_line = py_file.next_line(py_line)
while next_line:
if not next_line.is_comment:
if self._comment_stripper.mark(next_line):
self._comment_stripper.filter(next_line, py_file)
if not next_line.is_space:
py_line.logical += ' ' + next_line.logical
py_file.del_line(next_line)
if continue_fn(py_line):
next_line = py_file.next_line(py_line)
continue
else:
break
def _collect_backslash(self, py_line, py_file):
self._strip_backslash(py_line)
self._collect(py_line, py_file, self._strip_backslash)
def _collect_brackets(self, py_line, py_file):
self._collect(py_line, py_file, lambda l: l.has_unmatched_brackets)
def filter(self, py_line, py_file):
if py_line.logical.endswith(PyLineTokens.BACKSLASH):
self._collect_backslash(py_line, py_file)
else:
self._collect_brackets(py_line, py_file)
class MergeMultiLineImports(AbstractMultiLineCollector):
def mark(self, py_line):
if py_line.is_str_line:
return False
logical = py_line.logical
return (logical.startswith(('import ', 'from ',)) and
py_line.is_continuation)
def filter(self, py_line, py_file):
super(MergeMultiLineImports, self).filter(py_line, py_file)
py_line.logical = remove_brackets(py_line.logical)
class MergeMultiLineClass(AbstractMultiLineCollector):
def mark(self, py_line):
if py_line.is_str_line:
return False
return py_line.logical.startswith('class ') and py_line.is_continuation
class MergeMultiLineDef(AbstractMultiLineCollector):
def mark(self, py_line):
if py_line.is_str_line:
return False
return py_line.logical.startswith('def ') and py_line.is_continuation
class MergeMultiLineDecorator(AbstractMultiLineCollector):
def mark(self, py_line):
if py_line.is_str_line:
return False
return (py_line.logical.startswith(PyLineTokens.DECORATOR) and
py_line.is_continuation)
class MockParentClass(AbstractFilter):
_PARENT_RE = re.compile('class \w*\((.*)\)\:$')
def mark(self, py_line):
return (py_line.logical.startswith('class ') and
not py_line.is_str_line)
def filter(self, py_line, py_file):
if not py_file.get_line(py_line):
return
m = MockParentClass._PARENT_RE.match(py_line.logical)
if m:
py_line.logical = py_line.logical.replace(
"(%s):" % m.group(1), "(%s):" % _MOCK_CLASS_NAME)
class MockImports(AbstractFilter):
def __init__(self):
self._parser = ImportParser()
def mark(self, py_line):
return self._parser.is_statement(remove_brackets(py_line.logical))
def filter(self, py_line, py_file):
if not py_file.contains_line(py_line) or not self.mark(py_line):
return
py_line.logical = remove_brackets(py_line.logical)
self._parser.parse(py_line.logical)
if '*' in self._parser.names:
inferred_names = []
for module in self._parser.modules:
if not module.startswith('.'):
inferred_names.extend(module.split('.'))
self._parser.names = inferred_names
if not self._parser.names:
py_line.comment_out()
return
py_line.logical = ', '.join(self._parser.names) + ' = ' + ', '.join(
[_MOCK_IMPORT_CLASS_NAME + '()' for n in self._parser.names])
if '_' in self._parser.names:
# TODO(boden): one off
mock_translate = PyLine(py_line.ws, '_ = lambda s: str(s)',
py_file)
py_file.insert_after(py_line, mock_translate)
class APISignature(object):
class SignatureType(object):
CLASS = 'class'
FUNCTION = 'function'
METHOD = 'method'
CLASS_ATTR = 'class_attribute'
MODULE_ATTR = 'module_attribute'
def __init__(self, signature_type, qualified_name, member, arg_spec):
self.signature_type = signature_type
self.qualified_name = qualified_name
self.member = member
self.arg_spec = arg_spec
def to_dict(self):
defaults = ([json_primitive(d) for d in self.arg_spec.defaults]
if self.arg_spec.defaults else None)
return {
'member_type': self.signature_type,
'qualified_name': self.qualified_name,
'member_value': json_primitive(self.member),
'arg_spec': {
'args': self.arg_spec.args,
'varargs': self.arg_spec.varargs,
'keywords': self.arg_spec.keywords,
'defaults': defaults
}
}
@staticmethod
def arg_spec_from_dict(arg_spec_dict):
defaults = arg_spec_dict['defaults']
if defaults is not None:
defaults = tuple(defaults)
return inspect.ArgSpec(arg_spec_dict['args'],
arg_spec_dict['varargs'],
arg_spec_dict['keywords'],
defaults)
@staticmethod
def from_dict(api_dict):
return APISignature(
api_dict['member_type'],
api_dict['qualified_name'],
api_dict['member_value'],
APISignature.arg_spec_from_dict(api_dict['arg_spec']))
@property
def signature(self):
return self._build_signature(self.to_dict())
@staticmethod
def get_signature(signature):
if isinstance(signature, dict):
signature = APISignature.from_dict(signature)
return signature.signature
def _build_callable_signature(self, signature_dict):
arg_spec = signature_dict['arg_spec']
arg_str = ''
defaults = arg_spec['defaults'] or []
named_args = arg_spec['args'] or []
named_kwargs = []
if defaults:
named_args = arg_spec['args'][:-len(defaults)]
named_kwargs = arg_spec['args'][-len(defaults):]
if named_args:
arg_str += ", ".join(named_args)
if named_kwargs:
kw_args = []
for kw_name, kw_default in zip(named_kwargs, defaults):
kw_args.append("%s=%s" % (kw_name, kw_default))
arg_str += ", %s" % ", ".join(kw_args)
if arg_spec['varargs'] is not None:
arg_str = "*%s%s" % (arg_spec['varargs'],
'' if not arg_str else ', ' + arg_str)
if arg_spec['keywords'] is not None:
arg_str += "%s**%s" % (', ' if arg_str
else '', arg_spec['keywords'])
if arg_str.startswith(','):
arg_str = arg_str[1:]
return "%s(%s)" % (signature_dict['qualified_name'], arg_str.strip())
def _build_variable_signature(self, signature_dict):
return "%s = %s" % (signature_dict['qualified_name'],
signature_dict['member_value'])
def _build_class_signature(self, signature_dict):
return signature_dict['qualified_name']
def _build_signature(self, signature_dict):
if (signature_dict['member_type'] in
[APISignature.SignatureType.FUNCTION,
APISignature.SignatureType.METHOD]):
return self._build_callable_signature(signature_dict)
elif signature_dict['member_type'] == APISignature.SignatureType.CLASS:
return self._build_class_signature(signature_dict)
else:
return self._build_variable_signature(signature_dict)
class ModuleParser(object):
def __init__(self, listeners, abort_on_load_failure=False):
self.listeners = listeners
self.abort_on_load_failure = abort_on_load_failure
def _notify(self, signature_type, qualified_name, member, arg_spec=None):
for listener in self.listeners:
notify = getattr(listener, 'parse_' + signature_type)
notify(APISignature(signature_type, qualified_name,
member, arg_spec or
inspect.ArgSpec(None, None, None, None)))
def _collect_paths(self, paths, recurse=True):
inits, mods = [], []
if not paths:
return inits, mods
for py_path in paths:
if is_py_file(py_path):
if path.basename(py_path) == '__init__.py':
inits.append(py_path)
else:
mods.append(py_path)
elif is_py_dir(py_path) and recurse:
c_inits, c_mods = self._collect_paths(
[path.join(py_path, c) for c in os.listdir(py_path)],
recurse=recurse)
inits.extend(c_inits)
mods.extend(c_mods)
return inits, mods
def _load_path(self, module_path):
module_name = path.basename(path.splitext(module_path)[0])
pkg_name = '.'.join(parent_package_names(module_path) or '')
defined_name = ('%s.%s' % (pkg_name, module_name) if pkg_name
else module_name)
if module_name == '__init__':
defined_name = pkg_name
search_paths = [parent_path(module_path)]
f = None
try:
if defined_name in sys.modules:
del sys.modules[defined_name]
f, p, d = imp.find_module(module_name, search_paths)
module = imp.load_module(defined_name, f, p, d)
if defined_name == '__init__':
setattr(module, '__path__', search_paths)
return module
except Exception as e:
sys.stderr.write("Failed to load module '%s' due to: %s" %
(module_path, e))
if self.abort_on_load_failure:
raise e
finally:
if f:
f.close()
def load_modules(self, init_paths, module_paths):
init_mods, mods = [], []
failed_to_load = []
def _load(paths, store):
for m_path in paths:
module = self._load_path(m_path)
if module:
store.append(module)
else:
failed_to_load.append(m_path)
_load(init_paths, init_mods)
_load(module_paths, mods)
return init_mods, mods, failed_to_load
def _fully_qualified_name(self, parent, name):
if inspect.isclass(parent):
prefix = parent.__module__ + '.' + parent.__name__
else:
prefix = parent.__name__
return prefix + '.' + name
def parse_modules(self, modules):
for module in modules:
for member_name, member in inspect.getmembers(
module, _member_filter):
if member_name.startswith('__') and member_name.endswith('__'):
continue
fqn = self._fully_qualified_name(module, member_name)
if inspect.isclass(member):
self._notify(APISignature.SignatureType.CLASS,
fqn, member)
self.parse_modules([member])
elif inspect.isfunction(member):
self._notify(APISignature.SignatureType.FUNCTION,
fqn, member,
arg_spec=inspect.getargspec(member))
elif inspect.ismethod(member):
self._notify(APISignature.SignatureType.METHOD,
fqn, member,
arg_spec=inspect.getargspec(member))
else:
event = (APISignature.SignatureType.MODULE_ATTR
if inspect.ismodule(module)
else APISignature.SignatureType.CLASS_ATTR)
self._notify(event, fqn, member)
def parse_paths(self, py_paths, recurse=True):
init_paths, mod_paths = self._collect_paths(
py_paths, recurse=recurse)
init_mods, pkg_mods, failed_mods = self.load_modules(
init_paths, mod_paths)
self.parse_modules(init_mods)
self.parse_modules(pkg_mods)
class APIReport(object):
def __init__(self, abort_on_load_failure=False):
self._api = {}
self._parser = ModuleParser(
[self], abort_on_load_failure=abort_on_load_failure)
def _add(self, event):
if is_mock_import(event.member):
return
uuid = str(event.qualified_name)
if uuid in self._api:
# TODO(boden): configurable bail on duplicate flag
sys.stderr.write("Duplicate API signature: %s" % uuid)
return
if not filter(blacklist_filter, [uuid]):
return
self._api[uuid] = event.to_dict()
def parse_method(self, event):
self._add(event)
def parse_function(self, event):
self._add(event)
def parse_class(self, event):
self._add(event)
def parse_class_attribute(self, event):
self._add(event)
def parse_module_attribute(self, event):
self._add(event)
def parse_api_paths(self, api_paths, recurse=True):
self._parser.parse_paths(api_paths, recurse=recurse)
def parse_api_path(self, api_path, recurse=True):
self.parse_api_paths([api_path], recurse=recurse)
@property
def api(self):
return dict(self._api)
def to_json(self):
return jsonutils.dumps(self._api)
@staticmethod
def from_json(json_str):
api = APIReport()
api._api = jsonutils.loads(json_str)
return api
@staticmethod
def from_json_file(file_path):
with open(file_path, 'r') as json_file:
data = json_file.read()
return APIReport.from_json(data)
@staticmethod
def api_diff_files(new_api, old_api):
new_api = APIReport.from_json_file(new_api)
old_api = APIReport.from_json_file(old_api)
return new_api.api_diff(old_api)
def get_filtered_signatures(self):
return filter(blacklist_filter, self.get_signatures())
def get_signatures(self):
return sorted([APISignature.get_signature(s)
for s in self._api.values()])
def api_diff(self, other_api):
our_keys = sorted(self.api.keys())
other_keys = sorted(other_api.api.keys())
new_keys = set(our_keys) - set(other_keys)
removed_keys = set(other_keys) - set(our_keys)
common_keys = set(our_keys) & set(other_keys)
common_key_changes = [k for k in common_keys
if ordered(self.api[k]) !=
ordered(other_api.api[k])]
for k in common_key_changes:
if (not blacklist_filter(self.api[k]['member_value']) and not
blacklist_filter(other_api.api[k]['member_value'])):
common_key_changes.remove(k)
def _build_report(new_api):
apis = APIReport()
apis._api = new_api
return apis
return {
'new': _build_report({k: self.api[k] for k in new_keys}),
'removed': _build_report({k: other_api.api[k]
for k in removed_keys}),
'unchanged': _build_report({k: self.api[k]
for k in
set(common_keys) -
set(common_key_changes)}),
'new_changed': _build_report({k: self.api[k]
for k in common_key_changes}),
'old_changed': _build_report({k: other_api.api[k]
for k in common_key_changes})
}
@six.add_metaclass(abc.ABCMeta)
class AbstractCommand(object):
@abc.abstractmethod
def get_parser(self):
pass
@abc.abstractmethod
def run(self, args):
pass
def _add_blacklist_opt(parser):
parser.add_argument(
'--blacklist',
help='One or more regular expressions used to filter out '
'API paths from the report. File path segments, module '
'names, class names, etc. are all subject to filtering. '
'Multiple regexes can be specified using a comma in the '
'--blacklist argument.')
class GenerateReportCommand(AbstractCommand):
PY_LINE_FILTERS = [RemoveDocStrings(),
RemoveCommentLines(),
StripTrailingComments(),
MergeMultiLineImports(),
MergeMultiLineClass(),
MergeMultiLineDef(),
MergeMultiLineDecorator(),
CommentOutDecorators(),
PassEmptyDef(),
MockParentClass(),
AddMockDefinitions(),
MockImports()]
def __init__(self):
self._parser = argparse.ArgumentParser(
prog='generate',
description='Generate an interface report for python '
'source. The paths given can be a python '
'package or project directory, or a single '
'python source file. The program replaces '
'your imports with mocks, so no dependencies '
'are needed in the python env.')
_add_blacklist_opt(self._parser)
self._parser.add_argument(
'--debug',
help='Exit parsing on failure to load a module and '
'leave temp staging dir intact..',
action='store_const',
const=True)
self._parser.add_argument('PATH', nargs='+', metavar='PATH')
def get_parser(self):
return self._parser
def run(self, args):
if args.blacklist:
add_blacklist_from_csv_str(args.blacklist)
files = PyFiles(args.PATH)
with files.tmp_tree(delete_on_exit=(
not args.debug)) as tmp_root:
PyFiles.filter_all_py_files(
tmp_root, GenerateReportCommand.PY_LINE_FILTERS)
report = APIReport(abort_on_load_failure=args.debug)
for child in os.listdir(tmp_root):
child_path = path.join(tmp_root, child)
report.parse_api_paths([child_path])
print("%s" % report.to_json())
class PrintReportCommand(AbstractCommand):
def __init__(self):
self._parser = argparse.ArgumentParser(
prog='print',
description='Given a JSON API file, print the API signatures '
'to STDOUT.')
_add_blacklist_opt(self._parser)
self._parser.add_argument('REPORT_FILE',
help='Path to JSON report file.')
def get_parser(self):
return self._parser
def run(self, args):
if args.blacklist:
add_blacklist_from_csv_str(args.blacklist)
report = APIReport.from_json_file(args.REPORT_FILE)
for signature in report.get_filtered_signatures():
print(signature)
class DiffReportCommand(AbstractCommand):
def __init__(self):
self._parser = argparse.ArgumentParser(
prog='diff',
description='Given a new and old JSON interface report '
'files, calculate the changes between new '
'and old and echo them to STDOUT.')
_add_blacklist_opt(self._parser)
self._parser.add_argument(
'--unchanged',
help='Used with --diff to specify that unchanged '
'public APIs should be reported in addition to '
'new and removed.',
action='store_const',
const=True)
self._parser.add_argument('NEW_REPORT_FILE',
help='Path to new report file.')
self._parser.add_argument('OLD_REPORT_FILE',
help='Path to old report file.')
def get_parser(self):
return self._parser
def _print_row(self, heading, content_list):
print(heading)
print("-----------------------------------------------------")
for content in content_list:
print(str(content))
print("-----------------------------------------------------\n")
def run(self, args):
if args.blacklist:
add_blacklist_from_csv_str(args.blacklist)
api_diff = APIReport.api_diff_files(
args.NEW_REPORT_FILE, args.OLD_REPORT_FILE)
self._print_row("New API Signatures",
api_diff['new'].get_filtered_signatures())
self._print_row("Removed API Signatures",
api_diff['removed'].get_filtered_signatures())
new_sigs = api_diff['new_changed'].get_filtered_signatures()
old_sigs = api_diff['old_changed'].get_filtered_signatures()
if len(new_sigs) != len(old_sigs):
new_sigs = []
old_sigs = []
new_changed = api_diff['new_changed'].api
old_changed = api_diff['old_changed'].api
for n, new_spec in new_changed.items():
display_old = blacklist_filter(old_changed[n]['member_value'])
display_new = blacklist_filter(new_spec['member_value'])
if display_old and display_new:
new_sigs.append(APISignature.get_signature(new_spec))
old_sigs.append(APISignature.get_signature(old_changed[n]))
elif not display_old:
new_sigs.append(APISignature.get_signature(new_spec))
old_sigs.append(n)
else:
new_sigs.append('UNKNOWN')
old_sigs.append(n)
self._print_row("Changed API Signatures",
["%s [is now] %s" %
(old_sigs[i], new_sigs[i])
for i in range(len(new_sigs))])
if args.unchanged:
self._print_row("Unchanged API Signatures",
api_diff['unchanged'].get_filtered_signatures())
class CLI(object):
def __init__(self, commands):
self._commands = {c.get_parser().prog: c for c in commands}
self.parser = argparse.ArgumentParser(
prog='pyir',
description='Python API report tooling.',
usage="pyir <%s> [args]" % "|".join(self._commands.keys()),
add_help=True)
self.parser.add_argument(
'command',
help='The command to run. Known commands: '
'%s . Try \'pyir <command> --help\' for more info '
'on a specific command. ' % ", ".join(self._commands.keys()))
args = self.parser.parse_args(sys.argv[1:2])
if args.command not in self._commands.keys():
print("Unknown command: %s" % args.command)
self.parser.print_help()
exit(1)
cmd = self._commands[args.command]
cmd.get_parser().prog = self.parser.prog + ' ' + cmd.get_parser().prog
cmd.run(cmd.get_parser().parse_args(sys.argv[2:]))
def main():
CLI([DiffReportCommand(), GenerateReportCommand(), PrintReportCommand()])
if __name__ == '__main__':
main()
exit(0)