# mock.py | |
# Test tools for mocking and patching. | |
# Copyright (C) 2007-2009 Michael Foord | |
# E-mail: fuzzyman AT voidspace DOT org DOT uk | |
# mock 0.5.0 | |
# http://www.voidspace.org.uk/python/mock/ | |
# Released subject to the BSD License | |
# Please see http://www.voidspace.org.uk/python/license.shtml | |
# Scripts maintained at http://www.voidspace.org.uk/python/index.shtml | |
# Comments, suggestions and bug reports welcome. | |
__all__ = ( | |
'Mock', | |
'MakeMock', | |
'patch', | |
'patch_object', | |
'sentinel', | |
'DEFAULT' | |
) | |
__version__ = '0.5.0 alpha' | |
class SentinelObject(object): | |
def __init__(self, name): | |
self.name = name | |
def __repr__(self): | |
return '<SentinelObject "%s">' % self.name | |
class Sentinel(object): | |
def __init__(self): | |
self._sentinels = {} | |
def __getattr__(self, name): | |
return self._sentinels.setdefault(name, SentinelObject(name)) | |
sentinel = Sentinel() | |
DEFAULT = sentinel.DEFAULT | |
class OldStyleClass: | |
pass | |
ClassType = type(OldStyleClass) | |
def _is_magic(name): | |
return '__%s__' % name[2:-2] == name | |
def _copy(value): | |
if type(value) in (dict, list, tuple, set): | |
return type(value)(value) | |
return value | |
class Mock(object): | |
def __new__(cls, spec=None, magics=None, *args, **kwargs): | |
if isinstance(spec, list): | |
magics = [method[2:-2] for method in spec if (_is_magic(method) and method[2:-2] in magic_methods)] | |
elif spec is not None: | |
magics = [method for method in magic_methods if hasattr(spec, '__%s__' % method)] | |
elif isinstance(magics, basestring): | |
magics = magics.split() | |
if magics: | |
# It might be magic, but I like it | |
cls = MakeMock(magics) | |
return object.__new__(cls) | |
def __init__(self, spec=None, magics=None, side_effect=None, | |
return_value=DEFAULT, name=None, parent=None, | |
items=None, wraps=None): | |
self._parent = parent | |
self._name = name | |
if spec is not None and not isinstance(spec, list): | |
spec = [member for member in dir(spec) if not _is_magic(member)] | |
self._methods = spec | |
self._children = {} | |
self._return_value = return_value | |
self.side_effect = side_effect | |
self._wraps = wraps | |
self.__items = None | |
self.reset_mock() | |
if self._has_items(): | |
if items is None: | |
items = {} | |
self._items = items | |
self.__items = _copy(items) | |
def reset_mock(self): | |
self.called = False | |
self.call_args = None | |
self.call_count = 0 | |
self.call_args_list = [] | |
self.method_calls = [] | |
for child in self._children.itervalues(): | |
child.reset_mock() | |
if isinstance(self._return_value, Mock): | |
self._return_value.reset_mock() | |
if self._has_items(): | |
self._items = _copy(self.__items) | |
def _has_items(self): | |
# Overriden in MagicMock | |
return False | |
def __get_return_value(self): | |
if self._return_value is DEFAULT: | |
self._return_value = Mock() | |
return self._return_value | |
def __set_return_value(self, value): | |
self._return_value = value | |
return_value = property(__get_return_value, __set_return_value) | |
def __call__(self, *args, **kwargs): | |
self.called = True | |
self.call_count += 1 | |
self.call_args = (args, kwargs) | |
self.call_args_list.append((args, kwargs)) | |
parent = self._parent | |
name = self._name | |
while parent is not None: | |
parent.method_calls.append((name, args, kwargs)) | |
if parent._parent is None: | |
break | |
name = parent._name + '.' + name | |
parent = parent._parent | |
ret_val = self.return_value | |
if self.side_effect is not None: | |
ret_val = self.side_effect(*args, **kwargs) | |
if ret_val is DEFAULT: | |
ret_val = self.return_value | |
if self._wraps is not None: | |
return self._wraps(*args, **kwargs) | |
return ret_val | |
def __getattr__(self, name): | |
if self._methods is not None: | |
if name not in self._methods: | |
raise AttributeError("Mock object has no attribute '%s'" % name) | |
elif _is_magic(name): | |
raise AttributeError(name) | |
if name not in self._children: | |
wraps = None | |
if self._wraps is not None: | |
wraps = getattr(self._wraps, name) | |
self._children[name] = Mock(parent=self, name=name, wraps=wraps) | |
return self._children[name] | |
def assert_called_with(self, *args, **kwargs): | |
assert self.call_args == (args, kwargs), 'Expected: %s\nCalled with: %s' % ((args, kwargs), self.call_args) | |
def _args(name, *args): | |
return (name, args, {}) | |
def __getitem__(self, key): | |
val = self._items[key] | |
self.method_calls.append(_args('__getitem__', key)) | |
return val | |
def __setitem__(self, key, value): | |
self.method_calls.append(_args('__setitem__', key, value)) | |
self._items[key] = value | |
def __delitem__(self, key): | |
self.method_calls.append(_args('__delitem__', key)) | |
del self._items[key] | |
def __iter__(self): | |
self.method_calls.append(_args('__iter__')) | |
for item in list(self._items): | |
yield item | |
def __len__(self): | |
self.method_calls.append(_args('__len__')) | |
return len(self._items) | |
def __contains__(self, key): | |
self.method_calls.append(_args('__contains__', key)) | |
return key in self._items | |
def __nonzero__(self): | |
self.method_calls.append(_args('__nonzero__')) | |
return bool(self._items) | |
magic_methods = { | |
'delitem': __delitem__, | |
'getitem': __getitem__, | |
'setitem': __setitem__, | |
'iter': __iter__, | |
'len': __len__, | |
'contains': __contains__, | |
'nonzero': __nonzero__ | |
} | |
def MakeMock(members): | |
class MagicMock(Mock): | |
def _has_items(self): | |
return True | |
if 'all' in members: | |
members = magic_methods.keys() | |
for method in members: | |
if method not in magic_methods: | |
raise NameError("Unknown magic method %r" % method) | |
impl = magic_methods[method] | |
name = '__%s__' % method | |
setattr(MagicMock, name, impl) | |
return MagicMock | |
def _dot_lookup(thing, comp, import_path): | |
try: | |
return getattr(thing, comp) | |
except AttributeError: | |
__import__(import_path) | |
return getattr(thing, comp) | |
def _importer(target): | |
components = target.split('.') | |
import_path = components.pop(0) | |
thing = __import__(import_path) | |
for comp in components: | |
import_path += ".%s" % comp | |
thing = _dot_lookup(thing, comp, import_path) | |
return thing | |
class _patch(object): | |
def __init__(self, target, attribute, new, spec, magics, create): | |
self.target = target | |
self.attribute = attribute | |
self.new = new | |
self.spec = spec | |
self.magics = magics | |
self.create = create | |
def __call__(self, func): | |
if hasattr(func, 'patchings'): | |
func.patchings.append(self) | |
return func | |
def patched(*args, **keywargs): | |
# don't use a with here (backwards compatability with 2.5) | |
extra_args = [] | |
for patching in patched.patchings: | |
arg = patching.__enter__() | |
if patching.new is DEFAULT: | |
extra_args.append(arg) | |
args += tuple(extra_args) | |
try: | |
return func(*args, **keywargs) | |
finally: | |
for patching in getattr(patched, 'patchings', []): | |
patching.__exit__() | |
patched.patchings = [self] | |
patched.__name__ = func.__name__ | |
patched.compat_co_firstlineno = getattr(func, "compat_co_firstlineno", | |
func.func_code.co_firstlineno) | |
return patched | |
def get_original(self): | |
try: | |
return getattr(self.target, self.attribute) | |
except AttributeError: | |
if not self.create: | |
raise | |
return DEFAULT | |
def __enter__(self): | |
new, spec, magics = self.new, self.spec, self.magics | |
original = self.get_original() | |
if new is DEFAULT: | |
inherit = False | |
if spec == True: | |
# set spec to the object we are replacing | |
spec = original | |
if isinstance(spec, (type, ClassType)): | |
inherit = True | |
new = Mock(spec=spec, magics=magics) | |
if inherit: | |
# deliberately ignoring magics as we are using spec | |
new.return_value = Mock(spec=spec) | |
self.temp_original = original | |
setattr(self.target, self.attribute, new) | |
return new | |
def __exit__(self, *_): | |
if self.temp_original is not DEFAULT: | |
setattr(self.target, self.attribute, self.temp_original) | |
else: | |
delattr(self.target, self.attribute) | |
del self.temp_original | |
def patch_object(target, attribute, new=DEFAULT, spec=None, magics=None, create=False): | |
return _patch(target, attribute, new, spec, magics, create) | |
def patch(target, new=DEFAULT, spec=None, magics=None, create=False): | |
try: | |
target, attribute = target.rsplit('.', 1) | |
except (TypeError, ValueError): | |
raise TypeError("Need a valid target to patch. You supplied: %r" % (target,)) | |
target = _importer(target) | |
return _patch(target, attribute, new, spec, magics, create) | |