ec2-api/ec2api/tests/unit/matchers.py

565 lines
18 KiB
Python

# Copyright 2014
# The Cloudscaling Group, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Matcher classes to be used inside of the testtools assertThat framework."""
import copy
import pprint
from lxml import etree
from testtools import content
class DictKeysMismatch(object):
def __init__(self, d1only, d2only):
self.d1only = d1only
self.d2only = d2only
def describe(self):
return ('Keys in d1 and not d2: %(d1only)s.'
' Keys in d2 and not d1: %(d2only)s' %
{'d1only': self.d1only, 'd2only': self.d2only})
def get_details(self):
return {}
class ValueMismatch(object):
def __init__(self, v1, v2):
self.v1 = v1
self.v2 = v2
def describe(self):
return ("Values do not match. v1: %(v1)s v2: %(v2)s" %
{'v1': self.v1, 'v2': self.v2})
def get_details(self):
return {}
class DictMismatch(object):
def __init__(self, key, d1_value, d2_value):
self.key = key
self.d1_value = d1_value
self.d2_value = d2_value
def describe(self):
return ("Dictionaries do not match at %(key)s."
" d1: %(d1_value)s d2: %(d2_value)s" %
{'key': self.key, 'd1_value': self.d1_value,
'd2_value': self.d2_value})
def get_details(self):
return {}
class ValueMatches(object):
def __init__(self, v1, approx_equal=False, tolerance=0.001,
orderless_lists=False):
self.v1 = v1
self.approx_equal = approx_equal
self.tolerance = tolerance
self.orderless_lists = orderless_lists
def __str__(self):
return 'ValueMatches(%s)' % (pprint.pformat(self.v1))
# Useful assertions
def match(self, v2):
"""Assert two values are equivalent.
NOTE:
If you don't care (or don't know) a given value, you can specify
the string DONTCARE as the value. This will cause that item
to be skipped.
"""
try:
error = abs(float(self.v1) - float(v2))
within_tolerance = error <= self.tolerance
except (ValueError, TypeError):
# If both values aren't convertible to float, just ignore
# ValueError if arg is a str, TypeError if it's something else
# (like None)
within_tolerance = False
if hasattr(self.v1, 'keys') and hasattr(v2, 'keys'):
matcher = DictMatches(self.v1, self.approx_equal, self.tolerance,
self.orderless_lists)
did_match = matcher.match(v2)
if did_match is not None:
return did_match
elif isinstance(self.v1, list) and isinstance(v2, list):
matcher = ListMatches(self.v1, self.approx_equal, self.tolerance,
self.orderless_lists)
did_match = matcher.match(v2)
if did_match is not None:
return did_match
elif 'DONTCARE' in (self.v1, v2):
return
elif self.approx_equal and within_tolerance:
return
elif self.v1 != v2:
return ValueMismatch(self.v1, v2)
class DictMatches(object):
def __init__(self, d1, approx_equal=False, tolerance=0.001,
orderless_lists=False):
self.d1 = d1
self.approx_equal = approx_equal
self.tolerance = tolerance
self.orderless_lists = orderless_lists
def __str__(self):
return 'DictMatches(%s)' % (pprint.pformat(self.d1))
# Useful assertions
def match(self, d2):
"""Assert two dicts are equivalent.
This is a 'deep' match in the sense that it handles nested
dictionaries appropriately.
NOTE:
If you don't care (or don't know) a given value, you can specify
the string DONTCARE as the value. This will cause that dict-item
to be skipped.
"""
d1keys = set(self.d1.keys())
d2keys = set(d2.keys())
if d1keys != d2keys:
d1only = d1keys - d2keys
d2only = d2keys - d1keys
return DictKeysMismatch(d1only, d2only)
for key in d1keys:
d1value = self.d1[key]
d2value = d2[key]
matcher = ValueMatches(d1value, self.approx_equal, self.tolerance,
self.orderless_lists)
did_match = matcher.match(d2value)
if did_match is not None:
return did_match
class ListMismatch(object):
def __init__(self, l1, l2):
self.l1 = l1
self.l2 = l2
def describe(self):
return ('Lists mismatch: L1=%(l1)s != L2=%(l2)s' %
{'l1': self.l1, 'l2': self.l2})
def get_details(self):
return {}
class ListLengthMismatch(object):
def __init__(self, l1, l2):
self.l1 = l1
self.l2 = l2
def describe(self):
return ('Lists lengths mismatch: L1=%(l1)s != L2=%(l2)s' %
{'l1': self.l1, 'l2': self.l2})
def get_details(self):
return {}
class ListMatches(object):
def __init__(self, l1, approx_equal=False, tolerance=0.001,
orderless_lists=False):
self.l1 = l1
self.approx_equal = approx_equal
self.tolerance = tolerance
self.orderless_lists = orderless_lists
def __str__(self):
return 'ListMatches(%s)' % (pprint.pformat(self.l1))
# Useful assertions
def match(self, l2):
"""Assert a list of dicts are equivalent."""
l1count = len(self.l1)
l2count = len(l2)
if l1count != l2count:
return ListLengthMismatch(self.l1, l2)
l2 = copy.deepcopy(l2)
for v1 in self.l1:
for v2 in l2:
matcher = ValueMatches(v1, self.approx_equal, self.tolerance,
self.orderless_lists)
did_match = matcher.match(v2)
if did_match is None:
l2.remove(v2)
break
if not self.orderless_lists:
break
if did_match is not None:
return did_match
class SubDictMismatch(object):
def __init__(self,
key=None,
sub_value=None,
super_value=None,
keys=False):
self.key = key
self.sub_value = sub_value
self.super_value = super_value
self.keys = keys
def describe(self):
if self.keys:
return "Keys between dictionaries did not match"
else:
return("Dictionaries do not match at %s. d1: %s d2: %s"
% (self.key,
self.super_value,
self.sub_value))
def get_details(self):
return {}
class IsSubDictOf(object):
def __init__(self, super_dict):
self.super_dict = super_dict
def __str__(self):
return 'IsSubDictOf(%s)' % (self.super_dict)
def match(self, sub_dict):
"""Assert a sub_dict is subset of super_dict."""
if not set(sub_dict.keys()).issubset(set(self.super_dict.keys())):
return SubDictMismatch(keys=True)
for k, sub_value in sub_dict.items():
super_value = self.super_dict[k]
if isinstance(sub_value, dict):
matcher = IsSubDictOf(super_value)
did_match = matcher.match(sub_value)
if did_match is not None:
return did_match
elif 'DONTCARE' in (sub_value, super_value):
continue
else:
if sub_value != super_value:
return SubDictMismatch(k, sub_value, super_value)
class FunctionCallMatcher(object):
def __init__(self, expected_func_calls):
self.expected_func_calls = expected_func_calls
self.actual_func_calls = []
def call(self, *args, **kwargs):
func_call = {'args': args, 'kwargs': kwargs}
self.actual_func_calls.append(func_call)
def match(self):
dict_list_matcher = ListMatches(self.expected_func_calls)
return dict_list_matcher.match(self.actual_func_calls)
class XMLMismatch(object):
"""Superclass for XML mismatch."""
def __init__(self, state):
self.path = str(state)
self.expected = state.expected
self.actual = state.actual
def describe(self):
return "%(path)s: XML does not match" % {'path': self.path}
def get_details(self):
return {
'expected': content.text_content(self.expected),
'actual': content.text_content(self.actual),
}
class XMLTagMismatch(XMLMismatch):
"""XML tags don't match."""
def __init__(self, state, idx, expected_tag, actual_tag):
super(XMLTagMismatch, self).__init__(state)
self.idx = idx
self.expected_tag = expected_tag
self.actual_tag = actual_tag
def describe(self):
return ("%(path)s: XML tag mismatch at index %(idx)d: "
"expected tag <%(expected_tag)s>; "
"actual tag <%(actual_tag)s>" %
{'path': self.path, 'idx': self.idx,
'expected_tag': self.expected_tag,
'actual_tag': self.actual_tag})
class XMLAttrKeysMismatch(XMLMismatch):
"""XML attribute keys don't match."""
def __init__(self, state, expected_only, actual_only):
super(XMLAttrKeysMismatch, self).__init__(state)
self.expected_only = ', '.join(sorted(expected_only))
self.actual_only = ', '.join(sorted(actual_only))
def describe(self):
return ("%(path)s: XML attributes mismatch: "
"keys only in expected: %(expected_only)s; "
"keys only in actual: %(actual_only)s" %
{'path': self.path, 'expected_only': self.expected_only,
'actual_only': self.actual_only})
class XMLAttrValueMismatch(XMLMismatch):
"""XML attribute values don't match."""
def __init__(self, state, key, expected_value, actual_value):
super(XMLAttrValueMismatch, self).__init__(state)
self.key = key
self.expected_value = expected_value
self.actual_value = actual_value
def describe(self):
return ("%(path)s: XML attribute value mismatch: "
"expected value of attribute %(key)s: %(expected_value)r; "
"actual value: %(actual_value)r" %
{'path': self.path, 'key': self.key,
'expected_value': self.expected_value,
'actual_value': self.actual_value})
class XMLTextValueMismatch(XMLMismatch):
"""XML text values don't match."""
def __init__(self, state, expected_text, actual_text):
super(XMLTextValueMismatch, self).__init__(state)
self.expected_text = expected_text
self.actual_text = actual_text
def describe(self):
return ("%(path)s: XML text value mismatch: "
"expected text value: %(expected_text)r; "
"actual value: %(actual_text)r" %
{'path': self.path, 'expected_text': self.expected_text,
'actual_text': self.actual_text})
class XMLUnexpectedChild(XMLMismatch):
"""Unexpected child present in XML."""
def __init__(self, state, tag, idx):
super(XMLUnexpectedChild, self).__init__(state)
self.tag = tag
self.idx = idx
def describe(self):
return ("%(path)s: XML unexpected child element <%(tag)s> "
"present at index %(idx)d" %
{'path': self.path, 'tag': self.tag, 'idx': self.idx})
class XMLExpectedChild(XMLMismatch):
"""Expected child not present in XML."""
def __init__(self, state, tag, idx):
super(XMLExpectedChild, self).__init__(state)
self.tag = tag
self.idx = idx
def describe(self):
return ("%(path)s: XML expected child element <%(tag)s> "
"not present at index %(idx)d" %
{'path': self.path, 'tag': self.tag, 'idx': self.idx})
class XMLMatchState(object):
"""Maintain some state for matching.
Tracks the XML node path and saves the expected and actual full
XML text, for use by the XMLMismatch subclasses.
"""
def __init__(self, expected, actual):
self.path = []
self.expected = expected
self.actual = actual
def __enter__(self):
pass
def __exit__(self, exc_type, exc_value, exc_tb):
self.path.pop()
return False
def __str__(self):
return '/' + '/'.join(self.path)
def node(self, tag, idx):
"""Adds tag and index to the path; they will be popped off
when the corresponding 'with' statement exits.
:param tag: The element tag
:param idx: If not None, the integer index of the element
within its parent. Not included in the path
element if None.
"""
if idx is not None:
self.path.append("%s[%d]" % (tag, idx))
else:
self.path.append(tag)
return self
class XMLMatches(object):
"""Compare XML strings. More complete than string comparison."""
def __init__(self, expected, orderless_sequence=False):
self.expected_xml = expected
self.orderless_sequence = orderless_sequence
self.expected = etree.fromstring(expected)
if self.orderless_sequence:
self._sort_xml(self.expected)
def __str__(self):
return 'XMLMatches(%r)' % self.expected_xml
def _sort_xml(self, xml):
for parent in xml.xpath('//*[./*]'):
parent[:] = sorted(parent, key=lambda x: x.tag)
def match(self, actual_xml):
actual = etree.fromstring(actual_xml)
if self.orderless_sequence:
self._sort_xml(actual)
state = XMLMatchState(self.expected_xml, actual_xml)
result = self._compare_node(self.expected, actual, state, None)
if result is False:
return XMLMismatch(state)
elif result is not True:
return result
def _compare_node(self, expected, actual, state, idx):
"""Recursively compares nodes within the XML tree."""
# Start by comparing the tags
if expected.tag != actual.tag:
return XMLTagMismatch(state, idx, expected.tag, actual.tag)
with state.node(expected.tag, idx):
# Compare the attribute keys
expected_attrs = set(expected.attrib.keys())
actual_attrs = set(actual.attrib.keys())
if expected_attrs != actual_attrs:
expected_only = expected_attrs - actual_attrs
actual_only = actual_attrs - expected_attrs
return XMLAttrKeysMismatch(state, expected_only, actual_only)
# Compare the attribute values
for key in expected_attrs:
expected_value = expected.attrib[key]
actual_value = actual.attrib[key]
if 'DONTCARE' in (expected_value, actual_value):
continue
elif expected_value != actual_value:
return XMLAttrValueMismatch(state, key, expected_value,
actual_value)
# Compare the contents of the node
if len(expected) == 0 and len(actual) == 0:
# No children, compare text values
if ('DONTCARE' not in (expected.text, actual.text) and
expected.text != actual.text):
return XMLTextValueMismatch(state, expected.text,
actual.text)
else:
expected_idx = 0
actual_idx = 0
while (expected_idx < len(expected) and
actual_idx < len(actual)):
# Ignore comments and processing instructions
# TODO(Vek): may interpret PIs in the future, to
# allow for, say, arbitrary ordering of some
# elements
if (expected[expected_idx].tag in
(etree.Comment, etree.ProcessingInstruction)):
expected_idx += 1
continue
# Compare the nodes
result = self._compare_node(expected[expected_idx],
actual[actual_idx], state,
actual_idx)
if result is not True:
return result
# Step on to comparing the next nodes...
expected_idx += 1
actual_idx += 1
# Make sure we consumed all nodes in actual
if actual_idx < len(actual):
return XMLUnexpectedChild(state, actual[actual_idx].tag,
actual_idx)
# Make sure we consumed all nodes in expected
if expected_idx < len(expected):
for node in expected[expected_idx:]:
if (node.tag in
(etree.Comment, etree.ProcessingInstruction)):
continue
return XMLExpectedChild(state, node.tag, actual_idx)
# The nodes match
return True