Adding import locks around the places where we mess with sys.modules in patcher, at schmir's suggestion.

This commit is contained in:
Ryan Williams
2010-08-04 22:31:00 -07:00
parent 3c8dc49421
commit b568332cb1
2 changed files with 30 additions and 23 deletions

View File

@@ -59,6 +59,6 @@ Thanks To
* Marcin Bachry, nice repro of a bug and good diagnosis leading to the fix
* David Ziegler, reporting issue #53
* Favo Yang, twisted hub patch
* Schmir, patch that fixes readline method with chunked encoding in wsgi.py
* Schmir, patch that fixes readline method with chunked encoding in wsgi.py, advice on patcher
* Slide, for open-sourcing gogreen
* Holger Krekel, websocket example small fix

View File

@@ -1,5 +1,5 @@
import sys
import imp
__all__ = ['inject', 'import_patched', 'monkey_patch', 'is_monkey_patched']
@@ -11,6 +11,7 @@ class SysModulesSaver(object):
constructor."""
def __init__(self, module_names=()):
self._saved = {}
imp.acquire_lock()
self.save(*module_names)
def save(self, *module_names):
@@ -22,14 +23,17 @@ class SysModulesSaver(object):
"""Restores the modules that the saver knows about into
sys.modules.
"""
for modname, mod in self._saved.iteritems():
if mod is not None:
sys.modules[modname] = mod
else:
try:
del sys.modules[modname]
except KeyError:
pass
try:
for modname, mod in self._saved.iteritems():
if mod is not None:
sys.modules[modname] = mod
else:
try:
del sys.modules[modname]
except KeyError:
pass
finally:
imp.release_lock()
def inject(module_name, new_globals, *additional_modules):
@@ -63,7 +67,7 @@ def inject(module_name, new_globals, *additional_modules):
_green_time_modules())
# after this we are gonna screw with sys.modules, so capture the
# state of all the modules we're going to mess with
# state of all the modules we're going to mess with, and lock
saver = SysModulesSaver([name for name, m in additional_modules])
saver.save(module_name)
@@ -244,19 +248,22 @@ def monkey_patch(**on):
# tell us whether or not we succeeded
pass
for name, mod in modules_to_patch:
orig_mod = sys.modules.get(name)
if orig_mod is None:
orig_mod = __import__(name)
for attr_name in mod.__patched__:
patched_attr = getattr(mod, attr_name, None)
if patched_attr is not None:
setattr(orig_mod, attr_name, patched_attr)
# hacks ahead; this is necessary to prevent a KeyError on program exit
if patched_thread:
_patch_main_thread(sys.modules['threading'])
imp.acquire_lock()
try:
for name, mod in modules_to_patch:
orig_mod = sys.modules.get(name)
if orig_mod is None:
orig_mod = __import__(name)
for attr_name in mod.__patched__:
patched_attr = getattr(mod, attr_name, None)
if patched_attr is not None:
setattr(orig_mod, attr_name, patched_attr)
# hacks ahead; this is necessary to prevent a KeyError on program exit
if patched_thread:
_patch_main_thread(sys.modules['threading'])
finally:
imp.release_lock()
def _patch_main_thread(mod):
"""This is some gnarly patching specific to the threading module;