Improved support for monkeypatching by eliminating race conditions in eventlet's own imports of patched modules.

This commit is contained in:
Ryan Williams
2010-02-17 00:05:45 -08:00
parent 60f14370d3
commit 15118d8075
13 changed files with 132 additions and 42 deletions

View File

@@ -6,6 +6,7 @@ _fileobject = __socket._fileobject
from eventlet.hubs import get_hub
from eventlet.greenio import GreenSocket as socket
from eventlet.greenio import SSL as _SSL # for exceptions
from eventlet.greenio import _GLOBAL_DEFAULT_TIMEOUT
import os
import sys
import warnings
@@ -54,9 +55,6 @@ def _gethostbyname_tpool(name):
# XXX there're few more blocking functions in socket
# XXX having a hub-independent way to access thread pool would be nice
_GLOBAL_DEFAULT_TIMEOUT = object()
def create_connection(address, timeout=_GLOBAL_DEFAULT_TIMEOUT):
"""Connect to *address* and return the socket object.

View File

@@ -16,7 +16,7 @@ def get_ident(gr=None):
return id(gr)
def start_new_thread(function, args=(), kwargs={}):
g = greenthread.spawn(function, *args, **kwargs)
g = greenthread.spawn_n(function, *args, **kwargs)
return get_ident(g)
start_new = start_new_thread

View File

@@ -11,5 +11,15 @@ patcher.inject('threading',
del patcher
def _patch_main_thread(mod):
# this is some gnarly patching for the threading module;
# if threading is imported before we patch (it nearly always is),
# then the main thread will have the wrong key in therading._active,
# so, we try and replace that key with the correct one here
# this works best if there are no other threads besides the main one
curthread = mod._active.pop(mod.current_thread()._Thread__ident, None)
if curthread:
mod._active[thread.get_ident()] = curthread
if __name__ == '__main__':
_test()

View File

@@ -1,11 +1,12 @@
import select
import sys
import threading
from eventlet.support import greenlets as greenlet
_threadlocal = threading.local()
from eventlet import patcher
__all__ = ["use_hub", "get_hub", "get_default_hub", "trampoline"]
threading = patcher.original('threading')
_threadlocal = threading.local()
def get_default_hub():
"""Select the default hub implementation based on what multiplexing
libraries are installed. The order that the hubs are tried is:
@@ -33,6 +34,7 @@ def get_default_hub():
from eventlet.hubs import twistedr
return twistedr
select = patcher.original('select')
try:
import eventlet.hubs.epolls
return eventlet.hubs.epolls

View File

@@ -1,11 +1,16 @@
from eventlet import patcher
time = patcher.original('time')
try:
# shoot for epoll module first
from epoll import poll as epoll
except ImportError, e:
# if we can't import that, hope we're on 2.6
from select import epoll
select = patcher.original('select')
try:
epoll = select.epoll
except AttributeError:
raise ImportError("No epoll on select module")
import time
from eventlet.hubs.hub import BaseHub
from eventlet.hubs import poll
from eventlet.hubs.poll import READ, WRITE

View File

@@ -1,10 +1,11 @@
import bisect
import sys
import traceback
import time
from eventlet.support import greenlets as greenlet
from eventlet.hubs import timer
from eventlet import patcher
time = patcher.original('time')
READ="read"
WRITE="write"

View File

@@ -1,8 +1,9 @@
import sys
import select
import errno
from time import sleep
import time
from eventlet import patcher
select = patcher.original('select')
time = patcher.original('time')
sleep = time.sleep
from eventlet.hubs.hub import BaseHub, READ, WRITE

View File

@@ -1,5 +1,4 @@
import sys
import time
import traceback
import event
@@ -38,8 +37,8 @@ class Hub(BaseHub):
SYSTEM_EXCEPTIONS = (KeyboardInterrupt, SystemExit)
def __init__(self, clock=time.time):
super(Hub,self).__init__(clock)
def __init__(self):
super(Hub,self).__init__()
event.init()
self.signal_exc_info = None

View File

@@ -1,7 +1,8 @@
import sys
import select
import errno
import time
from eventlet import patcher
select = patcher.original('select')
time = patcher.original('time')
from eventlet.hubs.hub import BaseHub, READ, WRITE

View File

@@ -97,6 +97,23 @@ def patch_function(func, *additional_modules):
del sys.modules[name]
return patched
_originals = {}
class DummyModule(object):
pass
def make_original(modname):
orig_mod = __import__(modname)
dummy_mod = DummyModule()
for attr in dir(orig_mod):
setattr(dummy_mod, attr, getattr(orig_mod, attr))
_originals[modname] = dummy_mod
def original(modname):
mod = _originals.get(modname)
if mod is None:
make_original(modname)
mod = _originals.get(modname)
return mod
already_patched = {}
def monkey_patch(all=True, os=False, select=False,
socket=False, thread=False, time=False):
@@ -117,23 +134,32 @@ def monkey_patch(all=True, os=False, select=False,
modules_to_patch += _green_os_modules()
already_patched['os'] = True
if all or select and not already_patched.get('select'):
make_original('select')
modules_to_patch += _green_select_modules()
already_patched['select'] = True
if all or socket and not already_patched.get('socket'):
modules_to_patch += _green_socket_modules()
already_patched['socket'] = True
if all or thread and not already_patched.get('thread'):
make_original('threading')
# hacks ahead
threading = original('threading')
import eventlet.green.threading as greenthreading
greenthreading._patch_main_thread(threading)
modules_to_patch += _green_thread_modules()
already_patched['thread'] = True
if all or time and not already_patched.get('time'):
make_original('time')
modules_to_patch += _green_time_modules()
already_patched['time'] = True
for name, mod in modules_to_patch:
orig_mod = sys.modules.get(name)
for attr in mod.__patched__:
orig_attr = getattr(orig_mod, attr, None)
patched_attr = getattr(mod, attr, None)
if patched_attr is not None:
setattr(sys.modules[name], attr, patched_attr)
setattr(orig_mod, attr, patched_attr)
def _green_os_modules():
from eventlet.green import os

View File

@@ -14,7 +14,6 @@
# limitations under the License.
import os
import threading
import sys
from Queue import Empty, Queue
@@ -22,6 +21,8 @@ from Queue import Empty, Queue
from eventlet import event
from eventlet import greenio
from eventlet import greenthread
from eventlet import patcher
threading = patcher.original('threading')
__all__ = ['execute', 'Proxy', 'killall']

View File

@@ -1,5 +1,4 @@
import os
import select
import socket
import errno
import warnings

View File

@@ -43,17 +43,20 @@ class Patcher(LimitedTestCase):
fd.write(contents)
fd.close()
def test_patch_a_module(self):
self.write_to_tempfile("base", base_module_contents)
self.write_to_tempfile("patching", patching_module_contents)
self.write_to_tempfile("importing", import_module_contents)
def launch_subprocess(self, filename):
python_path = os.pathsep.join(sys.path + [self.tempdir])
new_env = os.environ.copy()
new_env['PYTHONPATH'] = python_path
p = subprocess.Popen([sys.executable,
os.path.join(self.tempdir, "importing.py")],
stdout=subprocess.PIPE, env=new_env)
os.path.join(self.tempdir, filename)],
stdout=subprocess.PIPE, stderr=subprocess.STDOUT, env=new_env)
return p
def test_patch_a_module(self):
self.write_to_tempfile("base", base_module_contents)
self.write_to_tempfile("patching", patching_module_contents)
self.write_to_tempfile("importing", import_module_contents)
p = self.launch_subprocess('importing.py')
output = p.communicate()
lines = output[0].split("\n")
self.assert_(lines[0].startswith('patcher'))
@@ -73,12 +76,7 @@ base = patcher.import_patched('base')
print "newmod", base, base.socket, base.urllib.socket.socket
"""
self.write_to_tempfile("newmod", new_mod)
python_path = os.pathsep.join(sys.path + [self.tempdir])
new_env = os.environ.copy()
new_env['PYTHONPATH'] = python_path
p = subprocess.Popen([sys.executable,
os.path.join(self.tempdir, "newmod.py")],
stdout=subprocess.PIPE, env=new_env)
p = self.launch_subprocess('newmod.py')
output = p.communicate()
lines = output[0].split("\n")
self.assert_(lines[0].startswith('base'))
@@ -95,13 +93,62 @@ import urllib
print "newmod", socket.socket, urllib.socket.socket
"""
self.write_to_tempfile("newmod", new_mod)
python_path = os.pathsep.join(sys.path + [self.tempdir])
new_env = os.environ.copy()
new_env['PYTHONPATH'] = python_path
p = subprocess.Popen([sys.executable,
os.path.join(self.tempdir, "newmod.py")],
stdout=subprocess.PIPE, env=new_env)
p = self.launch_subprocess('newmod.py')
output = p.communicate()
print output[0]
lines = output[0].split("\n")
self.assert_(lines[0].startswith('newmod'))
self.assertEqual(lines[0].count('GreenSocket'), 2)
def test_early_patching(self):
new_mod = """
from eventlet import patcher
patcher.monkey_patch()
import eventlet
eventlet.sleep(0.01)
print "newmod"
"""
self.write_to_tempfile("newmod", new_mod)
p = self.launch_subprocess('newmod.py')
output = p.communicate()
print output[0]
lines = output[0].split("\n")
self.assertEqual(len(lines), 2)
self.assert_(lines[0].startswith('newmod'))
def test_late_patching(self):
new_mod = """
import eventlet
eventlet.sleep(0.01)
from eventlet import patcher
patcher.monkey_patch()
eventlet.sleep(0.01)
print "newmod"
"""
self.write_to_tempfile("newmod", new_mod)
p = self.launch_subprocess('newmod.py')
output = p.communicate()
print output[0]
lines = output[0].split("\n")
self.assertEqual(len(lines), 2)
self.assert_(lines[0].startswith('newmod'))
def test_tpool(self):
new_mod = """
import eventlet
from eventlet import patcher
patcher.monkey_patch()
from eventlet import tpool
print "newmod", tpool.execute(len, "hi")
print "newmod", tpool.execute(len, "hi2")
"""
self.write_to_tempfile("newmod", new_mod)
p = self.launch_subprocess('newmod.py')
output = p.communicate()
print output[0]
lines = output[0].split("\n")
self.assertEqual(len(lines), 3)
self.assert_(lines[0].startswith('newmod'))
self.assert_('2' in lines[0])
self.assert_('3' in lines[1])