From 15118d8075988f63a1c955f35a3138640bc02254 Mon Sep 17 00:00:00 2001 From: Ryan Williams Date: Wed, 17 Feb 2010 00:05:45 -0800 Subject: [PATCH] Improved support for monkeypatching by eliminating race conditions in eventlet's own imports of patched modules. --- eventlet/green/socket.py | 4 +- eventlet/green/thread.py | 2 +- eventlet/green/threading.py | 10 +++++ eventlet/hubs/__init__.py | 8 ++-- eventlet/hubs/epolls.py | 9 +++- eventlet/hubs/hub.py | 3 +- eventlet/hubs/poll.py | 7 +-- eventlet/hubs/pyevent.py | 5 +-- eventlet/hubs/selects.py | 5 ++- eventlet/patcher.py | 30 ++++++++++++- eventlet/tpool.py | 3 +- eventlet/util.py | 1 - tests/patcher_test.py | 87 ++++++++++++++++++++++++++++--------- 13 files changed, 132 insertions(+), 42 deletions(-) diff --git a/eventlet/green/socket.py b/eventlet/green/socket.py index e9cde35..4af8f05 100644 --- a/eventlet/green/socket.py +++ b/eventlet/green/socket.py @@ -6,6 +6,7 @@ _fileobject = __socket._fileobject from eventlet.hubs import get_hub from eventlet.greenio import GreenSocket as socket from eventlet.greenio import SSL as _SSL # for exceptions +from eventlet.greenio import _GLOBAL_DEFAULT_TIMEOUT import os import sys import warnings @@ -54,9 +55,6 @@ def _gethostbyname_tpool(name): # XXX there're few more blocking functions in socket # XXX having a hub-independent way to access thread pool would be nice - -_GLOBAL_DEFAULT_TIMEOUT = object() - def create_connection(address, timeout=_GLOBAL_DEFAULT_TIMEOUT): """Connect to *address* and return the socket object. diff --git a/eventlet/green/thread.py b/eventlet/green/thread.py index 9a6ac78..9c7ea5e 100644 --- a/eventlet/green/thread.py +++ b/eventlet/green/thread.py @@ -16,7 +16,7 @@ def get_ident(gr=None): return id(gr) def start_new_thread(function, args=(), kwargs={}): - g = greenthread.spawn(function, *args, **kwargs) + g = greenthread.spawn_n(function, *args, **kwargs) return get_ident(g) start_new = start_new_thread diff --git a/eventlet/green/threading.py b/eventlet/green/threading.py index faf918b..106f7d9 100644 --- a/eventlet/green/threading.py +++ b/eventlet/green/threading.py @@ -11,5 +11,15 @@ patcher.inject('threading', del patcher +def _patch_main_thread(mod): + # this is some gnarly patching for the threading module; + # if threading is imported before we patch (it nearly always is), + # then the main thread will have the wrong key in therading._active, + # so, we try and replace that key with the correct one here + # this works best if there are no other threads besides the main one + curthread = mod._active.pop(mod.current_thread()._Thread__ident, None) + if curthread: + mod._active[thread.get_ident()] = curthread + if __name__ == '__main__': _test() diff --git a/eventlet/hubs/__init__.py b/eventlet/hubs/__init__.py index aa6d56d..28543ce 100644 --- a/eventlet/hubs/__init__.py +++ b/eventlet/hubs/__init__.py @@ -1,11 +1,12 @@ -import select import sys -import threading from eventlet.support import greenlets as greenlet -_threadlocal = threading.local() +from eventlet import patcher __all__ = ["use_hub", "get_hub", "get_default_hub", "trampoline"] +threading = patcher.original('threading') +_threadlocal = threading.local() + def get_default_hub(): """Select the default hub implementation based on what multiplexing libraries are installed. The order that the hubs are tried is: @@ -33,6 +34,7 @@ def get_default_hub(): from eventlet.hubs import twistedr return twistedr + select = patcher.original('select') try: import eventlet.hubs.epolls return eventlet.hubs.epolls diff --git a/eventlet/hubs/epolls.py b/eventlet/hubs/epolls.py index e3e03d8..0de4585 100644 --- a/eventlet/hubs/epolls.py +++ b/eventlet/hubs/epolls.py @@ -1,11 +1,16 @@ +from eventlet import patcher +time = patcher.original('time') try: # shoot for epoll module first from epoll import poll as epoll except ImportError, e: # if we can't import that, hope we're on 2.6 - from select import epoll + select = patcher.original('select') + try: + epoll = select.epoll + except AttributeError: + raise ImportError("No epoll on select module") -import time from eventlet.hubs.hub import BaseHub from eventlet.hubs import poll from eventlet.hubs.poll import READ, WRITE diff --git a/eventlet/hubs/hub.py b/eventlet/hubs/hub.py index b2ef1ad..c5e885e 100644 --- a/eventlet/hubs/hub.py +++ b/eventlet/hubs/hub.py @@ -1,10 +1,11 @@ import bisect import sys import traceback -import time from eventlet.support import greenlets as greenlet from eventlet.hubs import timer +from eventlet import patcher +time = patcher.original('time') READ="read" WRITE="write" diff --git a/eventlet/hubs/poll.py b/eventlet/hubs/poll.py index c9925e3..e6b3c27 100644 --- a/eventlet/hubs/poll.py +++ b/eventlet/hubs/poll.py @@ -1,8 +1,9 @@ import sys -import select import errno -from time import sleep -import time +from eventlet import patcher +select = patcher.original('select') +time = patcher.original('time') +sleep = time.sleep from eventlet.hubs.hub import BaseHub, READ, WRITE diff --git a/eventlet/hubs/pyevent.py b/eventlet/hubs/pyevent.py index a067ae5..b67d5d3 100644 --- a/eventlet/hubs/pyevent.py +++ b/eventlet/hubs/pyevent.py @@ -1,5 +1,4 @@ import sys -import time import traceback import event @@ -38,8 +37,8 @@ class Hub(BaseHub): SYSTEM_EXCEPTIONS = (KeyboardInterrupt, SystemExit) - def __init__(self, clock=time.time): - super(Hub,self).__init__(clock) + def __init__(self): + super(Hub,self).__init__() event.init() self.signal_exc_info = None diff --git a/eventlet/hubs/selects.py b/eventlet/hubs/selects.py index c205ce3..d20f3cf 100644 --- a/eventlet/hubs/selects.py +++ b/eventlet/hubs/selects.py @@ -1,7 +1,8 @@ import sys -import select import errno -import time +from eventlet import patcher +select = patcher.original('select') +time = patcher.original('time') from eventlet.hubs.hub import BaseHub, READ, WRITE diff --git a/eventlet/patcher.py b/eventlet/patcher.py index 22264d8..94ea29f 100644 --- a/eventlet/patcher.py +++ b/eventlet/patcher.py @@ -97,6 +97,23 @@ def patch_function(func, *additional_modules): del sys.modules[name] return patched +_originals = {} +class DummyModule(object): + pass +def make_original(modname): + orig_mod = __import__(modname) + dummy_mod = DummyModule() + for attr in dir(orig_mod): + setattr(dummy_mod, attr, getattr(orig_mod, attr)) + _originals[modname] = dummy_mod + +def original(modname): + mod = _originals.get(modname) + if mod is None: + make_original(modname) + mod = _originals.get(modname) + return mod + already_patched = {} def monkey_patch(all=True, os=False, select=False, socket=False, thread=False, time=False): @@ -117,23 +134,32 @@ def monkey_patch(all=True, os=False, select=False, modules_to_patch += _green_os_modules() already_patched['os'] = True if all or select and not already_patched.get('select'): + make_original('select') modules_to_patch += _green_select_modules() already_patched['select'] = True if all or socket and not already_patched.get('socket'): modules_to_patch += _green_socket_modules() already_patched['socket'] = True if all or thread and not already_patched.get('thread'): + make_original('threading') + # hacks ahead + threading = original('threading') + import eventlet.green.threading as greenthreading + greenthreading._patch_main_thread(threading) modules_to_patch += _green_thread_modules() already_patched['thread'] = True if all or time and not already_patched.get('time'): + make_original('time') modules_to_patch += _green_time_modules() already_patched['time'] = True for name, mod in modules_to_patch: + orig_mod = sys.modules.get(name) for attr in mod.__patched__: + orig_attr = getattr(orig_mod, attr, None) patched_attr = getattr(mod, attr, None) if patched_attr is not None: - setattr(sys.modules[name], attr, patched_attr) + setattr(orig_mod, attr, patched_attr) def _green_os_modules(): from eventlet.green import os @@ -159,4 +185,4 @@ def _green_thread_modules(): def _green_time_modules(): from eventlet.green import time - return [('time', time)] + return [('time', time)] \ No newline at end of file diff --git a/eventlet/tpool.py b/eventlet/tpool.py index bb694db..b36fc7c 100644 --- a/eventlet/tpool.py +++ b/eventlet/tpool.py @@ -14,7 +14,6 @@ # limitations under the License. import os -import threading import sys from Queue import Empty, Queue @@ -22,6 +21,8 @@ from Queue import Empty, Queue from eventlet import event from eventlet import greenio from eventlet import greenthread +from eventlet import patcher +threading = patcher.original('threading') __all__ = ['execute', 'Proxy', 'killall'] diff --git a/eventlet/util.py b/eventlet/util.py index a3d76d8..a5ff43d 100644 --- a/eventlet/util.py +++ b/eventlet/util.py @@ -1,5 +1,4 @@ import os -import select import socket import errno import warnings diff --git a/tests/patcher_test.py b/tests/patcher_test.py index d1695a6..a874bfc 100644 --- a/tests/patcher_test.py +++ b/tests/patcher_test.py @@ -42,18 +42,21 @@ class Patcher(LimitedTestCase): fd = open(filename, "w") fd.write(contents) fd.close() + + def launch_subprocess(self, filename): + python_path = os.pathsep.join(sys.path + [self.tempdir]) + new_env = os.environ.copy() + new_env['PYTHONPATH'] = python_path + p = subprocess.Popen([sys.executable, + os.path.join(self.tempdir, filename)], + stdout=subprocess.PIPE, stderr=subprocess.STDOUT, env=new_env) + return p def test_patch_a_module(self): self.write_to_tempfile("base", base_module_contents) self.write_to_tempfile("patching", patching_module_contents) self.write_to_tempfile("importing", import_module_contents) - - python_path = os.pathsep.join(sys.path + [self.tempdir]) - new_env = os.environ.copy() - new_env['PYTHONPATH'] = python_path - p = subprocess.Popen([sys.executable, - os.path.join(self.tempdir, "importing.py")], - stdout=subprocess.PIPE, env=new_env) + p = self.launch_subprocess('importing.py') output = p.communicate() lines = output[0].split("\n") self.assert_(lines[0].startswith('patcher')) @@ -73,12 +76,7 @@ base = patcher.import_patched('base') print "newmod", base, base.socket, base.urllib.socket.socket """ self.write_to_tempfile("newmod", new_mod) - python_path = os.pathsep.join(sys.path + [self.tempdir]) - new_env = os.environ.copy() - new_env['PYTHONPATH'] = python_path - p = subprocess.Popen([sys.executable, - os.path.join(self.tempdir, "newmod.py")], - stdout=subprocess.PIPE, env=new_env) + p = self.launch_subprocess('newmod.py') output = p.communicate() lines = output[0].split("\n") self.assert_(lines[0].startswith('base')) @@ -95,13 +93,62 @@ import urllib print "newmod", socket.socket, urllib.socket.socket """ self.write_to_tempfile("newmod", new_mod) - python_path = os.pathsep.join(sys.path + [self.tempdir]) - new_env = os.environ.copy() - new_env['PYTHONPATH'] = python_path - p = subprocess.Popen([sys.executable, - os.path.join(self.tempdir, "newmod.py")], - stdout=subprocess.PIPE, env=new_env) + p = self.launch_subprocess('newmod.py') output = p.communicate() + print output[0] lines = output[0].split("\n") self.assert_(lines[0].startswith('newmod')) - self.assertEqual(lines[0].count('GreenSocket'), 2) \ No newline at end of file + self.assertEqual(lines[0].count('GreenSocket'), 2) + + def test_early_patching(self): + new_mod = """ +from eventlet import patcher +patcher.monkey_patch() +import eventlet +eventlet.sleep(0.01) +print "newmod" +""" + self.write_to_tempfile("newmod", new_mod) + p = self.launch_subprocess('newmod.py') + output = p.communicate() + print output[0] + lines = output[0].split("\n") + self.assertEqual(len(lines), 2) + self.assert_(lines[0].startswith('newmod')) + + def test_late_patching(self): + new_mod = """ +import eventlet +eventlet.sleep(0.01) +from eventlet import patcher +patcher.monkey_patch() +eventlet.sleep(0.01) +print "newmod" +""" + self.write_to_tempfile("newmod", new_mod) + p = self.launch_subprocess('newmod.py') + output = p.communicate() + print output[0] + lines = output[0].split("\n") + self.assertEqual(len(lines), 2) + self.assert_(lines[0].startswith('newmod')) + + def test_tpool(self): + new_mod = """ +import eventlet +from eventlet import patcher +patcher.monkey_patch() +from eventlet import tpool +print "newmod", tpool.execute(len, "hi") +print "newmod", tpool.execute(len, "hi2") +""" + self.write_to_tempfile("newmod", new_mod) + p = self.launch_subprocess('newmod.py') + output = p.communicate() + print output[0] + lines = output[0].split("\n") + self.assertEqual(len(lines), 3) + self.assert_(lines[0].startswith('newmod')) + self.assert_('2' in lines[0]) + self.assert_('3' in lines[1]) +