2171 lines
69 KiB
Python
2171 lines
69 KiB
Python
# Copyright 2008 Google Inc.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
#
|
|
# This is a fork of the pymox library intended to work with Python 3.
|
|
# The file was modified by quermit@gmail.com and dawid.fatyga@gmail.com
|
|
|
|
"""Mox, an object-mocking framework for Python.
|
|
|
|
Mox works in the record-replay-verify paradigm. When you first create
|
|
a mock object, it is in record mode. You then programmatically set
|
|
the expected behavior of the mock object (what methods are to be
|
|
called on it, with what parameters, what they should return, and in
|
|
what order).
|
|
|
|
Once you have set up the expected mock behavior, you put it in replay
|
|
mode. Now the mock responds to method calls just as you told it to.
|
|
If an unexpected method (or an expected method with unexpected
|
|
parameters) is called, then an exception will be raised.
|
|
|
|
Once you are done interacting with the mock, you need to verify that
|
|
all the expected interactions occured. (Maybe your code exited
|
|
prematurely without calling some cleanup method!) The verify phase
|
|
ensures that every expected method was called; otherwise, an exception
|
|
will be raised.
|
|
|
|
WARNING! Mock objects created by Mox are not thread-safe. If you are
|
|
call a mock in multiple threads, it should be guarded by a mutex.
|
|
|
|
TODO(stevepm): Add the option to make mocks thread-safe!
|
|
|
|
Suggested usage / workflow:
|
|
|
|
# Create Mox factory
|
|
my_mox = Mox()
|
|
|
|
# Create a mock data access object
|
|
mock_dao = my_mox.CreateMock(DAOClass)
|
|
|
|
# Set up expected behavior
|
|
mock_dao.RetrievePersonWithIdentifier('1').AndReturn(person)
|
|
mock_dao.DeletePerson(person)
|
|
|
|
# Put mocks in replay mode
|
|
my_mox.ReplayAll()
|
|
|
|
# Inject mock object and run test
|
|
controller.SetDao(mock_dao)
|
|
controller.DeletePersonById('1')
|
|
|
|
# Verify all methods were called as expected
|
|
my_mox.VerifyAll()
|
|
"""
|
|
|
|
import collections
|
|
import difflib
|
|
import inspect
|
|
import re
|
|
import types
|
|
import unittest
|
|
|
|
from mox3 import stubout
|
|
|
|
|
|
class Error(AssertionError):
|
|
"""Base exception for this module."""
|
|
|
|
pass
|
|
|
|
|
|
class ExpectedMethodCallsError(Error):
|
|
"""Raised when an expected method wasn't called.
|
|
|
|
This can occur if Verify() is called before all expected methods have been
|
|
called.
|
|
"""
|
|
|
|
def __init__(self, expected_methods):
|
|
"""Init exception.
|
|
|
|
Args:
|
|
# expected_methods: A sequence of MockMethod objects that should
|
|
# have been called.
|
|
expected_methods: [MockMethod]
|
|
|
|
Raises:
|
|
ValueError: if expected_methods contains no methods.
|
|
"""
|
|
|
|
if not expected_methods:
|
|
raise ValueError("There must be at least one expected method")
|
|
Error.__init__(self)
|
|
self._expected_methods = expected_methods
|
|
|
|
def __str__(self):
|
|
calls = "\n".join(["%3d. %s" % (i, m)
|
|
for i, m in enumerate(self._expected_methods)])
|
|
return "Verify: Expected methods never called:\n%s" % (calls,)
|
|
|
|
|
|
class UnexpectedMethodCallError(Error):
|
|
"""Raised when an unexpected method is called.
|
|
|
|
This can occur if a method is called with incorrect parameters, or out of
|
|
the specified order.
|
|
"""
|
|
|
|
def __init__(self, unexpected_method, expected):
|
|
"""Init exception.
|
|
|
|
Args:
|
|
# unexpected_method: MockMethod that was called but was not at the
|
|
# head of the expected_method queue.
|
|
# expected: MockMethod or UnorderedGroup the method should have
|
|
# been in.
|
|
unexpected_method: MockMethod
|
|
expected: MockMethod or UnorderedGroup
|
|
"""
|
|
|
|
Error.__init__(self)
|
|
if expected is None:
|
|
self._str = "Unexpected method call %s" % (unexpected_method,)
|
|
else:
|
|
differ = difflib.Differ()
|
|
diff = differ.compare(str(unexpected_method).splitlines(True),
|
|
str(expected).splitlines(True))
|
|
self._str = ("Unexpected method call."
|
|
" unexpected:- expected:+\n%s"
|
|
% ("\n".join(line.rstrip() for line in diff),))
|
|
|
|
def __str__(self):
|
|
return self._str
|
|
|
|
|
|
class UnknownMethodCallError(Error):
|
|
"""Raised if an unknown method is requested of the mock object."""
|
|
|
|
def __init__(self, unknown_method_name):
|
|
"""Init exception.
|
|
|
|
Args:
|
|
# unknown_method_name: Method call that is not part of the mocked
|
|
# class's public interface.
|
|
unknown_method_name: str
|
|
"""
|
|
|
|
Error.__init__(self)
|
|
self._unknown_method_name = unknown_method_name
|
|
|
|
def __str__(self):
|
|
return ("Method called is not a member of the object: %s" %
|
|
self._unknown_method_name)
|
|
|
|
|
|
class PrivateAttributeError(Error):
|
|
"""Raised if a MockObject is passed a private additional attribute name."""
|
|
|
|
def __init__(self, attr):
|
|
Error.__init__(self)
|
|
self._attr = attr
|
|
|
|
def __str__(self):
|
|
return ("Attribute '%s' is private and should not be available"
|
|
"in a mock object." % self._attr)
|
|
|
|
|
|
class ExpectedMockCreationError(Error):
|
|
"""Raised if mocks should have been created by StubOutClassWithMocks."""
|
|
|
|
def __init__(self, expected_mocks):
|
|
"""Init exception.
|
|
|
|
Args:
|
|
# expected_mocks: A sequence of MockObjects that should have been
|
|
# created
|
|
|
|
Raises:
|
|
ValueError: if expected_mocks contains no methods.
|
|
"""
|
|
|
|
if not expected_mocks:
|
|
raise ValueError("There must be at least one expected method")
|
|
Error.__init__(self)
|
|
self._expected_mocks = expected_mocks
|
|
|
|
def __str__(self):
|
|
mocks = "\n".join(["%3d. %s" % (i, m)
|
|
for i, m in enumerate(self._expected_mocks)])
|
|
return "Verify: Expected mocks never created:\n%s" % (mocks,)
|
|
|
|
|
|
class UnexpectedMockCreationError(Error):
|
|
"""Raised if too many mocks were created by StubOutClassWithMocks."""
|
|
|
|
def __init__(self, instance, *params, **named_params):
|
|
"""Init exception.
|
|
|
|
Args:
|
|
# instance: the type of obejct that was created
|
|
# params: parameters given during instantiation
|
|
# named_params: named parameters given during instantiation
|
|
"""
|
|
|
|
Error.__init__(self)
|
|
self._instance = instance
|
|
self._params = params
|
|
self._named_params = named_params
|
|
|
|
def __str__(self):
|
|
args = ", ".join(["%s" % v for i, v in enumerate(self._params)])
|
|
error = "Unexpected mock creation: %s(%s" % (self._instance, args)
|
|
|
|
if self._named_params:
|
|
error += ", " + ", ".join(["%s=%s" % (k, v) for k, v in
|
|
self._named_params.items()])
|
|
|
|
error += ")"
|
|
return error
|
|
|
|
|
|
class Mox(object):
|
|
"""Mox: a factory for creating mock objects."""
|
|
|
|
# A list of types that should be stubbed out with MockObjects (as
|
|
# opposed to MockAnythings).
|
|
_USE_MOCK_OBJECT = [types.FunctionType, types.ModuleType, types.MethodType]
|
|
|
|
def __init__(self):
|
|
"""Initialize a new Mox."""
|
|
|
|
self._mock_objects = []
|
|
self.stubs = stubout.StubOutForTesting()
|
|
|
|
def CreateMock(self, class_to_mock, attrs=None, bounded_to=None):
|
|
"""Create a new mock object.
|
|
|
|
Args:
|
|
# class_to_mock: the class to be mocked
|
|
class_to_mock: class
|
|
attrs: dict of attribute names to values that will be
|
|
set on the mock object. Only public attributes may be set.
|
|
bounded_to: optionally, when class_to_mock is not a class,
|
|
it points to a real class object, to which
|
|
attribute is bound
|
|
|
|
Returns:
|
|
MockObject that can be used as the class_to_mock would be.
|
|
"""
|
|
if attrs is None:
|
|
attrs = {}
|
|
new_mock = MockObject(class_to_mock, attrs=attrs,
|
|
class_to_bind=bounded_to)
|
|
self._mock_objects.append(new_mock)
|
|
return new_mock
|
|
|
|
def CreateMockAnything(self, description=None):
|
|
"""Create a mock that will accept any method calls.
|
|
|
|
This does not enforce an interface.
|
|
|
|
Args:
|
|
description: str. Optionally, a descriptive name for the mock object
|
|
being created, for debugging output purposes.
|
|
"""
|
|
new_mock = MockAnything(description=description)
|
|
self._mock_objects.append(new_mock)
|
|
return new_mock
|
|
|
|
def ReplayAll(self):
|
|
"""Set all mock objects to replay mode."""
|
|
|
|
for mock_obj in self._mock_objects:
|
|
mock_obj._Replay()
|
|
|
|
def VerifyAll(self):
|
|
"""Call verify on all mock objects created."""
|
|
|
|
for mock_obj in self._mock_objects:
|
|
mock_obj._Verify()
|
|
|
|
def ResetAll(self):
|
|
"""Call reset on all mock objects. This does not unset stubs."""
|
|
|
|
for mock_obj in self._mock_objects:
|
|
mock_obj._Reset()
|
|
|
|
def StubOutWithMock(self, obj, attr_name, use_mock_anything=False):
|
|
"""Replace a method, attribute, etc. with a Mock.
|
|
|
|
This will replace a class or module with a MockObject, and everything
|
|
else (method, function, etc) with a MockAnything. This can be
|
|
overridden to always use a MockAnything by setting use_mock_anything
|
|
to True.
|
|
|
|
Args:
|
|
obj: A Python object (class, module, instance, callable).
|
|
attr_name: str. The name of the attribute to replace with a mock.
|
|
use_mock_anything: bool. True if a MockAnything should be used
|
|
regardless of the type of attribute.
|
|
"""
|
|
|
|
if inspect.isclass(obj):
|
|
class_to_bind = obj
|
|
else:
|
|
class_to_bind = None
|
|
|
|
attr_to_replace = getattr(obj, attr_name)
|
|
attr_type = type(attr_to_replace)
|
|
|
|
if attr_type == MockAnything or attr_type == MockObject:
|
|
raise TypeError('Cannot mock a MockAnything! Did you remember to '
|
|
'call UnsetStubs in your previous test?')
|
|
|
|
type_check = (
|
|
attr_type in self._USE_MOCK_OBJECT or
|
|
inspect.isclass(attr_to_replace) or
|
|
isinstance(attr_to_replace, object))
|
|
if type_check and not use_mock_anything:
|
|
stub = self.CreateMock(attr_to_replace, bounded_to=class_to_bind)
|
|
else:
|
|
stub = self.CreateMockAnything(
|
|
description='Stub for %s' % attr_to_replace)
|
|
stub.__name__ = attr_name
|
|
|
|
self.stubs.Set(obj, attr_name, stub)
|
|
|
|
def StubOutClassWithMocks(self, obj, attr_name):
|
|
"""Replace a class with a "mock factory" that will create mock objects.
|
|
|
|
This is useful if the code-under-test directly instantiates
|
|
dependencies. Previously some boilder plate was necessary to
|
|
create a mock that would act as a factory. Using
|
|
StubOutClassWithMocks, once you've stubbed out the class you may
|
|
use the stubbed class as you would any other mock created by mox:
|
|
during the record phase, new mock instances will be created, and
|
|
during replay, the recorded mocks will be returned.
|
|
|
|
In replay mode
|
|
|
|
# Example using StubOutWithMock (the old, clunky way):
|
|
|
|
mock1 = mox.CreateMock(my_import.FooClass)
|
|
mock2 = mox.CreateMock(my_import.FooClass)
|
|
foo_factory = mox.StubOutWithMock(my_import, 'FooClass',
|
|
use_mock_anything=True)
|
|
foo_factory(1, 2).AndReturn(mock1)
|
|
foo_factory(9, 10).AndReturn(mock2)
|
|
mox.ReplayAll()
|
|
|
|
my_import.FooClass(1, 2) # Returns mock1 again.
|
|
my_import.FooClass(9, 10) # Returns mock2 again.
|
|
mox.VerifyAll()
|
|
|
|
# Example using StubOutClassWithMocks:
|
|
|
|
mox.StubOutClassWithMocks(my_import, 'FooClass')
|
|
mock1 = my_import.FooClass(1, 2) # Returns a new mock of FooClass
|
|
mock2 = my_import.FooClass(9, 10) # Returns another mock instance
|
|
mox.ReplayAll()
|
|
|
|
my_import.FooClass(1, 2) # Returns mock1 again.
|
|
my_import.FooClass(9, 10) # Returns mock2 again.
|
|
mox.VerifyAll()
|
|
"""
|
|
attr_to_replace = getattr(obj, attr_name)
|
|
attr_type = type(attr_to_replace)
|
|
|
|
if attr_type == MockAnything or attr_type == MockObject:
|
|
raise TypeError('Cannot mock a MockAnything! Did you remember to '
|
|
'call UnsetStubs in your previous test?')
|
|
|
|
if not inspect.isclass(attr_to_replace):
|
|
raise TypeError('Given attr is not a Class. Use StubOutWithMock.')
|
|
|
|
factory = _MockObjectFactory(attr_to_replace, self)
|
|
self._mock_objects.append(factory)
|
|
self.stubs.Set(obj, attr_name, factory)
|
|
|
|
def UnsetStubs(self):
|
|
"""Restore stubs to their original state."""
|
|
|
|
self.stubs.UnsetAll()
|
|
|
|
|
|
def Replay(*args):
|
|
"""Put mocks into Replay mode.
|
|
|
|
Args:
|
|
# args is any number of mocks to put into replay mode.
|
|
"""
|
|
|
|
for mock in args:
|
|
mock._Replay()
|
|
|
|
|
|
def Verify(*args):
|
|
"""Verify mocks.
|
|
|
|
Args:
|
|
# args is any number of mocks to be verified.
|
|
"""
|
|
|
|
for mock in args:
|
|
mock._Verify()
|
|
|
|
|
|
def Reset(*args):
|
|
"""Reset mocks.
|
|
|
|
Args:
|
|
# args is any number of mocks to be reset.
|
|
"""
|
|
|
|
for mock in args:
|
|
mock._Reset()
|
|
|
|
|
|
class MockAnything(object):
|
|
"""A mock that can be used to mock anything.
|
|
|
|
This is helpful for mocking classes that do not provide a public interface.
|
|
"""
|
|
|
|
def __init__(self, description=None):
|
|
"""Initialize a new MockAnything.
|
|
|
|
Args:
|
|
description: str. Optionally, a descriptive name for the mock
|
|
object being created, for debugging output purposes.
|
|
"""
|
|
self._description = description
|
|
self._Reset()
|
|
|
|
def __repr__(self):
|
|
if self._description:
|
|
return '<MockAnything instance of %s>' % self._description
|
|
else:
|
|
return '<MockAnything instance>'
|
|
|
|
def __getattr__(self, method_name):
|
|
"""Intercept method calls on this object.
|
|
|
|
A new MockMethod is returned that is aware of the MockAnything's
|
|
state (record or replay). The call will be recorded or replayed
|
|
by the MockMethod's __call__.
|
|
|
|
Args:
|
|
# method name: the name of the method being called.
|
|
method_name: str
|
|
|
|
Returns:
|
|
A new MockMethod aware of MockAnything's state (record or replay).
|
|
"""
|
|
if method_name == '__dir__':
|
|
return self.__class__.__dir__.__get__(self, self.__class__)
|
|
|
|
return self._CreateMockMethod(method_name)
|
|
|
|
def __str__(self):
|
|
return self._CreateMockMethod('__str__')()
|
|
|
|
def __call__(self, *args, **kwargs):
|
|
return self._CreateMockMethod('__call__')(*args, **kwargs)
|
|
|
|
def __getitem__(self, i):
|
|
return self._CreateMockMethod('__getitem__')(i)
|
|
|
|
def _CreateMockMethod(self, method_name, method_to_mock=None,
|
|
class_to_bind=object):
|
|
"""Create a new mock method call and return it.
|
|
|
|
Args:
|
|
# method_name: the name of the method being called.
|
|
# method_to_mock: The actual method being mocked, used for
|
|
# introspection.
|
|
# class_to_bind: Class to which method is bounded
|
|
# (object by default)
|
|
method_name: str
|
|
method_to_mock: a method object
|
|
|
|
Returns:
|
|
A new MockMethod aware of MockAnything's state (record or replay).
|
|
"""
|
|
|
|
return MockMethod(method_name, self._expected_calls_queue,
|
|
self._replay_mode, method_to_mock=method_to_mock,
|
|
description=self._description,
|
|
class_to_bind=class_to_bind)
|
|
|
|
def __nonzero__(self):
|
|
"""Return 1 for nonzero so the mock can be used as a conditional."""
|
|
|
|
return 1
|
|
|
|
def __bool__(self):
|
|
"""Return True for nonzero so the mock can be used as a conditional."""
|
|
return True
|
|
|
|
def __eq__(self, rhs):
|
|
"""Provide custom logic to compare objects."""
|
|
|
|
return (isinstance(rhs, MockAnything) and
|
|
self._replay_mode == rhs._replay_mode and
|
|
self._expected_calls_queue == rhs._expected_calls_queue)
|
|
|
|
def __ne__(self, rhs):
|
|
"""Provide custom logic to compare objects."""
|
|
|
|
return not self == rhs
|
|
|
|
def _Replay(self):
|
|
"""Start replaying expected method calls."""
|
|
|
|
self._replay_mode = True
|
|
|
|
def _Verify(self):
|
|
"""Verify that all of the expected calls have been made.
|
|
|
|
Raises:
|
|
ExpectedMethodCallsError: if there are still more method calls in
|
|
the expected queue.
|
|
"""
|
|
|
|
# If the list of expected calls is not empty, raise an exception
|
|
if self._expected_calls_queue:
|
|
# The last MultipleTimesGroup is not popped from the queue.
|
|
if (len(self._expected_calls_queue) == 1 and
|
|
isinstance(self._expected_calls_queue[0],
|
|
MultipleTimesGroup) and
|
|
self._expected_calls_queue[0].IsSatisfied()):
|
|
pass
|
|
else:
|
|
raise ExpectedMethodCallsError(self._expected_calls_queue)
|
|
|
|
def _Reset(self):
|
|
"""Reset the state of this mock to record mode with an empty queue."""
|
|
|
|
# Maintain a list of method calls we are expecting
|
|
self._expected_calls_queue = collections.deque()
|
|
|
|
# Make sure we are in setup mode, not replay mode
|
|
self._replay_mode = False
|
|
|
|
|
|
class MockObject(MockAnything):
|
|
"""Mock object that simulates the public/protected interface of a class."""
|
|
|
|
def __init__(self, class_to_mock, attrs=None, class_to_bind=None):
|
|
"""Initialize a mock object.
|
|
|
|
Determines the methods and properties of the class and stores them.
|
|
|
|
Args:
|
|
# class_to_mock: class to be mocked
|
|
class_to_mock: class
|
|
attrs: dict of attribute names to values that will be set on the
|
|
mock object. Only public attributes may be set.
|
|
class_to_bind: optionally, when class_to_mock is not a class at
|
|
all, it points to a real class
|
|
|
|
Raises:
|
|
PrivateAttributeError: if a supplied attribute is not public.
|
|
ValueError: if an attribute would mask an existing method.
|
|
"""
|
|
if attrs is None:
|
|
attrs = {}
|
|
|
|
# Used to hack around the mixin/inheritance of MockAnything, which
|
|
# is not a proper object (it can be anything. :-)
|
|
MockAnything.__dict__['__init__'](self)
|
|
|
|
# Get a list of all the public and special methods we should mock.
|
|
self._known_methods = set()
|
|
self._known_vars = set()
|
|
self._class_to_mock = class_to_mock
|
|
|
|
if inspect.isclass(class_to_mock):
|
|
self._class_to_bind = self._class_to_mock
|
|
else:
|
|
self._class_to_bind = class_to_bind
|
|
|
|
try:
|
|
if inspect.isclass(self._class_to_mock):
|
|
self._description = class_to_mock.__name__
|
|
else:
|
|
self._description = type(class_to_mock).__name__
|
|
except Exception:
|
|
pass
|
|
|
|
for method in dir(class_to_mock):
|
|
attr = getattr(class_to_mock, method)
|
|
if callable(attr):
|
|
self._known_methods.add(method)
|
|
elif not (type(attr) is property):
|
|
# treating properties as class vars makes little sense.
|
|
self._known_vars.add(method)
|
|
|
|
# Set additional attributes at instantiation time; this is quicker
|
|
# than manually setting attributes that are normally created in
|
|
# __init__.
|
|
for attr, value in attrs.items():
|
|
if attr.startswith("_"):
|
|
raise PrivateAttributeError(attr)
|
|
elif attr in self._known_methods:
|
|
raise ValueError("'%s' is a method of '%s' objects." % (attr,
|
|
class_to_mock))
|
|
else:
|
|
setattr(self, attr, value)
|
|
|
|
def _CreateMockMethod(self, *args, **kwargs):
|
|
"""Overridden to provide self._class_to_mock to class_to_bind."""
|
|
kwargs.setdefault("class_to_bind", self._class_to_bind)
|
|
return super(MockObject, self)._CreateMockMethod(*args, **kwargs)
|
|
|
|
def __getattr__(self, name):
|
|
"""Intercept attribute request on this object.
|
|
|
|
If the attribute is a public class variable, it will be returned and
|
|
not recorded as a call.
|
|
|
|
If the attribute is not a variable, it is handled like a method
|
|
call. The method name is checked against the set of mockable
|
|
methods, and a new MockMethod is returned that is aware of the
|
|
MockObject's state (record or replay). The call will be recorded
|
|
or replayed by the MockMethod's __call__.
|
|
|
|
Args:
|
|
# name: the name of the attribute being requested.
|
|
name: str
|
|
|
|
Returns:
|
|
Either a class variable or a new MockMethod that is aware of the
|
|
state of the mock (record or replay).
|
|
|
|
Raises:
|
|
UnknownMethodCallError if the MockObject does not mock the
|
|
requested method.
|
|
"""
|
|
|
|
if name in self._known_vars:
|
|
return getattr(self._class_to_mock, name)
|
|
|
|
if name in self._known_methods:
|
|
return self._CreateMockMethod(
|
|
name,
|
|
method_to_mock=getattr(self._class_to_mock, name))
|
|
|
|
raise UnknownMethodCallError(name)
|
|
|
|
def __eq__(self, rhs):
|
|
"""Provide custom logic to compare objects."""
|
|
|
|
return (isinstance(rhs, MockObject) and
|
|
self._class_to_mock == rhs._class_to_mock and
|
|
self._replay_mode == rhs._replay_mode and
|
|
self._expected_calls_queue == rhs._expected_calls_queue)
|
|
|
|
def __setitem__(self, key, value):
|
|
"""Custom logic for mocking classes that support item assignment.
|
|
|
|
Args:
|
|
key: Key to set the value for.
|
|
value: Value to set.
|
|
|
|
Returns:
|
|
Expected return value in replay mode. A MockMethod object for the
|
|
__setitem__ method that has already been called if not in replay
|
|
mode.
|
|
|
|
Raises:
|
|
TypeError if the underlying class does not support item assignment.
|
|
UnexpectedMethodCallError if the object does not expect the call to
|
|
__setitem__.
|
|
|
|
"""
|
|
# Verify the class supports item assignment.
|
|
if '__setitem__' not in dir(self._class_to_mock):
|
|
raise TypeError('object does not support item assignment')
|
|
|
|
# If we are in replay mode then simply call the mock __setitem__ method
|
|
if self._replay_mode:
|
|
return MockMethod('__setitem__', self._expected_calls_queue,
|
|
self._replay_mode)(key, value)
|
|
|
|
# Otherwise, create a mock method __setitem__.
|
|
return self._CreateMockMethod('__setitem__')(key, value)
|
|
|
|
def __getitem__(self, key):
|
|
"""Provide custom logic for mocking classes that are subscriptable.
|
|
|
|
Args:
|
|
key: Key to return the value for.
|
|
|
|
Returns:
|
|
Expected return value in replay mode. A MockMethod object for the
|
|
__getitem__ method that has already been called if not in replay
|
|
mode.
|
|
|
|
Raises:
|
|
TypeError if the underlying class is not subscriptable.
|
|
UnexpectedMethodCallError if the object does not expect the call to
|
|
__getitem__.
|
|
|
|
"""
|
|
# Verify the class supports item assignment.
|
|
if '__getitem__' not in dir(self._class_to_mock):
|
|
raise TypeError('unsubscriptable object')
|
|
|
|
# If we are in replay mode then simply call the mock __getitem__ method
|
|
if self._replay_mode:
|
|
return MockMethod('__getitem__', self._expected_calls_queue,
|
|
self._replay_mode)(key)
|
|
|
|
# Otherwise, create a mock method __getitem__.
|
|
return self._CreateMockMethod('__getitem__')(key)
|
|
|
|
def __iter__(self):
|
|
"""Provide custom logic for mocking classes that are iterable.
|
|
|
|
Returns:
|
|
Expected return value in replay mode. A MockMethod object for the
|
|
__iter__ method that has already been called if not in replay mode.
|
|
|
|
Raises:
|
|
TypeError if the underlying class is not iterable.
|
|
UnexpectedMethodCallError if the object does not expect the call to
|
|
__iter__.
|
|
|
|
"""
|
|
methods = dir(self._class_to_mock)
|
|
|
|
# Verify the class supports iteration.
|
|
if '__iter__' not in methods:
|
|
# If it doesn't have iter method and we are in replay method,
|
|
# then try to iterate using subscripts.
|
|
if '__getitem__' not in methods or not self._replay_mode:
|
|
raise TypeError('not iterable object')
|
|
else:
|
|
results = []
|
|
index = 0
|
|
try:
|
|
while True:
|
|
results.append(self[index])
|
|
index += 1
|
|
except IndexError:
|
|
return iter(results)
|
|
|
|
# If we are in replay mode then simply call the mock __iter__ method.
|
|
if self._replay_mode:
|
|
return MockMethod('__iter__', self._expected_calls_queue,
|
|
self._replay_mode)()
|
|
|
|
# Otherwise, create a mock method __iter__.
|
|
return self._CreateMockMethod('__iter__')()
|
|
|
|
def __contains__(self, key):
|
|
"""Provide custom logic for mocking classes that contain items.
|
|
|
|
Args:
|
|
key: Key to look in container for.
|
|
|
|
Returns:
|
|
Expected return value in replay mode. A MockMethod object for the
|
|
__contains__ method that has already been called if not in replay
|
|
mode.
|
|
|
|
Raises:
|
|
TypeError if the underlying class does not implement __contains__
|
|
UnexpectedMethodCaller if the object does not expect the call to
|
|
__contains__.
|
|
|
|
"""
|
|
contains = self._class_to_mock.__dict__.get('__contains__', None)
|
|
|
|
if contains is None:
|
|
raise TypeError('unsubscriptable object')
|
|
|
|
if self._replay_mode:
|
|
return MockMethod('__contains__', self._expected_calls_queue,
|
|
self._replay_mode)(key)
|
|
|
|
return self._CreateMockMethod('__contains__')(key)
|
|
|
|
def __call__(self, *params, **named_params):
|
|
"""Provide custom logic for mocking classes that are callable."""
|
|
|
|
# Verify the class we are mocking is callable.
|
|
is_callable = hasattr(self._class_to_mock, '__call__')
|
|
if not is_callable:
|
|
raise TypeError('Not callable')
|
|
|
|
# Because the call is happening directly on this object instead of
|
|
# a method, the call on the mock method is made right here
|
|
|
|
# If we are mocking a Function, then use the function, and not the
|
|
# __call__ method
|
|
method = None
|
|
if type(self._class_to_mock) in (types.FunctionType, types.MethodType):
|
|
method = self._class_to_mock
|
|
else:
|
|
method = getattr(self._class_to_mock, '__call__')
|
|
mock_method = self._CreateMockMethod('__call__', method_to_mock=method)
|
|
|
|
return mock_method(*params, **named_params)
|
|
|
|
@property
|
|
def __name__(self):
|
|
"""Return the name that is being mocked."""
|
|
return self._description
|
|
|
|
# TODO(dejw): this property stopped to work after I introduced changes with
|
|
# binding classes. Fortunately I found a solution in the form of
|
|
# __getattribute__ method below, but this issue should be investigated
|
|
@property
|
|
def __class__(self):
|
|
return self._class_to_mock
|
|
|
|
def __dir__(self):
|
|
"""Return only attributes of a class to mock."""
|
|
return dir(self._class_to_mock)
|
|
|
|
def __getattribute__(self, name):
|
|
"""Return _class_to_mock on __class__ attribute."""
|
|
if name == "__class__":
|
|
return super(MockObject, self).__getattribute__("_class_to_mock")
|
|
|
|
return super(MockObject, self).__getattribute__(name)
|
|
|
|
|
|
class _MockObjectFactory(MockObject):
|
|
"""A MockObjectFactory creates mocks and verifies __init__ params.
|
|
|
|
A MockObjectFactory removes the boiler plate code that was previously
|
|
necessary to stub out direction instantiation of a class.
|
|
|
|
The MockObjectFactory creates new MockObjects when called and verifies the
|
|
__init__ params are correct when in record mode. When replaying,
|
|
existing mocks are returned, and the __init__ params are verified.
|
|
|
|
See StubOutWithMock vs StubOutClassWithMocks for more detail.
|
|
"""
|
|
|
|
def __init__(self, class_to_mock, mox_instance):
|
|
MockObject.__init__(self, class_to_mock)
|
|
self._mox = mox_instance
|
|
self._instance_queue = collections.deque()
|
|
|
|
def __call__(self, *params, **named_params):
|
|
"""Instantiate and record that a new mock has been created."""
|
|
|
|
method = getattr(self._class_to_mock, '__init__')
|
|
mock_method = self._CreateMockMethod('__init__', method_to_mock=method)
|
|
# Note: calling mock_method() is deferred in order to catch the
|
|
# empty instance_queue first.
|
|
|
|
if self._replay_mode:
|
|
if not self._instance_queue:
|
|
raise UnexpectedMockCreationError(self._class_to_mock, *params,
|
|
**named_params)
|
|
|
|
mock_method(*params, **named_params)
|
|
|
|
return self._instance_queue.pop()
|
|
else:
|
|
mock_method(*params, **named_params)
|
|
|
|
instance = self._mox.CreateMock(self._class_to_mock)
|
|
self._instance_queue.appendleft(instance)
|
|
return instance
|
|
|
|
def _Verify(self):
|
|
"""Verify that all mocks have been created."""
|
|
if self._instance_queue:
|
|
raise ExpectedMockCreationError(self._instance_queue)
|
|
super(_MockObjectFactory, self)._Verify()
|
|
|
|
|
|
class MethodSignatureChecker(object):
|
|
"""Ensures that methods are called correctly."""
|
|
|
|
_NEEDED, _DEFAULT, _GIVEN = range(3)
|
|
|
|
def __init__(self, method, class_to_bind=None):
|
|
"""Creates a checker.
|
|
|
|
Args:
|
|
# method: A method to check.
|
|
# class_to_bind: optionally, a class used to type check first
|
|
# method parameter, only used with unbound methods
|
|
method: function
|
|
class_to_bind: type or None
|
|
|
|
Raises:
|
|
ValueError: method could not be inspected, so checks aren't
|
|
possible. Some methods and functions like built-ins
|
|
can't be inspected.
|
|
"""
|
|
try:
|
|
self._args, varargs, varkw, defaults = inspect.getargspec(method)
|
|
except TypeError:
|
|
raise ValueError('Could not get argument specification for %r'
|
|
% (method,))
|
|
if (inspect.ismethod(method) or class_to_bind or (
|
|
hasattr(self, '_args') and len(self._args) > 0 and
|
|
self._args[0] == 'self')):
|
|
self._args = self._args[1:] # Skip 'self'.
|
|
self._method = method
|
|
self._instance = None # May contain the instance this is bound to.
|
|
self._instance = getattr(method, "__self__", None)
|
|
|
|
# _bounded_to determines whether the method is bound or not
|
|
if self._instance:
|
|
self._bounded_to = self._instance.__class__
|
|
else:
|
|
self._bounded_to = class_to_bind or getattr(method, "im_class",
|
|
None)
|
|
|
|
self._has_varargs = varargs is not None
|
|
self._has_varkw = varkw is not None
|
|
if defaults is None:
|
|
self._required_args = self._args
|
|
self._default_args = []
|
|
else:
|
|
self._required_args = self._args[:-len(defaults)]
|
|
self._default_args = self._args[-len(defaults):]
|
|
|
|
def _RecordArgumentGiven(self, arg_name, arg_status):
|
|
"""Mark an argument as being given.
|
|
|
|
Args:
|
|
# arg_name: The name of the argument to mark in arg_status.
|
|
# arg_status: Maps argument names to one of
|
|
# _NEEDED, _DEFAULT, _GIVEN.
|
|
arg_name: string
|
|
arg_status: dict
|
|
|
|
Raises:
|
|
AttributeError: arg_name is already marked as _GIVEN.
|
|
"""
|
|
if arg_status.get(arg_name, None) == MethodSignatureChecker._GIVEN:
|
|
raise AttributeError('%s provided more than once' % (arg_name,))
|
|
arg_status[arg_name] = MethodSignatureChecker._GIVEN
|
|
|
|
def Check(self, params, named_params):
|
|
"""Ensures that the parameters used while recording a call are valid.
|
|
|
|
Args:
|
|
# params: A list of positional parameters.
|
|
# named_params: A dict of named parameters.
|
|
params: list
|
|
named_params: dict
|
|
|
|
Raises:
|
|
AttributeError: the given parameters don't work with the given
|
|
method.
|
|
"""
|
|
arg_status = dict((a, MethodSignatureChecker._NEEDED)
|
|
for a in self._required_args)
|
|
for arg in self._default_args:
|
|
arg_status[arg] = MethodSignatureChecker._DEFAULT
|
|
|
|
# WARNING: Suspect hack ahead.
|
|
#
|
|
# Check to see if this is an unbound method, where the instance
|
|
# should be bound as the first argument. We try to determine if
|
|
# the first argument (param[0]) is an instance of the class, or it
|
|
# is equivalent to the class (used to account for Comparators).
|
|
#
|
|
# NOTE: If a Func() comparator is used, and the signature is not
|
|
# correct, this will cause extra executions of the function.
|
|
if inspect.ismethod(self._method) or self._bounded_to:
|
|
# The extra param accounts for the bound instance.
|
|
if len(params) > len(self._required_args):
|
|
expected = self._bounded_to
|
|
|
|
# Check if the param is an instance of the expected class,
|
|
# or check equality (useful for checking Comparators).
|
|
|
|
# This is a hack to work around the fact that the first
|
|
# parameter can be a Comparator, and the comparison may raise
|
|
# an exception during this comparison, which is OK.
|
|
try:
|
|
param_equality = (params[0] == expected)
|
|
except Exception:
|
|
param_equality = False
|
|
|
|
if isinstance(params[0], expected) or param_equality:
|
|
params = params[1:]
|
|
# If the IsA() comparator is being used, we need to check the
|
|
# inverse of the usual case - that the given instance is a
|
|
# subclass of the expected class. For example, the code under
|
|
# test does late binding to a subclass.
|
|
elif (isinstance(params[0], IsA) and
|
|
params[0]._IsSubClass(expected)):
|
|
params = params[1:]
|
|
|
|
# Check that each positional param is valid.
|
|
for i in range(len(params)):
|
|
try:
|
|
arg_name = self._args[i]
|
|
except IndexError:
|
|
if not self._has_varargs:
|
|
raise AttributeError(
|
|
'%s does not take %d or more positional '
|
|
'arguments' % (self._method.__name__, i))
|
|
else:
|
|
self._RecordArgumentGiven(arg_name, arg_status)
|
|
|
|
# Check each keyword argument.
|
|
for arg_name in named_params:
|
|
if arg_name not in arg_status and not self._has_varkw:
|
|
raise AttributeError('%s is not expecting keyword argument %s'
|
|
% (self._method.__name__, arg_name))
|
|
self._RecordArgumentGiven(arg_name, arg_status)
|
|
|
|
# Ensure all the required arguments have been given.
|
|
still_needed = [k for k, v in arg_status.items()
|
|
if v == MethodSignatureChecker._NEEDED]
|
|
if still_needed:
|
|
raise AttributeError('No values given for arguments: %s'
|
|
% (' '.join(sorted(still_needed))))
|
|
|
|
|
|
class MockMethod(object):
|
|
"""Callable mock method.
|
|
|
|
A MockMethod should act exactly like the method it mocks, accepting
|
|
parameters and returning a value, or throwing an exception (as specified).
|
|
When this method is called, it can optionally verify whether the called
|
|
method (name and signature) matches the expected method.
|
|
"""
|
|
|
|
def __init__(self, method_name, call_queue, replay_mode,
|
|
method_to_mock=None, description=None, class_to_bind=None):
|
|
"""Construct a new mock method.
|
|
|
|
Args:
|
|
# method_name: the name of the method
|
|
# call_queue: deque of calls, verify this call against the head,
|
|
# or add this call to the queue.
|
|
# replay_mode: False if we are recording, True if we are verifying
|
|
# calls against the call queue.
|
|
# method_to_mock: The actual method being mocked, used for
|
|
# introspection.
|
|
# description: optionally, a descriptive name for this method.
|
|
# Typically this is equal to the descriptive name of
|
|
# the method's class.
|
|
# class_to_bind: optionally, a class that is used for unbound
|
|
# methods (or functions in Python3) to which method
|
|
# is bound, in order not to loose binding
|
|
# information. If given, it will be used for
|
|
# checking the type of first method parameter
|
|
method_name: str
|
|
call_queue: list or deque
|
|
replay_mode: bool
|
|
method_to_mock: a method object
|
|
description: str or None
|
|
class_to_bind: type or None
|
|
"""
|
|
|
|
self._name = method_name
|
|
self.__name__ = method_name
|
|
self._call_queue = call_queue
|
|
if not isinstance(call_queue, collections.deque):
|
|
self._call_queue = collections.deque(self._call_queue)
|
|
self._replay_mode = replay_mode
|
|
self._description = description
|
|
|
|
self._params = None
|
|
self._named_params = None
|
|
self._return_value = None
|
|
self._exception = None
|
|
self._side_effects = None
|
|
|
|
try:
|
|
self._checker = MethodSignatureChecker(method_to_mock,
|
|
class_to_bind=class_to_bind)
|
|
except ValueError:
|
|
self._checker = None
|
|
|
|
def __call__(self, *params, **named_params):
|
|
"""Log parameters and return the specified return value.
|
|
|
|
If the Mock(Anything/Object) associated with this call is in record
|
|
mode, this MockMethod will be pushed onto the expected call queue.
|
|
If the mock is in replay mode, this will pop a MockMethod off the
|
|
top of the queue and verify this call is equal to the expected call.
|
|
|
|
Raises:
|
|
UnexpectedMethodCall if this call is supposed to match an expected
|
|
method call and it does not.
|
|
"""
|
|
|
|
self._params = params
|
|
self._named_params = named_params
|
|
|
|
if not self._replay_mode:
|
|
if self._checker is not None:
|
|
self._checker.Check(params, named_params)
|
|
self._call_queue.append(self)
|
|
return self
|
|
|
|
expected_method = self._VerifyMethodCall()
|
|
|
|
if expected_method._side_effects:
|
|
result = expected_method._side_effects(*params, **named_params)
|
|
if expected_method._return_value is None:
|
|
expected_method._return_value = result
|
|
|
|
if expected_method._exception:
|
|
raise expected_method._exception
|
|
|
|
return expected_method._return_value
|
|
|
|
def __getattr__(self, name):
|
|
"""Raise an AttributeError with a helpful message."""
|
|
|
|
raise AttributeError(
|
|
'MockMethod has no attribute "%s". '
|
|
'Did you remember to put your mocks in replay mode?' % name)
|
|
|
|
def __iter__(self):
|
|
"""Raise a TypeError with a helpful message."""
|
|
raise TypeError(
|
|
'MockMethod cannot be iterated. '
|
|
'Did you remember to put your mocks in replay mode?')
|
|
|
|
def next(self):
|
|
"""Raise a TypeError with a helpful message."""
|
|
raise TypeError(
|
|
'MockMethod cannot be iterated. '
|
|
'Did you remember to put your mocks in replay mode?')
|
|
|
|
def __next__(self):
|
|
"""Raise a TypeError with a helpful message."""
|
|
raise TypeError(
|
|
'MockMethod cannot be iterated. '
|
|
'Did you remember to put your mocks in replay mode?')
|
|
|
|
def _PopNextMethod(self):
|
|
"""Pop the next method from our call queue."""
|
|
try:
|
|
return self._call_queue.popleft()
|
|
except IndexError:
|
|
raise UnexpectedMethodCallError(self, None)
|
|
|
|
def _VerifyMethodCall(self):
|
|
"""Verify the called method is expected.
|
|
|
|
This can be an ordered method, or part of an unordered set.
|
|
|
|
Returns:
|
|
The expected mock method.
|
|
|
|
Raises:
|
|
UnexpectedMethodCall if the method called was not expected.
|
|
"""
|
|
|
|
expected = self._PopNextMethod()
|
|
|
|
# Loop here, because we might have a MethodGroup followed by another
|
|
# group.
|
|
while isinstance(expected, MethodGroup):
|
|
expected, method = expected.MethodCalled(self)
|
|
if method is not None:
|
|
return method
|
|
|
|
# This is a mock method, so just check equality.
|
|
if expected != self:
|
|
raise UnexpectedMethodCallError(self, expected)
|
|
|
|
return expected
|
|
|
|
def __str__(self):
|
|
params = ', '.join(
|
|
[repr(p) for p in self._params or []] +
|
|
['%s=%r' % x for x in sorted((self._named_params or {}).items())])
|
|
full_desc = "%s(%s) -> %r" % (self._name, params, self._return_value)
|
|
if self._description:
|
|
full_desc = "%s.%s" % (self._description, full_desc)
|
|
return full_desc
|
|
|
|
def __hash__(self):
|
|
return id(self)
|
|
|
|
def __eq__(self, rhs):
|
|
"""Test whether this MockMethod is equivalent to another MockMethod.
|
|
|
|
Args:
|
|
# rhs: the right hand side of the test
|
|
rhs: MockMethod
|
|
"""
|
|
|
|
return (isinstance(rhs, MockMethod) and
|
|
self._name == rhs._name and
|
|
self._params == rhs._params and
|
|
self._named_params == rhs._named_params)
|
|
|
|
def __ne__(self, rhs):
|
|
"""Test if this MockMethod is not equivalent to another MockMethod.
|
|
|
|
Args:
|
|
# rhs: the right hand side of the test
|
|
rhs: MockMethod
|
|
"""
|
|
|
|
return not self == rhs
|
|
|
|
def GetPossibleGroup(self):
|
|
"""Returns a possible group from the end of the call queue.
|
|
|
|
Return None if no other methods are on the stack.
|
|
"""
|
|
|
|
# Remove this method from the tail of the queue so we can add it
|
|
# to a group.
|
|
this_method = self._call_queue.pop()
|
|
assert this_method == self
|
|
|
|
# Determine if the tail of the queue is a group, or just a regular
|
|
# ordered mock method.
|
|
group = None
|
|
try:
|
|
group = self._call_queue[-1]
|
|
except IndexError:
|
|
pass
|
|
|
|
return group
|
|
|
|
def _CheckAndCreateNewGroup(self, group_name, group_class):
|
|
"""Checks if the last method (a possible group) is an instance of our
|
|
group_class. Adds the current method to this group or creates a
|
|
new one.
|
|
|
|
Args:
|
|
|
|
group_name: the name of the group.
|
|
group_class: the class used to create instance of this new group
|
|
"""
|
|
group = self.GetPossibleGroup()
|
|
|
|
# If this is a group, and it is the correct group, add the method.
|
|
if isinstance(group, group_class) and group.group_name() == group_name:
|
|
group.AddMethod(self)
|
|
return self
|
|
|
|
# Create a new group and add the method.
|
|
new_group = group_class(group_name)
|
|
new_group.AddMethod(self)
|
|
self._call_queue.append(new_group)
|
|
return self
|
|
|
|
def InAnyOrder(self, group_name="default"):
|
|
"""Move this method into a group of unordered calls.
|
|
|
|
A group of unordered calls must be defined together, and must be
|
|
executed in full before the next expected method can be called.
|
|
There can be multiple groups that are expected serially, if they are
|
|
given different group names. The same group name can be reused if there
|
|
is a standard method call, or a group with a different name, spliced
|
|
between usages.
|
|
|
|
Args:
|
|
group_name: the name of the unordered group.
|
|
|
|
Returns:
|
|
self
|
|
"""
|
|
return self._CheckAndCreateNewGroup(group_name, UnorderedGroup)
|
|
|
|
def MultipleTimes(self, group_name="default"):
|
|
"""Move method into group of calls which may be called multiple times.
|
|
|
|
A group of repeating calls must be defined together, and must be
|
|
executed in full before the next expected method can be called.
|
|
|
|
Args:
|
|
group_name: the name of the unordered group.
|
|
|
|
Returns:
|
|
self
|
|
"""
|
|
return self._CheckAndCreateNewGroup(group_name, MultipleTimesGroup)
|
|
|
|
def AndReturn(self, return_value):
|
|
"""Set the value to return when this method is called.
|
|
|
|
Args:
|
|
# return_value can be anything.
|
|
"""
|
|
|
|
self._return_value = return_value
|
|
return return_value
|
|
|
|
def AndRaise(self, exception):
|
|
"""Set the exception to raise when this method is called.
|
|
|
|
Args:
|
|
# exception: the exception to raise when this method is called.
|
|
exception: Exception
|
|
"""
|
|
|
|
self._exception = exception
|
|
|
|
def WithSideEffects(self, side_effects):
|
|
"""Set the side effects that are simulated when this method is called.
|
|
|
|
Args:
|
|
side_effects: A callable which modifies the parameters or other
|
|
relevant state which a given test case depends on.
|
|
|
|
Returns:
|
|
Self for chaining with AndReturn and AndRaise.
|
|
"""
|
|
self._side_effects = side_effects
|
|
return self
|
|
|
|
|
|
class Comparator:
|
|
"""Base class for all Mox comparators.
|
|
|
|
A Comparator can be used as a parameter to a mocked method when the exact
|
|
value is not known. For example, the code you are testing might build up
|
|
a long SQL string that is passed to your mock DAO. You're only interested
|
|
that the IN clause contains the proper primary keys, so you can set your
|
|
mock up as follows:
|
|
|
|
mock_dao.RunQuery(StrContains('IN (1, 2, 4, 5)')).AndReturn(mock_result)
|
|
|
|
Now whatever query is passed in must contain the string 'IN (1, 2, 4, 5)'.
|
|
|
|
A Comparator may replace one or more parameters, for example:
|
|
# return at most 10 rows
|
|
mock_dao.RunQuery(StrContains('SELECT'), 10)
|
|
|
|
or
|
|
|
|
# Return some non-deterministic number of rows
|
|
mock_dao.RunQuery(StrContains('SELECT'), IsA(int))
|
|
"""
|
|
|
|
def equals(self, rhs):
|
|
"""Special equals method that all comparators must implement.
|
|
|
|
Args:
|
|
rhs: any python object
|
|
"""
|
|
|
|
raise NotImplementedError('method must be implemented by a subclass.')
|
|
|
|
def __eq__(self, rhs):
|
|
return self.equals(rhs)
|
|
|
|
def __ne__(self, rhs):
|
|
return not self.equals(rhs)
|
|
|
|
|
|
class Is(Comparator):
|
|
"""Comparison class used to check identity, instead of equality."""
|
|
|
|
def __init__(self, obj):
|
|
self._obj = obj
|
|
|
|
def equals(self, rhs):
|
|
return rhs is self._obj
|
|
|
|
def __repr__(self):
|
|
return "<is %r (%s)>" % (self._obj, id(self._obj))
|
|
|
|
|
|
class IsA(Comparator):
|
|
"""This class wraps a basic Python type or class. It is used to verify
|
|
that a parameter is of the given type or class.
|
|
|
|
Example:
|
|
mock_dao.Connect(IsA(DbConnectInfo))
|
|
"""
|
|
|
|
def __init__(self, class_name):
|
|
"""Initialize IsA
|
|
|
|
Args:
|
|
class_name: basic python type or a class
|
|
"""
|
|
|
|
self._class_name = class_name
|
|
|
|
def equals(self, rhs):
|
|
"""Check to see if the RHS is an instance of class_name.
|
|
|
|
Args:
|
|
# rhs: the right hand side of the test
|
|
rhs: object
|
|
|
|
Returns:
|
|
bool
|
|
"""
|
|
|
|
try:
|
|
return isinstance(rhs, self._class_name)
|
|
except TypeError:
|
|
# Check raw types if there was a type error. This is helpful for
|
|
# things like cStringIO.StringIO.
|
|
return type(rhs) == type(self._class_name)
|
|
|
|
def _IsSubClass(self, clazz):
|
|
"""Check to see if the IsA comparators class is a subclass of clazz.
|
|
|
|
Args:
|
|
# clazz: a class object
|
|
|
|
Returns:
|
|
bool
|
|
"""
|
|
|
|
try:
|
|
return issubclass(self._class_name, clazz)
|
|
except TypeError:
|
|
# Check raw types if there was a type error. This is helpful for
|
|
# things like cStringIO.StringIO.
|
|
return type(clazz) == type(self._class_name)
|
|
|
|
def __repr__(self):
|
|
return 'mox.IsA(%s) ' % str(self._class_name)
|
|
|
|
|
|
class IsAlmost(Comparator):
|
|
"""Comparison class used to check whether a parameter is nearly equal
|
|
to a given value. Generally useful for floating point numbers.
|
|
|
|
Example mock_dao.SetTimeout((IsAlmost(3.9)))
|
|
"""
|
|
|
|
def __init__(self, float_value, places=7):
|
|
"""Initialize IsAlmost.
|
|
|
|
Args:
|
|
float_value: The value for making the comparison.
|
|
places: The number of decimal places to round to.
|
|
"""
|
|
|
|
self._float_value = float_value
|
|
self._places = places
|
|
|
|
def equals(self, rhs):
|
|
"""Check to see if RHS is almost equal to float_value
|
|
|
|
Args:
|
|
rhs: the value to compare to float_value
|
|
|
|
Returns:
|
|
bool
|
|
"""
|
|
|
|
try:
|
|
return round(rhs - self._float_value, self._places) == 0
|
|
except Exception:
|
|
# Probably because either float_value or rhs is not a number.
|
|
return False
|
|
|
|
def __repr__(self):
|
|
return str(self._float_value)
|
|
|
|
|
|
class StrContains(Comparator):
|
|
"""Comparison class used to check whether a substring exists in a
|
|
string parameter. This can be useful in mocking a database with SQL
|
|
passed in as a string parameter, for example.
|
|
|
|
Example:
|
|
mock_dao.RunQuery(StrContains('IN (1, 2, 4, 5)')).AndReturn(mock_result)
|
|
"""
|
|
|
|
def __init__(self, search_string):
|
|
"""Initialize.
|
|
|
|
Args:
|
|
# search_string: the string you are searching for
|
|
search_string: str
|
|
"""
|
|
|
|
self._search_string = search_string
|
|
|
|
def equals(self, rhs):
|
|
"""Check to see if the search_string is contained in the rhs string.
|
|
|
|
Args:
|
|
# rhs: the right hand side of the test
|
|
rhs: object
|
|
|
|
Returns:
|
|
bool
|
|
"""
|
|
|
|
try:
|
|
return rhs.find(self._search_string) > -1
|
|
except Exception:
|
|
return False
|
|
|
|
def __repr__(self):
|
|
return '<str containing \'%s\'>' % self._search_string
|
|
|
|
|
|
class Regex(Comparator):
|
|
"""Checks if a string matches a regular expression.
|
|
|
|
This uses a given regular expression to determine equality.
|
|
"""
|
|
|
|
def __init__(self, pattern, flags=0):
|
|
"""Initialize.
|
|
|
|
Args:
|
|
# pattern is the regular expression to search for
|
|
pattern: str
|
|
# flags passed to re.compile function as the second argument
|
|
flags: int
|
|
"""
|
|
self.flags = flags
|
|
self.regex = re.compile(pattern, flags=flags)
|
|
|
|
def equals(self, rhs):
|
|
"""Check to see if rhs matches regular expression pattern.
|
|
|
|
Returns:
|
|
bool
|
|
"""
|
|
|
|
try:
|
|
return self.regex.search(rhs) is not None
|
|
except Exception:
|
|
return False
|
|
|
|
def __repr__(self):
|
|
s = '<regular expression \'%s\'' % self.regex.pattern
|
|
if self.flags:
|
|
s += ', flags=%d' % self.flags
|
|
s += '>'
|
|
return s
|
|
|
|
|
|
class In(Comparator):
|
|
"""Checks whether an item (or key) is in a list (or dict) parameter.
|
|
|
|
Example:
|
|
mock_dao.GetUsersInfo(In('expectedUserName')).AndReturn(mock_result)
|
|
"""
|
|
|
|
def __init__(self, key):
|
|
"""Initialize.
|
|
|
|
Args:
|
|
# key is any thing that could be in a list or a key in a dict
|
|
"""
|
|
|
|
self._key = key
|
|
|
|
def equals(self, rhs):
|
|
"""Check to see whether key is in rhs.
|
|
|
|
Args:
|
|
rhs: dict
|
|
|
|
Returns:
|
|
bool
|
|
"""
|
|
|
|
try:
|
|
return self._key in rhs
|
|
except Exception:
|
|
return False
|
|
|
|
def __repr__(self):
|
|
return '<sequence or map containing \'%s\'>' % str(self._key)
|
|
|
|
|
|
class Not(Comparator):
|
|
"""Checks whether a predicates is False.
|
|
|
|
Example:
|
|
mock_dao.UpdateUsers(Not(ContainsKeyValue('stevepm',
|
|
stevepm_user_info)))
|
|
"""
|
|
|
|
def __init__(self, predicate):
|
|
"""Initialize.
|
|
|
|
Args:
|
|
# predicate: a Comparator instance.
|
|
"""
|
|
|
|
assert isinstance(predicate, Comparator), ("predicate %r must be a"
|
|
" Comparator." % predicate)
|
|
self._predicate = predicate
|
|
|
|
def equals(self, rhs):
|
|
"""Check to see whether the predicate is False.
|
|
|
|
Args:
|
|
rhs: A value that will be given in argument of the predicate.
|
|
|
|
Returns:
|
|
bool
|
|
"""
|
|
|
|
try:
|
|
return not self._predicate.equals(rhs)
|
|
except Exception:
|
|
return False
|
|
|
|
def __repr__(self):
|
|
return '<not \'%s\'>' % self._predicate
|
|
|
|
|
|
class ContainsKeyValue(Comparator):
|
|
"""Checks whether a key/value pair is in a dict parameter.
|
|
|
|
Example:
|
|
mock_dao.UpdateUsers(ContainsKeyValue('stevepm', stevepm_user_info))
|
|
"""
|
|
|
|
def __init__(self, key, value):
|
|
"""Initialize.
|
|
|
|
Args:
|
|
# key: a key in a dict
|
|
# value: the corresponding value
|
|
"""
|
|
|
|
self._key = key
|
|
self._value = value
|
|
|
|
def equals(self, rhs):
|
|
"""Check whether the given key/value pair is in the rhs dict.
|
|
|
|
Returns:
|
|
bool
|
|
"""
|
|
|
|
try:
|
|
return rhs[self._key] == self._value
|
|
except Exception:
|
|
return False
|
|
|
|
def __repr__(self):
|
|
return '<map containing the entry \'%s: %s\'>' % (str(self._key),
|
|
str(self._value))
|
|
|
|
|
|
class ContainsAttributeValue(Comparator):
|
|
"""Checks whether passed parameter contains attributes with a given value.
|
|
|
|
Example:
|
|
mock_dao.UpdateSomething(ContainsAttribute('stevepm', stevepm_user_info))
|
|
"""
|
|
|
|
def __init__(self, key, value):
|
|
"""Initialize.
|
|
|
|
Args:
|
|
# key: an attribute name of an object
|
|
# value: the corresponding value
|
|
"""
|
|
|
|
self._key = key
|
|
self._value = value
|
|
|
|
def equals(self, rhs):
|
|
"""Check if the given attribute has a matching value in the rhs object.
|
|
|
|
Returns:
|
|
bool
|
|
"""
|
|
|
|
try:
|
|
return getattr(rhs, self._key) == self._value
|
|
except Exception:
|
|
return False
|
|
|
|
|
|
class SameElementsAs(Comparator):
|
|
"""Checks whether sequences contain the same elements (ignoring order).
|
|
|
|
Example:
|
|
mock_dao.ProcessUsers(SameElementsAs('stevepm', 'salomaki'))
|
|
"""
|
|
|
|
def __init__(self, expected_seq):
|
|
"""Initialize.
|
|
|
|
Args:
|
|
expected_seq: a sequence
|
|
"""
|
|
# Store in case expected_seq is an iterator.
|
|
self._expected_list = list(expected_seq)
|
|
|
|
def equals(self, actual_seq):
|
|
"""Check to see whether actual_seq has same elements as expected_seq.
|
|
|
|
Args:
|
|
actual_seq: sequence
|
|
|
|
Returns:
|
|
bool
|
|
"""
|
|
try:
|
|
# Store in case actual_seq is an iterator. We potentially iterate
|
|
# twice: once to make the dict, once in the list fallback.
|
|
actual_list = list(actual_seq)
|
|
except TypeError:
|
|
# actual_seq cannot be read as a sequence.
|
|
#
|
|
# This happens because Mox uses __eq__ both to check object
|
|
# equality (in MethodSignatureChecker) and to invoke Comparators.
|
|
return False
|
|
|
|
try:
|
|
return set(self._expected_list) == set(actual_list)
|
|
except TypeError:
|
|
# Fall back to slower list-compare if any of the objects
|
|
# are unhashable.
|
|
if len(self._expected_list) != len(actual_list):
|
|
return False
|
|
for el in actual_list:
|
|
if el not in self._expected_list:
|
|
return False
|
|
return True
|
|
|
|
def __repr__(self):
|
|
return '<sequence with same elements as \'%s\'>' % self._expected_list
|
|
|
|
|
|
class And(Comparator):
|
|
"""Evaluates one or more Comparators on RHS, returns an AND of the results.
|
|
"""
|
|
|
|
def __init__(self, *args):
|
|
"""Initialize.
|
|
|
|
Args:
|
|
*args: One or more Comparator
|
|
"""
|
|
|
|
self._comparators = args
|
|
|
|
def equals(self, rhs):
|
|
"""Checks whether all Comparators are equal to rhs.
|
|
|
|
Args:
|
|
# rhs: can be anything
|
|
|
|
Returns:
|
|
bool
|
|
"""
|
|
|
|
for comparator in self._comparators:
|
|
if not comparator.equals(rhs):
|
|
return False
|
|
|
|
return True
|
|
|
|
def __repr__(self):
|
|
return '<AND %s>' % str(self._comparators)
|
|
|
|
|
|
class Or(Comparator):
|
|
"""Evaluates one or more Comparators on RHS; returns OR of the results."""
|
|
|
|
def __init__(self, *args):
|
|
"""Initialize.
|
|
|
|
Args:
|
|
*args: One or more Mox comparators
|
|
"""
|
|
|
|
self._comparators = args
|
|
|
|
def equals(self, rhs):
|
|
"""Checks whether any Comparator is equal to rhs.
|
|
|
|
Args:
|
|
# rhs: can be anything
|
|
|
|
Returns:
|
|
bool
|
|
"""
|
|
|
|
for comparator in self._comparators:
|
|
if comparator.equals(rhs):
|
|
return True
|
|
|
|
return False
|
|
|
|
def __repr__(self):
|
|
return '<OR %s>' % str(self._comparators)
|
|
|
|
|
|
class Func(Comparator):
|
|
"""Call a function that should verify the parameter passed in is correct.
|
|
|
|
You may need the ability to perform more advanced operations on the
|
|
parameter in order to validate it. You can use this to have a callable
|
|
validate any parameter. The callable should return either True or False.
|
|
|
|
|
|
Example:
|
|
|
|
def myParamValidator(param):
|
|
# Advanced logic here
|
|
return True
|
|
|
|
mock_dao.DoSomething(Func(myParamValidator), true)
|
|
"""
|
|
|
|
def __init__(self, func):
|
|
"""Initialize.
|
|
|
|
Args:
|
|
func: callable that takes one parameter and returns a bool
|
|
"""
|
|
|
|
self._func = func
|
|
|
|
def equals(self, rhs):
|
|
"""Test whether rhs passes the function test.
|
|
|
|
rhs is passed into func.
|
|
|
|
Args:
|
|
rhs: any python object
|
|
|
|
Returns:
|
|
the result of func(rhs)
|
|
"""
|
|
|
|
return self._func(rhs)
|
|
|
|
def __repr__(self):
|
|
return str(self._func)
|
|
|
|
|
|
class IgnoreArg(Comparator):
|
|
"""Ignore an argument.
|
|
|
|
This can be used when we don't care about an argument of a method call.
|
|
|
|
Example:
|
|
# Check if CastMagic is called with 3 as first arg and
|
|
# 'disappear' as third.
|
|
mymock.CastMagic(3, IgnoreArg(), 'disappear')
|
|
"""
|
|
|
|
def equals(self, unused_rhs):
|
|
"""Ignores arguments and returns True.
|
|
|
|
Args:
|
|
unused_rhs: any python object
|
|
|
|
Returns:
|
|
always returns True
|
|
"""
|
|
|
|
return True
|
|
|
|
def __repr__(self):
|
|
return '<IgnoreArg>'
|
|
|
|
|
|
class Value(Comparator):
|
|
"""Compares argument against a remembered value.
|
|
|
|
To be used in conjunction with Remember comparator. See Remember()
|
|
for example.
|
|
"""
|
|
|
|
def __init__(self):
|
|
self._value = None
|
|
self._has_value = False
|
|
|
|
def store_value(self, rhs):
|
|
self._value = rhs
|
|
self._has_value = True
|
|
|
|
def equals(self, rhs):
|
|
if not self._has_value:
|
|
return False
|
|
else:
|
|
return rhs == self._value
|
|
|
|
def __repr__(self):
|
|
if self._has_value:
|
|
return "<Value %r>" % self._value
|
|
else:
|
|
return "<Value>"
|
|
|
|
|
|
class Remember(Comparator):
|
|
"""Remembers the argument to a value store.
|
|
|
|
To be used in conjunction with Value comparator.
|
|
|
|
Example:
|
|
# Remember the argument for one method call.
|
|
users_list = Value()
|
|
mock_dao.ProcessUsers(Remember(users_list))
|
|
|
|
# Check argument against remembered value.
|
|
mock_dao.ReportUsers(users_list)
|
|
"""
|
|
|
|
def __init__(self, value_store):
|
|
if not isinstance(value_store, Value):
|
|
raise TypeError(
|
|
"value_store is not an instance of the Value class")
|
|
self._value_store = value_store
|
|
|
|
def equals(self, rhs):
|
|
self._value_store.store_value(rhs)
|
|
return True
|
|
|
|
def __repr__(self):
|
|
return "<Remember %d>" % id(self._value_store)
|
|
|
|
|
|
class MethodGroup(object):
|
|
"""Base class containing common behaviour for MethodGroups."""
|
|
|
|
def __init__(self, group_name):
|
|
self._group_name = group_name
|
|
|
|
def group_name(self):
|
|
return self._group_name
|
|
|
|
def __str__(self):
|
|
return '<%s "%s">' % (self.__class__.__name__, self._group_name)
|
|
|
|
def AddMethod(self, mock_method):
|
|
raise NotImplementedError
|
|
|
|
def MethodCalled(self, mock_method):
|
|
raise NotImplementedError
|
|
|
|
def IsSatisfied(self):
|
|
raise NotImplementedError
|
|
|
|
|
|
class UnorderedGroup(MethodGroup):
|
|
"""UnorderedGroup holds a set of method calls that may occur in any order.
|
|
|
|
This construct is helpful for non-deterministic events, such as iterating
|
|
over the keys of a dict.
|
|
"""
|
|
|
|
def __init__(self, group_name):
|
|
super(UnorderedGroup, self).__init__(group_name)
|
|
self._methods = []
|
|
|
|
def __str__(self):
|
|
return '%s "%s" pending calls:\n%s' % (
|
|
self.__class__.__name__,
|
|
self._group_name,
|
|
"\n".join(str(method) for method in self._methods))
|
|
|
|
def AddMethod(self, mock_method):
|
|
"""Add a method to this group.
|
|
|
|
Args:
|
|
mock_method: A mock method to be added to this group.
|
|
"""
|
|
|
|
self._methods.append(mock_method)
|
|
|
|
def MethodCalled(self, mock_method):
|
|
"""Remove a method call from the group.
|
|
|
|
If the method is not in the set, an UnexpectedMethodCallError will be
|
|
raised.
|
|
|
|
Args:
|
|
mock_method: a mock method that should be equal to a method in the
|
|
group.
|
|
|
|
Returns:
|
|
The mock method from the group
|
|
|
|
Raises:
|
|
UnexpectedMethodCallError if the mock_method was not in the group.
|
|
"""
|
|
|
|
# Check to see if this method exists, and if so, remove it from the set
|
|
# and return it.
|
|
for method in self._methods:
|
|
if method == mock_method:
|
|
# Remove the called mock_method instead of the method in the
|
|
# group. The called method will match any comparators when
|
|
# equality is checked during removal. The method in the group
|
|
# could pass a comparator to another comparator during the
|
|
# equality check.
|
|
self._methods.remove(mock_method)
|
|
|
|
# If group is not empty, put it back at the head of the queue.
|
|
if not self.IsSatisfied():
|
|
mock_method._call_queue.appendleft(self)
|
|
|
|
return self, method
|
|
|
|
raise UnexpectedMethodCallError(mock_method, self)
|
|
|
|
def IsSatisfied(self):
|
|
"""Return True if there are not any methods in this group."""
|
|
|
|
return len(self._methods) == 0
|
|
|
|
|
|
class MultipleTimesGroup(MethodGroup):
|
|
"""MultipleTimesGroup holds methods that may be called any number of times.
|
|
|
|
Note: Each method must be called at least once.
|
|
|
|
This is helpful, if you don't know or care how many times a method is
|
|
called.
|
|
"""
|
|
|
|
def __init__(self, group_name):
|
|
super(MultipleTimesGroup, self).__init__(group_name)
|
|
self._methods = set()
|
|
self._methods_left = set()
|
|
|
|
def AddMethod(self, mock_method):
|
|
"""Add a method to this group.
|
|
|
|
Args:
|
|
mock_method: A mock method to be added to this group.
|
|
"""
|
|
|
|
self._methods.add(mock_method)
|
|
self._methods_left.add(mock_method)
|
|
|
|
def MethodCalled(self, mock_method):
|
|
"""Remove a method call from the group.
|
|
|
|
If the method is not in the set, an UnexpectedMethodCallError will be
|
|
raised.
|
|
|
|
Args:
|
|
mock_method: a mock method that should be equal to a method in the
|
|
group.
|
|
|
|
Returns:
|
|
The mock method from the group
|
|
|
|
Raises:
|
|
UnexpectedMethodCallError if the mock_method was not in the group.
|
|
"""
|
|
|
|
# Check to see if this method exists, and if so add it to the set of
|
|
# called methods.
|
|
for method in self._methods:
|
|
if method == mock_method:
|
|
self._methods_left.discard(method)
|
|
# Always put this group back on top of the queue,
|
|
# because we don't know when we are done.
|
|
mock_method._call_queue.appendleft(self)
|
|
return self, method
|
|
|
|
if self.IsSatisfied():
|
|
next_method = mock_method._PopNextMethod()
|
|
return next_method, None
|
|
else:
|
|
raise UnexpectedMethodCallError(mock_method, self)
|
|
|
|
def IsSatisfied(self):
|
|
"""Return True if all methods in group are called at least once."""
|
|
return len(self._methods_left) == 0
|
|
|
|
|
|
class MoxMetaTestBase(type):
|
|
"""Metaclass to add mox cleanup and verification to every test.
|
|
|
|
As the mox unit testing class is being constructed (MoxTestBase or a
|
|
subclass), this metaclass will modify all test functions to call the
|
|
CleanUpMox method of the test class after they finish. This means that
|
|
unstubbing and verifying will happen for every test with no additional
|
|
code, and any failures will result in test failures as opposed to errors.
|
|
"""
|
|
|
|
def __init__(cls, name, bases, d):
|
|
type.__init__(cls, name, bases, d)
|
|
|
|
# also get all the attributes from the base classes to account
|
|
# for a case when test class is not the immediate child of MoxTestBase
|
|
for base in bases:
|
|
for attr_name in dir(base):
|
|
if attr_name not in d:
|
|
d[attr_name] = getattr(base, attr_name)
|
|
|
|
for func_name, func in d.items():
|
|
if func_name.startswith('test') and callable(func):
|
|
|
|
setattr(cls, func_name, MoxMetaTestBase.CleanUpTest(cls, func))
|
|
|
|
@staticmethod
|
|
def CleanUpTest(cls, func):
|
|
"""Adds Mox cleanup code to any MoxTestBase method.
|
|
|
|
Always unsets stubs after a test. Will verify all mocks for tests that
|
|
otherwise pass.
|
|
|
|
Args:
|
|
cls: MoxTestBase or subclass; the class whose method we are
|
|
altering.
|
|
func: method; the method of the MoxTestBase test class we wish to
|
|
alter.
|
|
|
|
Returns:
|
|
The modified method.
|
|
"""
|
|
def new_method(self, *args, **kwargs):
|
|
mox_obj = getattr(self, 'mox', None)
|
|
stubout_obj = getattr(self, 'stubs', None)
|
|
cleanup_mox = False
|
|
cleanup_stubout = False
|
|
if mox_obj and isinstance(mox_obj, Mox):
|
|
cleanup_mox = True
|
|
if stubout_obj and isinstance(stubout_obj,
|
|
stubout.StubOutForTesting):
|
|
cleanup_stubout = True
|
|
try:
|
|
func(self, *args, **kwargs)
|
|
finally:
|
|
if cleanup_mox:
|
|
mox_obj.UnsetStubs()
|
|
if cleanup_stubout:
|
|
stubout_obj.UnsetAll()
|
|
stubout_obj.SmartUnsetAll()
|
|
if cleanup_mox:
|
|
mox_obj.VerifyAll()
|
|
new_method.__name__ = func.__name__
|
|
new_method.__doc__ = func.__doc__
|
|
new_method.__module__ = func.__module__
|
|
return new_method
|
|
|
|
|
|
_MoxTestBase = MoxMetaTestBase('_MoxTestBase', (unittest.TestCase, ), {})
|
|
|
|
|
|
class MoxTestBase(_MoxTestBase):
|
|
"""Convenience test class to make stubbing easier.
|
|
|
|
Sets up a "mox" attribute which is an instance of Mox (any mox tests will
|
|
want this), and a "stubs" attribute that is an instance of
|
|
StubOutForTesting (needed at times). Also automatically unsets any stubs
|
|
and verifies that all mock methods have been called at the end of each
|
|
test, eliminating boilerplate code.
|
|
"""
|
|
|
|
def setUp(self):
|
|
super(MoxTestBase, self).setUp()
|
|
self.mox = Mox()
|
|
self.stubs = stubout.StubOutForTesting()
|