From 0793b503ebe03ca82d1329d76d092329e9c29208 Mon Sep 17 00:00:00 2001 From: Ryan Williams Date: Fri, 4 Jun 2010 13:33:14 -0700 Subject: [PATCH] Fixed main thread patching and moved the code into patcher.py. Implemented a sys.modules state saver that simplifies much of the code in patcher.py. Improved comments and docs in there too. --- eventlet/green/threading.py | 10 --- eventlet/patcher.py | 125 +++++++++++++++++++++++------------- 2 files changed, 79 insertions(+), 56 deletions(-) diff --git a/eventlet/green/threading.py b/eventlet/green/threading.py index 0c0a03c..7d61c58 100644 --- a/eventlet/green/threading.py +++ b/eventlet/green/threading.py @@ -11,13 +11,3 @@ patcher.inject('threading', ('time', time)) del patcher - -def _patch_main_thread(mod): - # this is some gnarly patching for the threading module; - # if threading is imported before we patch (it nearly always is), - # then the main thread will have the wrong key in threading._active, - # so, we try and replace that key with the correct one here - # this works best if there are no other threads besides the main one - curthread = mod._active.pop(mod._get_ident(), None) - if curthread: - mod._active[thread.get_ident()] = curthread diff --git a/eventlet/patcher.py b/eventlet/patcher.py index 7425cb2..77d8e42 100644 --- a/eventlet/patcher.py +++ b/eventlet/patcher.py @@ -5,6 +5,33 @@ __all__ = ['inject', 'import_patched', 'monkey_patch', 'is_monkey_patched'] __exclude = set(('__builtins__', '__file__', '__name__')) +class SysModulesSaver(object): + """Class that captures some subset of the current state of + sys.modules. Pass in an iterator of module names to the + constructor.""" + def __init__(self, module_names=()): + self._saved = {} + self.save(*module_names) + + def save(self, *module_names): + """Saves the named modules to the object.""" + for modname in module_names: + self._saved[modname] = sys.modules.get(modname, None) + + def restore(self): + """Restores the modules that the saver knows about into + sys.modules. + """ + for modname, mod in self._saved.iteritems(): + if mod is not None: + sys.modules[modname] = mod + else: + try: + del sys.modules[modname] + except KeyError: + pass + + def inject(module_name, new_globals, *additional_modules): """Base method for "injecting" greened modules into an imported module. It imports the module specified in *module_name*, arranging things so @@ -34,16 +61,20 @@ def inject(module_name, new_globals, *additional_modules): _green_socket_modules() + _green_thread_modules() + _green_time_modules()) + + # after this we are gonna screw with sys.modules, so capture the + # state of all the modules we're going to mess with + saver = SysModulesSaver([name for name, m in additional_modules]) + saver.save(module_name) - ## Put the specified modules in sys.modules for the duration of the import - saved = {} + # Cover the target modules so that when you import the module it + # sees only the patched versions for name, mod in additional_modules: - saved[name] = sys.modules.get(name, None) sys.modules[name] = mod ## Remove the old module from sys.modules and reimport it while ## the specified modules are in place - old_module = sys.modules.pop(module_name, None) + sys.modules.pop(module_name, None) try: module = __import__(module_name, {}, {}, module_name.split('.')[:-1]) @@ -56,18 +87,7 @@ def inject(module_name, new_globals, *additional_modules): ## Keep a reference to the new module to prevent it from dying sys.modules[patched_name] = module finally: - ## Put the original module back - if old_module is not None: - sys.modules[module_name] = old_module - elif module_name in sys.modules: - del sys.modules[module_name] - - ## Put all the saved modules back - for name, mod in additional_modules: - if saved[name] is not None: - sys.modules[name] = saved[name] - else: - del sys.modules[name] + saver.restore() ## Put the original modules back return module @@ -86,8 +106,11 @@ def import_patched(module_name, *additional_modules, **kw_additional_modules): def patch_function(func, *additional_modules): - """Huge hack here -- patches the specified modules for the - duration of the function call.""" + """Decorator that returns a version of the function that patches + some modules for the duration of the function call. This is + deeply gross and should only be used for functions that import + network libraries within their function bodies that there is no + way of getting around.""" if not additional_modules: # supply some defaults additional_modules = ( @@ -98,50 +121,44 @@ def patch_function(func, *additional_modules): _green_time_modules()) def patched(*args, **kw): - saved = {} + saver = SysModulesSaver(additional_modules.keys()) for name, mod in additional_modules: - saved[name] = sys.modules.get(name, None) sys.modules[name] = mod try: return func(*args, **kw) finally: - ## Put all the saved modules back - for name, mod in additional_modules: - if saved[name] is not None: - sys.modules[name] = saved[name] - else: - del sys.modules[name] + saver.restore() return patched def _original_patch_function(func, *module_names): - """Kind of the opposite of patch_function; wraps a function such - that sys.modules is populated only with the unpatched versions of - the specified modules. Also a gross hack; tell your kids not to - import inside function bodies!""" + """Kind of the contrapositive of patch_function: decorates a + function such that when it's called, sys.modules is populated only + with the unpatched versions of the specified modules. Unlike + patch_function, only the names of the modules need be supplied, + and there are no defaults. This is a gross hack; tell your kids not + to import inside function bodies!""" def patched(*args, **kw): - saved = {} + saver = SysModulesSaver(module_names) for name in module_names: - saved[name] = sys.modules.get(name, None) sys.modules[name] = original(name) try: return func(*args, **kw) finally: - for name in module_names: - if saved[name] is not None: - sys.modules[name] = saved[name] - else: - del sys.modules[name] + saver.restore() return patched def original(modname): + """ This returns an unpatched version of a module; this is useful for + Eventlet itself (i.e. tpool).""" original_name = '__original_module_' + modname if original_name in sys.modules: return sys.modules.get(original_name) # re-import the "pure" module and store it in the global _originals # dict; be sure to restore whatever module had that name already - current_mod = sys.modules.pop(modname, None) + saver = SysModulesSaver((modname,)) + sys.modules.pop(modname, None) try: real_mod = __import__(modname, {}, {}, modname.split('.')[:-1]) # hacky hack: Queue's constructor imports threading; therefore @@ -149,12 +166,11 @@ def original(modname): # original threading if modname == 'Queue': real_mod.Queue.__init__ = _original_patch_function(real_mod.Queue.__init__, 'threading') + # save a reference to the unpatched module so it doesn't get lost sys.modules[original_name] = real_mod finally: - if current_mod is not None: - sys.modules[modname] = current_mod - else: - del sys.modules[modname] + saver.restore() + return sys.modules[original_name] already_patched = {} @@ -183,6 +199,7 @@ def monkey_patch(**on): on.setdefault(modname, default_on) modules_to_patch = [] + patched_thread = False if on['os'] and not already_patched.get('os'): modules_to_patch += _green_os_modules() already_patched['os'] = True @@ -193,10 +210,7 @@ def monkey_patch(**on): modules_to_patch += _green_socket_modules() already_patched['socket'] = True if on['thread'] and not already_patched.get('thread'): - # hacks ahead - threading = original('threading') - import eventlet.green.threading as greenthreading - greenthreading._patch_main_thread(threading) + patched_thread = True modules_to_patch += _green_thread_modules() already_patched['thread'] = True if on['time'] and not already_patched.get('time'): @@ -222,6 +236,25 @@ def monkey_patch(**on): if patched_attr is not None: setattr(orig_mod, attr_name, patched_attr) + # hacks ahead; this is necessary to prevent a KeyError on program exit + if patched_thread: + _patch_main_thread(sys.modules['threading']) + + +def _patch_main_thread(mod): + """This is some gnarly patching specific to the threading module; + threading will always be initialized prior to monkeypatching, and + its _active dict will have the wrong key (it uses the real thread + id but once it's patched it will use the greenlet ids); so what we + do is rekey the _active dict so that the main thread's entry uses + the greenthread key. Other threads' keys are ignored.""" + thread = original('thread') + curthread = mod._active.pop(thread.get_ident(), None) + if curthread: + import eventlet.green.thread + mod._active[eventlet.green.thread.get_ident()] = curthread + + def is_monkey_patched(module): """Returns True if the given module is monkeypatched currently, False if not. *module* can be either the module itself or its name.