# Copyright 2010 United States Government as represented by the # Administrator of the National Aeronautics and Space Administration. # Copyright 2012 Hewlett-Packard Development Company, L.P. # # Licensed under the Apache License, Version 2.0 (the "License"); you may # not use this file except in compliance with the License. You may obtain # a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, WITHOUT # WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the # License for the specific language governing permissions and limitations # under the License. """Matcher classes to be used inside of the testtools assertThat framework.""" import pprint from lxml import etree import six from testtools import content class DictKeysMismatch(object): def __init__(self, d1only, d2only): self.d1only = d1only self.d2only = d2only def describe(self): return ('Keys in d1 and not d2: %(d1only)s.' ' Keys in d2 and not d1: %(d2only)s' % {'d1only': self.d1only, 'd2only': self.d2only}) def get_details(self): return {} class DictMismatch(object): def __init__(self, key, d1_value, d2_value): self.key = key self.d1_value = d1_value self.d2_value = d2_value def describe(self): return ("Dictionaries do not match at %(key)s." " d1: %(d1_value)s d2: %(d2_value)s" % {'key': self.key, 'd1_value': self.d1_value, 'd2_value': self.d2_value}) def get_details(self): return {} class DictMatches(object): def __init__(self, d1, approx_equal=False, tolerance=0.001): self.d1 = d1 self.approx_equal = approx_equal self.tolerance = tolerance def __str__(self): return 'DictMatches(%s)' % (pprint.pformat(self.d1)) # Useful assertions def match(self, d2): """Assert two dicts are equivalent. This is a 'deep' match in the sense that it handles nested dictionaries appropriately. NOTE: If you don't care (or don't know) a given value, you can specify the string DONTCARE as the value. This will cause that dict-item to be skipped. """ d1keys = set(self.d1.keys()) d2keys = set(d2.keys()) if d1keys != d2keys: d1only = sorted(d1keys - d2keys) d2only = sorted(d2keys - d1keys) return DictKeysMismatch(d1only, d2only) for key in d1keys: d1value = self.d1[key] d2value = d2[key] try: error = abs(float(d1value) - float(d2value)) within_tolerance = error <= self.tolerance except (ValueError, TypeError): # If both values aren't convertible to float, just ignore # ValueError if arg is a str, TypeError if it's something else # (like None) within_tolerance = False if hasattr(d1value, 'keys') and hasattr(d2value, 'keys'): matcher = DictMatches(d1value) did_match = matcher.match(d2value) if did_match is not None: return did_match elif 'DONTCARE' in (d1value, d2value): continue elif self.approx_equal and within_tolerance: continue elif d1value != d2value: return DictMismatch(key, d1value, d2value) class ListLengthMismatch(object): def __init__(self, len1, len2): self.len1 = len1 self.len2 = len2 def describe(self): return ('Length mismatch: len(L1)=%(len1)d != ' 'len(L2)=%(len2)d' % {'len1': self.len1, 'len2': self.len2}) def get_details(self): return {} class DictListMatches(object): def __init__(self, l1, approx_equal=False, tolerance=0.001): self.l1 = l1 self.approx_equal = approx_equal self.tolerance = tolerance def __str__(self): return 'DictListMatches(%s)' % (pprint.pformat(self.l1)) # Useful assertions def match(self, l2): """Assert a list of dicts are equivalent.""" l1count = len(self.l1) l2count = len(l2) if l1count != l2count: return ListLengthMismatch(l1count, l2count) for d1, d2 in zip(self.l1, l2): matcher = DictMatches(d2, approx_equal=self.approx_equal, tolerance=self.tolerance) did_match = matcher.match(d1) if did_match: return did_match class SubDictMismatch(object): def __init__(self, key=None, sub_value=None, super_value=None, keys=False): self.key = key self.sub_value = sub_value self.super_value = super_value self.keys = keys def describe(self): if self.keys: return "Keys between dictionaries did not match" else: return("Dictionaries do not match at %s. d1: %s d2: %s" % (self.key, self.super_value, self.sub_value)) def get_details(self): return {} class IsSubDictOf(object): def __init__(self, super_dict): self.super_dict = super_dict def __str__(self): return 'IsSubDictOf(%s)' % (self.super_dict) def match(self, sub_dict): """Assert a sub_dict is subset of super_dict.""" if not set(sub_dict.keys()).issubset(set(self.super_dict.keys())): return SubDictMismatch(keys=True) for k, sub_value in sub_dict.items(): super_value = self.super_dict[k] if isinstance(sub_value, dict): matcher = IsSubDictOf(super_value) did_match = matcher.match(sub_value) if did_match is not None: return did_match elif 'DONTCARE' in (sub_value, super_value): continue else: if sub_value != super_value: return SubDictMismatch(k, sub_value, super_value) class FunctionCallMatcher(object): def __init__(self, expected_func_calls): self.expected_func_calls = expected_func_calls self.actual_func_calls = [] def call(self, *args, **kwargs): func_call = {'args': args, 'kwargs': kwargs} self.actual_func_calls.append(func_call) def match(self): dict_list_matcher = DictListMatches(self.expected_func_calls) return dict_list_matcher.match(self.actual_func_calls) class XMLMismatch(object): """Superclass for XML mismatch.""" def __init__(self, state): self.path = str(state) self.expected = state.expected self.actual = state.actual def describe(self): return "%(path)s: XML does not match" % {'path': self.path} def get_details(self): return { 'expected': content.text_content(self.expected), 'actual': content.text_content(self.actual), } class XMLDocInfoMismatch(XMLMismatch): """XML version or encoding doesn't match.""" def __init__(self, state, expected_doc_info, actual_doc_info): super(XMLDocInfoMismatch, self).__init__(state) self.expected_doc_info = expected_doc_info self.actual_doc_info = actual_doc_info def describe(self): return ("%(path)s: XML information mismatch(version, encoding) " "expected version %(expected_version)s, " "expected encoding %(expected_encoding)s; " "actual version %(actual_version)s, " "actual encoding %(actual_encoding)s" % {'path': self.path, 'expected_version': self.expected_doc_info['version'], 'expected_encoding': self.expected_doc_info['encoding'], 'actual_version': self.actual_doc_info['version'], 'actual_encoding': self.actual_doc_info['encoding']}) class XMLTagMismatch(XMLMismatch): """XML tags don't match.""" def __init__(self, state, idx, expected_tag, actual_tag): super(XMLTagMismatch, self).__init__(state) self.idx = idx self.expected_tag = expected_tag self.actual_tag = actual_tag def describe(self): return ("%(path)s: XML tag mismatch at index %(idx)d: " "expected tag <%(expected_tag)s>; " "actual tag <%(actual_tag)s>" % {'path': self.path, 'idx': self.idx, 'expected_tag': self.expected_tag, 'actual_tag': self.actual_tag}) class XMLAttrKeysMismatch(XMLMismatch): """XML attribute keys don't match.""" def __init__(self, state, expected_only, actual_only): super(XMLAttrKeysMismatch, self).__init__(state) self.expected_only = ', '.join(sorted(expected_only)) self.actual_only = ', '.join(sorted(actual_only)) def describe(self): return ("%(path)s: XML attributes mismatch: " "keys only in expected: %(expected_only)s; " "keys only in actual: %(actual_only)s" % {'path': self.path, 'expected_only': self.expected_only, 'actual_only': self.actual_only}) class XMLAttrValueMismatch(XMLMismatch): """XML attribute values don't match.""" def __init__(self, state, key, expected_value, actual_value): super(XMLAttrValueMismatch, self).__init__(state) self.key = key self.expected_value = expected_value self.actual_value = actual_value def describe(self): return ("%(path)s: XML attribute value mismatch: " "expected value of attribute %(key)s: %(expected_value)r; " "actual value: %(actual_value)r" % {'path': self.path, 'key': self.key, 'expected_value': self.expected_value, 'actual_value': self.actual_value}) class XMLTextValueMismatch(XMLMismatch): """XML text values don't match.""" def __init__(self, state, expected_text, actual_text): super(XMLTextValueMismatch, self).__init__(state) self.expected_text = expected_text self.actual_text = actual_text def describe(self): return ("%(path)s: XML text value mismatch: " "expected text value: %(expected_text)r; " "actual value: %(actual_text)r" % {'path': self.path, 'expected_text': self.expected_text, 'actual_text': self.actual_text}) class XMLUnexpectedChild(XMLMismatch): """Unexpected child present in XML.""" def __init__(self, state, tag, idx): super(XMLUnexpectedChild, self).__init__(state) self.tag = tag self.idx = idx def describe(self): return ("%(path)s: XML unexpected child element <%(tag)s> " "present at index %(idx)d" % {'path': self.path, 'tag': self.tag, 'idx': self.idx}) class XMLExpectedChild(XMLMismatch): """Expected child not present in XML.""" def __init__(self, state, tag, idx): super(XMLExpectedChild, self).__init__(state) self.tag = tag self.idx = idx def describe(self): return ("%(path)s: XML expected child element <%(tag)s> " "not present at index %(idx)d" % {'path': self.path, 'tag': self.tag, 'idx': self.idx}) class XMLMatchState(object): """Maintain some state for matching. Tracks the XML node path and saves the expected and actual full XML text, for use by the XMLMismatch subclasses. """ def __init__(self, expected, actual): self.path = [] self.expected = expected self.actual = actual def __enter__(self): pass def __exit__(self, exc_type, exc_value, exc_tb): self.path.pop() return False def __str__(self): return '/' + '/'.join(self.path) def node(self, tag, idx): """Adds tag and index to the path; they will be popped off when the corresponding 'with' statement exits. :param tag: The element tag :param idx: If not None, the integer index of the element within its parent. Not included in the path element if None. """ if idx is not None: self.path.append("%s[%d]" % (tag, idx)) else: self.path.append(tag) return self class XMLMatches(object): """Compare XML strings. More complete than string comparison.""" SKIP_TAGS = (etree.Comment, etree.ProcessingInstruction) def __init__(self, expected, allow_mixed_nodes=False, skip_empty_text_nodes=True, skip_values=('DONTCARE',)): self.expected_xml = expected self.expected = etree.parse(six.StringIO(expected)) self.allow_mixed_nodes = allow_mixed_nodes self.skip_empty_text_nodes = skip_empty_text_nodes self.skip_values = set(skip_values) def __str__(self): return 'XMLMatches(%r)' % self.expected_xml def match(self, actual_xml): actual = etree.parse(six.StringIO(actual_xml)) state = XMLMatchState(self.expected_xml, actual_xml) expected_doc_info = self._get_xml_docinfo(self.expected) actual_doc_info = self._get_xml_docinfo(actual) if expected_doc_info != actual_doc_info: return XMLDocInfoMismatch(state, expected_doc_info, actual_doc_info) result = self._compare_node(self.expected.getroot(), actual.getroot(), state, None) if result is False: return XMLMismatch(state) elif result is not True: return result @staticmethod def _get_xml_docinfo(xml_document): return {'version': xml_document.docinfo.xml_version, 'encoding': xml_document.docinfo.encoding} def _compare_text_nodes(self, expected, actual, state): expected_text = [expected.text] expected_text.extend(child.tail for child in expected) actual_text = [actual.text] actual_text.extend(child.tail for child in actual) if self.skip_empty_text_nodes: expected_text = [text for text in expected_text if text and not text.isspace()] actual_text = [text for text in actual_text if text and not text.isspace()] if self.skip_values.intersection( expected_text + actual_text): return if self.allow_mixed_nodes: # lets sort text nodes because they can be mixed expected_text = sorted(expected_text) actual_text = sorted(actual_text) if expected_text != actual_text: return XMLTextValueMismatch(state, expected_text, actual_text) def _compare_node(self, expected, actual, state, idx): """Recursively compares nodes within the XML tree.""" # Start by comparing the tags if expected.tag != actual.tag: return XMLTagMismatch(state, idx, expected.tag, actual.tag) with state.node(expected.tag, idx): # Compare the attribute keys expected_attrs = set(expected.attrib.keys()) actual_attrs = set(actual.attrib.keys()) if expected_attrs != actual_attrs: expected_only = expected_attrs - actual_attrs actual_only = actual_attrs - expected_attrs return XMLAttrKeysMismatch(state, expected_only, actual_only) # Compare the attribute values for key in expected_attrs: expected_value = expected.attrib[key] actual_value = actual.attrib[key] if self.skip_values.intersection( [expected_value, actual_value]): continue elif expected_value != actual_value: return XMLAttrValueMismatch(state, key, expected_value, actual_value) # Compare text nodes text_nodes_mismatch = self._compare_text_nodes( expected, actual, state) if text_nodes_mismatch: return text_nodes_mismatch # Compare the contents of the node matched_actual_child_idxs = set() # first_actual_child_idx - pointer to next actual child # used with allow_mixed_nodes=False ONLY # prevent to visit actual child nodes twice first_actual_child_idx = 0 for expected_child in expected: if expected_child.tag in self.SKIP_TAGS: continue related_actual_child_idx = None if self.allow_mixed_nodes: first_actual_child_idx = 0 for actual_child_idx in range( first_actual_child_idx, len(actual)): if actual[actual_child_idx].tag in self.SKIP_TAGS: first_actual_child_idx += 1 continue if actual_child_idx in matched_actual_child_idxs: continue # Compare the nodes result = self._compare_node(expected_child, actual[actual_child_idx], state, actual_child_idx) first_actual_child_idx += 1 if result is not True: if self.allow_mixed_nodes: continue else: return result else: # nodes match related_actual_child_idx = actual_child_idx break if related_actual_child_idx is not None: matched_actual_child_idxs.add(actual_child_idx) else: return XMLExpectedChild(state, expected_child.tag, actual_child_idx + 1) # Make sure we consumed all nodes in actual for actual_child_idx, actual_child in enumerate(actual): if (actual_child.tag not in self.SKIP_TAGS and actual_child_idx not in matched_actual_child_idxs): return XMLUnexpectedChild(state, actual_child.tag, actual_child_idx) # The nodes match return True