Improved support for monkeypatching by eliminating race conditions in eventlet's own imports of patched modules.

This commit is contained in:
Ryan Williams
2010-02-17 00:05:45 -08:00
parent 60f14370d3
commit 15118d8075
13 changed files with 132 additions and 42 deletions

View File

@@ -6,6 +6,7 @@ _fileobject = __socket._fileobject
from eventlet.hubs import get_hub from eventlet.hubs import get_hub
from eventlet.greenio import GreenSocket as socket from eventlet.greenio import GreenSocket as socket
from eventlet.greenio import SSL as _SSL # for exceptions from eventlet.greenio import SSL as _SSL # for exceptions
from eventlet.greenio import _GLOBAL_DEFAULT_TIMEOUT
import os import os
import sys import sys
import warnings import warnings
@@ -54,9 +55,6 @@ def _gethostbyname_tpool(name):
# XXX there're few more blocking functions in socket # XXX there're few more blocking functions in socket
# XXX having a hub-independent way to access thread pool would be nice # XXX having a hub-independent way to access thread pool would be nice
_GLOBAL_DEFAULT_TIMEOUT = object()
def create_connection(address, timeout=_GLOBAL_DEFAULT_TIMEOUT): def create_connection(address, timeout=_GLOBAL_DEFAULT_TIMEOUT):
"""Connect to *address* and return the socket object. """Connect to *address* and return the socket object.

View File

@@ -16,7 +16,7 @@ def get_ident(gr=None):
return id(gr) return id(gr)
def start_new_thread(function, args=(), kwargs={}): def start_new_thread(function, args=(), kwargs={}):
g = greenthread.spawn(function, *args, **kwargs) g = greenthread.spawn_n(function, *args, **kwargs)
return get_ident(g) return get_ident(g)
start_new = start_new_thread start_new = start_new_thread

View File

@@ -11,5 +11,15 @@ patcher.inject('threading',
del patcher del patcher
def _patch_main_thread(mod):
# this is some gnarly patching for the threading module;
# if threading is imported before we patch (it nearly always is),
# then the main thread will have the wrong key in therading._active,
# so, we try and replace that key with the correct one here
# this works best if there are no other threads besides the main one
curthread = mod._active.pop(mod.current_thread()._Thread__ident, None)
if curthread:
mod._active[thread.get_ident()] = curthread
if __name__ == '__main__': if __name__ == '__main__':
_test() _test()

View File

@@ -1,11 +1,12 @@
import select
import sys import sys
import threading
from eventlet.support import greenlets as greenlet from eventlet.support import greenlets as greenlet
_threadlocal = threading.local() from eventlet import patcher
__all__ = ["use_hub", "get_hub", "get_default_hub", "trampoline"] __all__ = ["use_hub", "get_hub", "get_default_hub", "trampoline"]
threading = patcher.original('threading')
_threadlocal = threading.local()
def get_default_hub(): def get_default_hub():
"""Select the default hub implementation based on what multiplexing """Select the default hub implementation based on what multiplexing
libraries are installed. The order that the hubs are tried is: libraries are installed. The order that the hubs are tried is:
@@ -33,6 +34,7 @@ def get_default_hub():
from eventlet.hubs import twistedr from eventlet.hubs import twistedr
return twistedr return twistedr
select = patcher.original('select')
try: try:
import eventlet.hubs.epolls import eventlet.hubs.epolls
return eventlet.hubs.epolls return eventlet.hubs.epolls

View File

@@ -1,11 +1,16 @@
from eventlet import patcher
time = patcher.original('time')
try: try:
# shoot for epoll module first # shoot for epoll module first
from epoll import poll as epoll from epoll import poll as epoll
except ImportError, e: except ImportError, e:
# if we can't import that, hope we're on 2.6 # if we can't import that, hope we're on 2.6
from select import epoll select = patcher.original('select')
try:
epoll = select.epoll
except AttributeError:
raise ImportError("No epoll on select module")
import time
from eventlet.hubs.hub import BaseHub from eventlet.hubs.hub import BaseHub
from eventlet.hubs import poll from eventlet.hubs import poll
from eventlet.hubs.poll import READ, WRITE from eventlet.hubs.poll import READ, WRITE

View File

@@ -1,10 +1,11 @@
import bisect import bisect
import sys import sys
import traceback import traceback
import time
from eventlet.support import greenlets as greenlet from eventlet.support import greenlets as greenlet
from eventlet.hubs import timer from eventlet.hubs import timer
from eventlet import patcher
time = patcher.original('time')
READ="read" READ="read"
WRITE="write" WRITE="write"

View File

@@ -1,8 +1,9 @@
import sys import sys
import select
import errno import errno
from time import sleep from eventlet import patcher
import time select = patcher.original('select')
time = patcher.original('time')
sleep = time.sleep
from eventlet.hubs.hub import BaseHub, READ, WRITE from eventlet.hubs.hub import BaseHub, READ, WRITE

View File

@@ -1,5 +1,4 @@
import sys import sys
import time
import traceback import traceback
import event import event
@@ -38,8 +37,8 @@ class Hub(BaseHub):
SYSTEM_EXCEPTIONS = (KeyboardInterrupt, SystemExit) SYSTEM_EXCEPTIONS = (KeyboardInterrupt, SystemExit)
def __init__(self, clock=time.time): def __init__(self):
super(Hub,self).__init__(clock) super(Hub,self).__init__()
event.init() event.init()
self.signal_exc_info = None self.signal_exc_info = None

View File

@@ -1,7 +1,8 @@
import sys import sys
import select
import errno import errno
import time from eventlet import patcher
select = patcher.original('select')
time = patcher.original('time')
from eventlet.hubs.hub import BaseHub, READ, WRITE from eventlet.hubs.hub import BaseHub, READ, WRITE

View File

@@ -97,6 +97,23 @@ def patch_function(func, *additional_modules):
del sys.modules[name] del sys.modules[name]
return patched return patched
_originals = {}
class DummyModule(object):
pass
def make_original(modname):
orig_mod = __import__(modname)
dummy_mod = DummyModule()
for attr in dir(orig_mod):
setattr(dummy_mod, attr, getattr(orig_mod, attr))
_originals[modname] = dummy_mod
def original(modname):
mod = _originals.get(modname)
if mod is None:
make_original(modname)
mod = _originals.get(modname)
return mod
already_patched = {} already_patched = {}
def monkey_patch(all=True, os=False, select=False, def monkey_patch(all=True, os=False, select=False,
socket=False, thread=False, time=False): socket=False, thread=False, time=False):
@@ -117,23 +134,32 @@ def monkey_patch(all=True, os=False, select=False,
modules_to_patch += _green_os_modules() modules_to_patch += _green_os_modules()
already_patched['os'] = True already_patched['os'] = True
if all or select and not already_patched.get('select'): if all or select and not already_patched.get('select'):
make_original('select')
modules_to_patch += _green_select_modules() modules_to_patch += _green_select_modules()
already_patched['select'] = True already_patched['select'] = True
if all or socket and not already_patched.get('socket'): if all or socket and not already_patched.get('socket'):
modules_to_patch += _green_socket_modules() modules_to_patch += _green_socket_modules()
already_patched['socket'] = True already_patched['socket'] = True
if all or thread and not already_patched.get('thread'): if all or thread and not already_patched.get('thread'):
make_original('threading')
# hacks ahead
threading = original('threading')
import eventlet.green.threading as greenthreading
greenthreading._patch_main_thread(threading)
modules_to_patch += _green_thread_modules() modules_to_patch += _green_thread_modules()
already_patched['thread'] = True already_patched['thread'] = True
if all or time and not already_patched.get('time'): if all or time and not already_patched.get('time'):
make_original('time')
modules_to_patch += _green_time_modules() modules_to_patch += _green_time_modules()
already_patched['time'] = True already_patched['time'] = True
for name, mod in modules_to_patch: for name, mod in modules_to_patch:
orig_mod = sys.modules.get(name)
for attr in mod.__patched__: for attr in mod.__patched__:
orig_attr = getattr(orig_mod, attr, None)
patched_attr = getattr(mod, attr, None) patched_attr = getattr(mod, attr, None)
if patched_attr is not None: if patched_attr is not None:
setattr(sys.modules[name], attr, patched_attr) setattr(orig_mod, attr, patched_attr)
def _green_os_modules(): def _green_os_modules():
from eventlet.green import os from eventlet.green import os
@@ -159,4 +185,4 @@ def _green_thread_modules():
def _green_time_modules(): def _green_time_modules():
from eventlet.green import time from eventlet.green import time
return [('time', time)] return [('time', time)]

View File

@@ -14,7 +14,6 @@
# limitations under the License. # limitations under the License.
import os import os
import threading
import sys import sys
from Queue import Empty, Queue from Queue import Empty, Queue
@@ -22,6 +21,8 @@ from Queue import Empty, Queue
from eventlet import event from eventlet import event
from eventlet import greenio from eventlet import greenio
from eventlet import greenthread from eventlet import greenthread
from eventlet import patcher
threading = patcher.original('threading')
__all__ = ['execute', 'Proxy', 'killall'] __all__ = ['execute', 'Proxy', 'killall']

View File

@@ -1,5 +1,4 @@
import os import os
import select
import socket import socket
import errno import errno
import warnings import warnings

View File

@@ -42,18 +42,21 @@ class Patcher(LimitedTestCase):
fd = open(filename, "w") fd = open(filename, "w")
fd.write(contents) fd.write(contents)
fd.close() fd.close()
def launch_subprocess(self, filename):
python_path = os.pathsep.join(sys.path + [self.tempdir])
new_env = os.environ.copy()
new_env['PYTHONPATH'] = python_path
p = subprocess.Popen([sys.executable,
os.path.join(self.tempdir, filename)],
stdout=subprocess.PIPE, stderr=subprocess.STDOUT, env=new_env)
return p
def test_patch_a_module(self): def test_patch_a_module(self):
self.write_to_tempfile("base", base_module_contents) self.write_to_tempfile("base", base_module_contents)
self.write_to_tempfile("patching", patching_module_contents) self.write_to_tempfile("patching", patching_module_contents)
self.write_to_tempfile("importing", import_module_contents) self.write_to_tempfile("importing", import_module_contents)
p = self.launch_subprocess('importing.py')
python_path = os.pathsep.join(sys.path + [self.tempdir])
new_env = os.environ.copy()
new_env['PYTHONPATH'] = python_path
p = subprocess.Popen([sys.executable,
os.path.join(self.tempdir, "importing.py")],
stdout=subprocess.PIPE, env=new_env)
output = p.communicate() output = p.communicate()
lines = output[0].split("\n") lines = output[0].split("\n")
self.assert_(lines[0].startswith('patcher')) self.assert_(lines[0].startswith('patcher'))
@@ -73,12 +76,7 @@ base = patcher.import_patched('base')
print "newmod", base, base.socket, base.urllib.socket.socket print "newmod", base, base.socket, base.urllib.socket.socket
""" """
self.write_to_tempfile("newmod", new_mod) self.write_to_tempfile("newmod", new_mod)
python_path = os.pathsep.join(sys.path + [self.tempdir]) p = self.launch_subprocess('newmod.py')
new_env = os.environ.copy()
new_env['PYTHONPATH'] = python_path
p = subprocess.Popen([sys.executable,
os.path.join(self.tempdir, "newmod.py")],
stdout=subprocess.PIPE, env=new_env)
output = p.communicate() output = p.communicate()
lines = output[0].split("\n") lines = output[0].split("\n")
self.assert_(lines[0].startswith('base')) self.assert_(lines[0].startswith('base'))
@@ -95,13 +93,62 @@ import urllib
print "newmod", socket.socket, urllib.socket.socket print "newmod", socket.socket, urllib.socket.socket
""" """
self.write_to_tempfile("newmod", new_mod) self.write_to_tempfile("newmod", new_mod)
python_path = os.pathsep.join(sys.path + [self.tempdir]) p = self.launch_subprocess('newmod.py')
new_env = os.environ.copy()
new_env['PYTHONPATH'] = python_path
p = subprocess.Popen([sys.executable,
os.path.join(self.tempdir, "newmod.py")],
stdout=subprocess.PIPE, env=new_env)
output = p.communicate() output = p.communicate()
print output[0]
lines = output[0].split("\n") lines = output[0].split("\n")
self.assert_(lines[0].startswith('newmod')) self.assert_(lines[0].startswith('newmod'))
self.assertEqual(lines[0].count('GreenSocket'), 2) self.assertEqual(lines[0].count('GreenSocket'), 2)
def test_early_patching(self):
new_mod = """
from eventlet import patcher
patcher.monkey_patch()
import eventlet
eventlet.sleep(0.01)
print "newmod"
"""
self.write_to_tempfile("newmod", new_mod)
p = self.launch_subprocess('newmod.py')
output = p.communicate()
print output[0]
lines = output[0].split("\n")
self.assertEqual(len(lines), 2)
self.assert_(lines[0].startswith('newmod'))
def test_late_patching(self):
new_mod = """
import eventlet
eventlet.sleep(0.01)
from eventlet import patcher
patcher.monkey_patch()
eventlet.sleep(0.01)
print "newmod"
"""
self.write_to_tempfile("newmod", new_mod)
p = self.launch_subprocess('newmod.py')
output = p.communicate()
print output[0]
lines = output[0].split("\n")
self.assertEqual(len(lines), 2)
self.assert_(lines[0].startswith('newmod'))
def test_tpool(self):
new_mod = """
import eventlet
from eventlet import patcher
patcher.monkey_patch()
from eventlet import tpool
print "newmod", tpool.execute(len, "hi")
print "newmod", tpool.execute(len, "hi2")
"""
self.write_to_tempfile("newmod", new_mod)
p = self.launch_subprocess('newmod.py')
output = p.communicate()
print output[0]
lines = output[0].split("\n")
self.assertEqual(len(lines), 3)
self.assert_(lines[0].startswith('newmod'))
self.assert_('2' in lines[0])
self.assert_('3' in lines[1])