Fixed main thread patching and moved the code into patcher.py. Implemented a sys.modules state saver that simplifies much of the code in patcher.py. Improved comments and docs in there too.

This commit is contained in:
Ryan Williams
2010-06-04 13:33:14 -07:00
parent 88bf9d0240
commit 0793b503eb
2 changed files with 79 additions and 56 deletions

View File

@@ -11,13 +11,3 @@ patcher.inject('threading',
('time', time)) ('time', time))
del patcher del patcher
def _patch_main_thread(mod):
# this is some gnarly patching for the threading module;
# if threading is imported before we patch (it nearly always is),
# then the main thread will have the wrong key in threading._active,
# so, we try and replace that key with the correct one here
# this works best if there are no other threads besides the main one
curthread = mod._active.pop(mod._get_ident(), None)
if curthread:
mod._active[thread.get_ident()] = curthread

View File

@@ -5,6 +5,33 @@ __all__ = ['inject', 'import_patched', 'monkey_patch', 'is_monkey_patched']
__exclude = set(('__builtins__', '__file__', '__name__')) __exclude = set(('__builtins__', '__file__', '__name__'))
class SysModulesSaver(object):
"""Class that captures some subset of the current state of
sys.modules. Pass in an iterator of module names to the
constructor."""
def __init__(self, module_names=()):
self._saved = {}
self.save(*module_names)
def save(self, *module_names):
"""Saves the named modules to the object."""
for modname in module_names:
self._saved[modname] = sys.modules.get(modname, None)
def restore(self):
"""Restores the modules that the saver knows about into
sys.modules.
"""
for modname, mod in self._saved.iteritems():
if mod is not None:
sys.modules[modname] = mod
else:
try:
del sys.modules[modname]
except KeyError:
pass
def inject(module_name, new_globals, *additional_modules): def inject(module_name, new_globals, *additional_modules):
"""Base method for "injecting" greened modules into an imported module. It """Base method for "injecting" greened modules into an imported module. It
imports the module specified in *module_name*, arranging things so imports the module specified in *module_name*, arranging things so
@@ -34,16 +61,20 @@ def inject(module_name, new_globals, *additional_modules):
_green_socket_modules() + _green_socket_modules() +
_green_thread_modules() + _green_thread_modules() +
_green_time_modules()) _green_time_modules())
# after this we are gonna screw with sys.modules, so capture the
# state of all the modules we're going to mess with
saver = SysModulesSaver([name for name, m in additional_modules])
saver.save(module_name)
## Put the specified modules in sys.modules for the duration of the import # Cover the target modules so that when you import the module it
saved = {} # sees only the patched versions
for name, mod in additional_modules: for name, mod in additional_modules:
saved[name] = sys.modules.get(name, None)
sys.modules[name] = mod sys.modules[name] = mod
## Remove the old module from sys.modules and reimport it while ## Remove the old module from sys.modules and reimport it while
## the specified modules are in place ## the specified modules are in place
old_module = sys.modules.pop(module_name, None) sys.modules.pop(module_name, None)
try: try:
module = __import__(module_name, {}, {}, module_name.split('.')[:-1]) module = __import__(module_name, {}, {}, module_name.split('.')[:-1])
@@ -56,18 +87,7 @@ def inject(module_name, new_globals, *additional_modules):
## Keep a reference to the new module to prevent it from dying ## Keep a reference to the new module to prevent it from dying
sys.modules[patched_name] = module sys.modules[patched_name] = module
finally: finally:
## Put the original module back saver.restore() ## Put the original modules back
if old_module is not None:
sys.modules[module_name] = old_module
elif module_name in sys.modules:
del sys.modules[module_name]
## Put all the saved modules back
for name, mod in additional_modules:
if saved[name] is not None:
sys.modules[name] = saved[name]
else:
del sys.modules[name]
return module return module
@@ -86,8 +106,11 @@ def import_patched(module_name, *additional_modules, **kw_additional_modules):
def patch_function(func, *additional_modules): def patch_function(func, *additional_modules):
"""Huge hack here -- patches the specified modules for the """Decorator that returns a version of the function that patches
duration of the function call.""" some modules for the duration of the function call. This is
deeply gross and should only be used for functions that import
network libraries within their function bodies that there is no
way of getting around."""
if not additional_modules: if not additional_modules:
# supply some defaults # supply some defaults
additional_modules = ( additional_modules = (
@@ -98,50 +121,44 @@ def patch_function(func, *additional_modules):
_green_time_modules()) _green_time_modules())
def patched(*args, **kw): def patched(*args, **kw):
saved = {} saver = SysModulesSaver(additional_modules.keys())
for name, mod in additional_modules: for name, mod in additional_modules:
saved[name] = sys.modules.get(name, None)
sys.modules[name] = mod sys.modules[name] = mod
try: try:
return func(*args, **kw) return func(*args, **kw)
finally: finally:
## Put all the saved modules back saver.restore()
for name, mod in additional_modules:
if saved[name] is not None:
sys.modules[name] = saved[name]
else:
del sys.modules[name]
return patched return patched
def _original_patch_function(func, *module_names): def _original_patch_function(func, *module_names):
"""Kind of the opposite of patch_function; wraps a function such """Kind of the contrapositive of patch_function: decorates a
that sys.modules is populated only with the unpatched versions of function such that when it's called, sys.modules is populated only
the specified modules. Also a gross hack; tell your kids not to with the unpatched versions of the specified modules. Unlike
import inside function bodies!""" patch_function, only the names of the modules need be supplied,
and there are no defaults. This is a gross hack; tell your kids not
to import inside function bodies!"""
def patched(*args, **kw): def patched(*args, **kw):
saved = {} saver = SysModulesSaver(module_names)
for name in module_names: for name in module_names:
saved[name] = sys.modules.get(name, None)
sys.modules[name] = original(name) sys.modules[name] = original(name)
try: try:
return func(*args, **kw) return func(*args, **kw)
finally: finally:
for name in module_names: saver.restore()
if saved[name] is not None:
sys.modules[name] = saved[name]
else:
del sys.modules[name]
return patched return patched
def original(modname): def original(modname):
""" This returns an unpatched version of a module; this is useful for
Eventlet itself (i.e. tpool)."""
original_name = '__original_module_' + modname original_name = '__original_module_' + modname
if original_name in sys.modules: if original_name in sys.modules:
return sys.modules.get(original_name) return sys.modules.get(original_name)
# re-import the "pure" module and store it in the global _originals # re-import the "pure" module and store it in the global _originals
# dict; be sure to restore whatever module had that name already # dict; be sure to restore whatever module had that name already
current_mod = sys.modules.pop(modname, None) saver = SysModulesSaver((modname,))
sys.modules.pop(modname, None)
try: try:
real_mod = __import__(modname, {}, {}, modname.split('.')[:-1]) real_mod = __import__(modname, {}, {}, modname.split('.')[:-1])
# hacky hack: Queue's constructor imports threading; therefore # hacky hack: Queue's constructor imports threading; therefore
@@ -149,12 +166,11 @@ def original(modname):
# original threading # original threading
if modname == 'Queue': if modname == 'Queue':
real_mod.Queue.__init__ = _original_patch_function(real_mod.Queue.__init__, 'threading') real_mod.Queue.__init__ = _original_patch_function(real_mod.Queue.__init__, 'threading')
# save a reference to the unpatched module so it doesn't get lost
sys.modules[original_name] = real_mod sys.modules[original_name] = real_mod
finally: finally:
if current_mod is not None: saver.restore()
sys.modules[modname] = current_mod
else:
del sys.modules[modname]
return sys.modules[original_name] return sys.modules[original_name]
already_patched = {} already_patched = {}
@@ -183,6 +199,7 @@ def monkey_patch(**on):
on.setdefault(modname, default_on) on.setdefault(modname, default_on)
modules_to_patch = [] modules_to_patch = []
patched_thread = False
if on['os'] and not already_patched.get('os'): if on['os'] and not already_patched.get('os'):
modules_to_patch += _green_os_modules() modules_to_patch += _green_os_modules()
already_patched['os'] = True already_patched['os'] = True
@@ -193,10 +210,7 @@ def monkey_patch(**on):
modules_to_patch += _green_socket_modules() modules_to_patch += _green_socket_modules()
already_patched['socket'] = True already_patched['socket'] = True
if on['thread'] and not already_patched.get('thread'): if on['thread'] and not already_patched.get('thread'):
# hacks ahead patched_thread = True
threading = original('threading')
import eventlet.green.threading as greenthreading
greenthreading._patch_main_thread(threading)
modules_to_patch += _green_thread_modules() modules_to_patch += _green_thread_modules()
already_patched['thread'] = True already_patched['thread'] = True
if on['time'] and not already_patched.get('time'): if on['time'] and not already_patched.get('time'):
@@ -222,6 +236,25 @@ def monkey_patch(**on):
if patched_attr is not None: if patched_attr is not None:
setattr(orig_mod, attr_name, patched_attr) setattr(orig_mod, attr_name, patched_attr)
# hacks ahead; this is necessary to prevent a KeyError on program exit
if patched_thread:
_patch_main_thread(sys.modules['threading'])
def _patch_main_thread(mod):
"""This is some gnarly patching specific to the threading module;
threading will always be initialized prior to monkeypatching, and
its _active dict will have the wrong key (it uses the real thread
id but once it's patched it will use the greenlet ids); so what we
do is rekey the _active dict so that the main thread's entry uses
the greenthread key. Other threads' keys are ignored."""
thread = original('thread')
curthread = mod._active.pop(thread.get_ident(), None)
if curthread:
import eventlet.green.thread
mod._active[eventlet.green.thread.get_ident()] = curthread
def is_monkey_patched(module): def is_monkey_patched(module):
"""Returns True if the given module is monkeypatched currently, False if """Returns True if the given module is monkeypatched currently, False if
not. *module* can be either the module itself or its name. not. *module* can be either the module itself or its name.