262 lines
8.1 KiB
Python
262 lines
8.1 KiB
Python
# Copyright (c) 2009-2012 testtools developers. See LICENSE for details.
|
|
|
|
__all__ = [
|
|
'KeysEqual',
|
|
]
|
|
|
|
from ..helpers import (
|
|
dict_subtract,
|
|
filter_values,
|
|
map_values,
|
|
)
|
|
from ._higherorder import (
|
|
AnnotatedMismatch,
|
|
PrefixedMismatch,
|
|
MismatchesAll,
|
|
)
|
|
from ._impl import Matcher, Mismatch
|
|
|
|
|
|
def LabelledMismatches(mismatches, details=None):
|
|
"""A collection of mismatches, each labelled."""
|
|
return MismatchesAll(
|
|
(PrefixedMismatch(k, v) for (k, v) in sorted(mismatches.items())),
|
|
wrap=False)
|
|
|
|
|
|
class MatchesAllDict(Matcher):
|
|
"""Matches if all of the matchers it is created with match.
|
|
|
|
A lot like ``MatchesAll``, but takes a dict of Matchers and labels any
|
|
mismatches with the key of the dictionary.
|
|
"""
|
|
|
|
def __init__(self, matchers):
|
|
super(MatchesAllDict, self).__init__()
|
|
self.matchers = matchers
|
|
|
|
def __str__(self):
|
|
return 'MatchesAllDict(%s)' % (_format_matcher_dict(self.matchers),)
|
|
|
|
def match(self, observed):
|
|
mismatches = {}
|
|
for label in self.matchers:
|
|
mismatches[label] = self.matchers[label].match(observed)
|
|
return _dict_to_mismatch(
|
|
mismatches, result_mismatch=LabelledMismatches)
|
|
|
|
|
|
class DictMismatches(Mismatch):
|
|
"""A mismatch with a dict of child mismatches."""
|
|
|
|
def __init__(self, mismatches, details=None):
|
|
super(DictMismatches, self).__init__(None, details=details)
|
|
self.mismatches = mismatches
|
|
|
|
def describe(self):
|
|
lines = ['{']
|
|
lines.extend(
|
|
[' %r: %s,' % (key, mismatch.describe())
|
|
for (key, mismatch) in sorted(self.mismatches.items())])
|
|
lines.append('}')
|
|
return '\n'.join(lines)
|
|
|
|
|
|
def _dict_to_mismatch(data, to_mismatch=None,
|
|
result_mismatch=DictMismatches):
|
|
if to_mismatch:
|
|
data = map_values(to_mismatch, data)
|
|
mismatches = filter_values(bool, data)
|
|
if mismatches:
|
|
return result_mismatch(mismatches)
|
|
|
|
|
|
class _MatchCommonKeys(Matcher):
|
|
"""Match on keys in a dictionary.
|
|
|
|
Given a dictionary where the values are matchers, this will look for
|
|
common keys in the matched dictionary and match if and only if all common
|
|
keys match the given matchers.
|
|
|
|
Thus::
|
|
|
|
>>> structure = {'a': Equals('x'), 'b': Equals('y')}
|
|
>>> _MatchCommonKeys(structure).match({'a': 'x', 'c': 'z'})
|
|
None
|
|
"""
|
|
|
|
def __init__(self, dict_of_matchers):
|
|
super(_MatchCommonKeys, self).__init__()
|
|
self._matchers = dict_of_matchers
|
|
|
|
def _compare_dicts(self, expected, observed):
|
|
common_keys = set(expected.keys()) & set(observed.keys())
|
|
mismatches = {}
|
|
for key in common_keys:
|
|
mismatch = expected[key].match(observed[key])
|
|
if mismatch:
|
|
mismatches[key] = mismatch
|
|
return mismatches
|
|
|
|
def match(self, observed):
|
|
mismatches = self._compare_dicts(self._matchers, observed)
|
|
if mismatches:
|
|
return DictMismatches(mismatches)
|
|
|
|
|
|
class _SubDictOf(Matcher):
|
|
"""Matches if the matched dict only has keys that are in given dict."""
|
|
|
|
def __init__(self, super_dict, format_value=repr):
|
|
super(_SubDictOf, self).__init__()
|
|
self.super_dict = super_dict
|
|
self.format_value = format_value
|
|
|
|
def match(self, observed):
|
|
excess = dict_subtract(observed, self.super_dict)
|
|
return _dict_to_mismatch(
|
|
excess, lambda v: Mismatch(self.format_value(v)))
|
|
|
|
|
|
class _SuperDictOf(Matcher):
|
|
"""Matches if all of the keys in the given dict are in the matched dict.
|
|
"""
|
|
|
|
def __init__(self, sub_dict, format_value=repr):
|
|
super(_SuperDictOf, self).__init__()
|
|
self.sub_dict = sub_dict
|
|
self.format_value = format_value
|
|
|
|
def match(self, super_dict):
|
|
return _SubDictOf(super_dict, self.format_value).match(self.sub_dict)
|
|
|
|
|
|
def _format_matcher_dict(matchers):
|
|
return '{%s}' % (
|
|
', '.join(sorted('%r: %s' % (k, v) for k, v in matchers.items())))
|
|
|
|
|
|
class _CombinedMatcher(Matcher):
|
|
"""Many matchers labelled and combined into one uber-matcher.
|
|
|
|
Subclass this and then specify a dict of matcher factories that take a
|
|
single 'expected' value and return a matcher. The subclass will match
|
|
only if all of the matchers made from factories match.
|
|
|
|
Not **entirely** dissimilar from ``MatchesAll``.
|
|
"""
|
|
|
|
matcher_factories = {}
|
|
|
|
def __init__(self, expected):
|
|
super(_CombinedMatcher, self).__init__()
|
|
self._expected = expected
|
|
|
|
def format_expected(self, expected):
|
|
return repr(expected)
|
|
|
|
def __str__(self):
|
|
return '%s(%s)' % (
|
|
self.__class__.__name__, self.format_expected(self._expected))
|
|
|
|
def match(self, observed):
|
|
matchers = dict(
|
|
(k, v(self._expected)) for k, v in self.matcher_factories.items())
|
|
return MatchesAllDict(matchers).match(observed)
|
|
|
|
|
|
class MatchesDict(_CombinedMatcher):
|
|
"""Match a dictionary exactly, by its keys.
|
|
|
|
Specify a dictionary mapping keys (often strings) to matchers. This is
|
|
the 'expected' dict. Any dictionary that matches this must have exactly
|
|
the same keys, and the values must match the corresponding matchers in the
|
|
expected dict.
|
|
"""
|
|
|
|
matcher_factories = {
|
|
'Extra': _SubDictOf,
|
|
'Missing': lambda m: _SuperDictOf(m, format_value=str),
|
|
'Differences': _MatchCommonKeys,
|
|
}
|
|
|
|
format_expected = lambda self, expected: _format_matcher_dict(expected)
|
|
|
|
|
|
class ContainsDict(_CombinedMatcher):
|
|
"""Match a dictionary for that contains a specified sub-dictionary.
|
|
|
|
Specify a dictionary mapping keys (often strings) to matchers. This is
|
|
the 'expected' dict. Any dictionary that matches this must have **at
|
|
least** these keys, and the values must match the corresponding matchers
|
|
in the expected dict. Dictionaries that have more keys will also match.
|
|
|
|
In other words, any matching dictionary must contain the dictionary given
|
|
to the constructor.
|
|
|
|
Does not check for strict sub-dictionary. That is, equal dictionaries
|
|
match.
|
|
"""
|
|
|
|
matcher_factories = {
|
|
'Missing': lambda m: _SuperDictOf(m, format_value=str),
|
|
'Differences': _MatchCommonKeys,
|
|
}
|
|
|
|
format_expected = lambda self, expected: _format_matcher_dict(expected)
|
|
|
|
|
|
class ContainedByDict(_CombinedMatcher):
|
|
"""Match a dictionary for which this is a super-dictionary.
|
|
|
|
Specify a dictionary mapping keys (often strings) to matchers. This is
|
|
the 'expected' dict. Any dictionary that matches this must have **only**
|
|
these keys, and the values must match the corresponding matchers in the
|
|
expected dict. Dictionaries that have fewer keys can also match.
|
|
|
|
In other words, any matching dictionary must be contained by the
|
|
dictionary given to the constructor.
|
|
|
|
Does not check for strict super-dictionary. That is, equal dictionaries
|
|
match.
|
|
"""
|
|
|
|
matcher_factories = {
|
|
'Extra': _SubDictOf,
|
|
'Differences': _MatchCommonKeys,
|
|
}
|
|
|
|
format_expected = lambda self, expected: _format_matcher_dict(expected)
|
|
|
|
|
|
class KeysEqual(Matcher):
|
|
"""Checks whether a dict has particular keys."""
|
|
|
|
def __init__(self, *expected):
|
|
"""Create a `KeysEqual` Matcher.
|
|
|
|
:param expected: The keys the matchee is expected to have. As a
|
|
special case, if a single argument is specified, and it is a
|
|
mapping, then we use its keys as the expected set.
|
|
"""
|
|
super(KeysEqual, self).__init__()
|
|
if len(expected) == 1:
|
|
try:
|
|
expected = expected[0].keys()
|
|
except AttributeError:
|
|
pass
|
|
self.expected = list(expected)
|
|
|
|
def __str__(self):
|
|
return "KeysEqual(%s)" % ', '.join(map(repr, self.expected))
|
|
|
|
def match(self, matchee):
|
|
from ._basic import _BinaryMismatch, Equals
|
|
expected = sorted(self.expected)
|
|
matched = Equals(expected).match(sorted(matchee.keys()))
|
|
if matched:
|
|
return AnnotatedMismatch(
|
|
'Keys not equal',
|
|
_BinaryMismatch(expected, 'does not match', matchee))
|
|
return None
|